From b794cf46984cc9b94f8b06b93afc1006390d8b5a Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sun, 31 Mar 2024 20:34:57 -0400 Subject: [PATCH 001/495] Set development branch to 0.12 (#1375) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a5713ad20d..dbc3b7478d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.11.17" +version = "0.12.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From ed15bb48d4b69304222cad2781af57f5fa2a4cea Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 31 Mar 2024 22:34:18 -0400 Subject: [PATCH 002/495] Mark randn! as inactive (#1378) --- src/internal_rules.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 9bcce5925c..8afcaa0a20 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -75,6 +75,9 @@ end function EnzymeRules.inactive(::typeof(Random.randn), args...) return nothing end +function EnzymeRules.inactive(::typeof(Random.randn!), args...) + return nothing +end function EnzymeRules.inactive(::typeof(Random.default_rng), args...) return nothing end From 6ac1212928113fa3dbfef97c9e841903042eb31e Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 31 Mar 2024 23:35:56 -0400 Subject: [PATCH 003/495] [EnzymeTestUtils] Mark 1.8 batch test as failing (#1379) --- lib/EnzymeTestUtils/test/test_forward.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/EnzymeTestUtils/test/test_forward.jl b/lib/EnzymeTestUtils/test/test_forward.jl index 8768d5324e..a2ab010042 100644 --- a/lib/EnzymeTestUtils/test/test_forward.jl +++ b/lib/EnzymeTestUtils/test/test_forward.jl @@ -85,7 +85,9 @@ end elseif TT <: NamedTuple x = (a=randn(T), b=randn(T)) else # TT <: TestStruct - VERSION ≤ v"1.8" && (@test_skip false; continue) + if VERSION <= v"1.8" && Tx == BatchDuplicated + continue + end x = TestStruct(randn(T, 5), randn(T)) end atol = rtol = sqrt(eps(real(T))) From edfb21fdc85949c78fcad095527d9aefc0c40c15 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 31 Mar 2024 23:55:47 -0400 Subject: [PATCH 004/495] Fix undefined memory in faq (#1376) --- docs/src/faq.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/faq.md b/docs/src/faq.md index c8315464ac..72fa1f97d9 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -271,7 +271,7 @@ Enzyme.autodiff(Reverse, f, Active(1.2), Const(Vector{Float64}(undef, 1)), Const Passing in a dupliacted (e.g. differentiable) variable for `tmp` now leads to the correct answer. ```jldoctest storage -Enzyme.autodiff(Reverse, f, Active(1.2), Duplicated(Vector{Float64}(undef, 1), Vector{Float64}(undef, 1)), Const(1), Const(5)) # Correct (returns 10.367999999999999 == 1.2^4 * 5) +Enzyme.autodiff(Reverse, f, Active(1.2), Duplicated(Vector{Float64}(undef, 1), zeros(1)), Const(1), Const(5)) # Correct (returns 10.367999999999999 == 1.2^4 * 5) # output @@ -539,4 +539,4 @@ For `d/d conj(z)`, $\frac12 \left( [u_x + i v_x] + i [u_y + i v_y] \right) = \fr 3.1 + 2.7im ``` -Note: when writing rules for complex scalar functions, in reverse mode one needs to conjugate the differential return, and similarly the true result will be the conjugate of that value (in essence you can think of reverse-mode AD as working in the conjugate space). \ No newline at end of file +Note: when writing rules for complex scalar functions, in reverse mode one needs to conjugate the differential return, and similarly the true result will be the conjugate of that value (in essence you can think of reverse-mode AD as working in the conjugate space). From 55b75677685f3356aa5a818567edb443d3ae4e80 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 5 Apr 2024 13:55:37 -0700 Subject: [PATCH 005/495] Restore 1.9+ setfield (#1383) --- src/rules/llvmrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index f066910450..4c06290ce9 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -60,7 +60,7 @@ function jlcall_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if in(name, ("ijl_f_getfield", "jl_f_getfield")) return common_jl_getfield_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end - if in(name, ("ijl_s_getfield", "jl_s_getfield")) + if in(name, ("ijl_f_setfield", "jl_f_setfield")) return common_setfield_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end if in(name, ("ijl_f__apply_iterate", "jl_f__apply_iterate")) From 724b9bc31c316744142b90e3b9c603a9d60270f5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 6 Apr 2024 04:56:41 -0700 Subject: [PATCH 006/495] Update Project.toml (#1384) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index dbc3b7478d..1032750266 100644 --- a/Project.toml +++ b/Project.toml @@ -25,7 +25,7 @@ EnzymeSpecialFunctionsExt = "SpecialFunctions" [compat] CEnum = "0.4, 0.5" EnzymeCore = "0.7" -Enzyme_jll = "0.0.103" +Enzyme_jll = "0.0.104" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1" ObjectFile = "0.4" From 1bf16f8217f2f0e516666f5dff2deb27a653302d Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 6 Apr 2024 08:55:50 -0700 Subject: [PATCH 007/495] add complex sqrt (#1324) * add complex sqrt * fixes --- src/compiler.jl | 208 ++++++++++++++++++++++++------------ src/compiler/interpreter.jl | 2 +- test/runtests.jl | 1 + 3 files changed, 139 insertions(+), 72 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 03a413c880..7fdcf86dbd 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -45,17 +45,6 @@ end import GPUCompiler: @safe_debug, @safe_info, @safe_warn, @safe_error -safe_println(head, tail) = ccall(:jl_safe_printf, Cvoid, (Cstring, Cstring...), "%s%s\n",head, tail) -macro safe_show(exs...) - blk = Expr(:block) - for ex in exs - push!(blk.args, :($safe_println($(sprint(Base.show_unquoted, ex)*" = "), - repr(begin local value = $(esc(ex)) end)))) - end - isempty(exs) || push!(blk.args, :value) - return blk -end - if LLVM.has_orc_v1() include("compiler/orcv1.jl") else @@ -70,6 +59,7 @@ include("compiler/utils.jl") const cmplx_known_ops = Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( typeof(Base.inv) => (:cmplx_inv, 1, nothing), + typeof(Base.sqrt) => (:cmplx_sqrt, 1, nothing), ) const known_ops = Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( @@ -4082,10 +4072,13 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function end for e in toErase if !isempty(collect(uses(e))) - @safe_show mod - @safe_show entry_f - @safe_show e - throw(AssertionError("Use after deletion")) + msg = sprint() do io + println(io, string(mod)) + println(io, string(entry_f)) + println(io, string(e)) + println(io, "Use after deletion") + end + throw(AssertionError(msg)) end LLVM.API.LLVMInstructionEraseFromParent(e) end @@ -4144,6 +4137,9 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function @assert eltype(ty) == value_type(wrapparm) store!(builder, wrapparm, ptr) push!(wrapper_args, ptr) + push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(arg.typ, ctx, dl, seen)))) + push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))))) + push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) else push!(wrapper_args, wrapparm) for attr in collect(parameter_attributes(entry_f, arg.codegen.i)) @@ -4206,16 +4202,26 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function position!(builder, def) ret!(builder, extract_value!(builder, res, 0)) + + push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen)))) + push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType))))) + push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) end elseif sret if sretPtr === nothing ret!(builder) else + push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen)))) + push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType))))) + push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) ret!(builder, load!(builder, RT, sretPtr)) end elseif LLVM.return_type(entry_ft) == LLVM.VoidType() ret!(builder) else + push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen)))) + push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType))))) + push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) ret!(builder, res) end dispose(builder) @@ -4231,14 +4237,52 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function attributes = function_attributes(wrapper_f) push!(attributes, StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi))))) push!(attributes, StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(rt))))) + + for prev in collect(function_attributes(entry_f)) + if kind(prev) == kind(StringAttribute("enzyme_ta_norecur")) + push!(attributes, prev) + end + if kind(prev) == kind(StringAttribute("enzyme_parmremove")) + push!(attributes, prev) + end + if kind(prev) == kind(StringAttribute("enzyme_math")) + push!(attributes, prev) + end + if kind(prev) == kind(StringAttribute("enzyme_shouldrecompute")) + push!(attributes, prev) + end + if kind(prev) == kind(EnumAttribute("readonly")) + push!(attributes, prev) + end + if kind(prev) == kind(EnumAttribute("readnone")) + push!(attributes, prev) + end + if kind(prev) == kind(EnumAttribute("argmemonly")) + push!(attributes, prev) + end + if kind(prev) == kind(EnumAttribute("inaccessiblememonly")) + push!(attributes, prev) + end + if kind(prev) == kind(EnumAttribute("speculatable")) + push!(attributes, prev) + end + if kind(prev) == kind(EnumAttribute("nofree")) + push!(attributes, prev) + end + if kind(prev) == kind(StringAttribute("enzyme_inactive")) + push!(attributes, prev) + end + end if LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMReturnStatusAction) != 0 - @safe_show mod - @safe_show LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction) - @safe_show wrapper_f - @safe_show parmsRemoved, retRemoved, prargs - flush(stdout) - throw(LLVM.LLVMException("broken function")) + msg = sprint() do io + println(io, string(mod)) + println(io, LVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction)) + println(io, string(wrapper_f)) + println(io, "parmsRemoved=", parmsRemoved, " retRemoved=", retRemoved, " prargs=", prargs) + println(io, "Broken function") + end + throw(LLVM.LLVMException(msg)) end ModulePassManager() do pm @@ -4333,19 +4377,17 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function end if LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMReturnStatusAction) != 0 - @safe_show mod - @safe_show LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction) - @safe_show wrapper_f - flush(stdout) - throw(LLVM.LLVMException("broken function")) + msg = sprint() do io + println(io, string(mod)) + println(io, LVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction)) + println(io, string(wrapper_f)) + println(io, "Broken function") + end + throw(LLVM.LLVMException(msg)) end return wrapper_f, returnRoots, boxedArgs, loweredArgs end -function adim(::Array{T, N}) where {T, N} - return N -end - function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; libraries::Bool=true, deferred_codegen::Bool=true, optimize::Bool=true, toplevel::Bool=true, strip::Bool=false, validate::Bool=true, only_entry::Bool=false, parent_job::Union{Nothing, CompilerJob} = nothing) @@ -4628,7 +4670,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; name = meth.name jlmod = meth.module - function handleCustom(name, attrs=[], setlink=true, noinl=true) + function handleCustom(llvmfn, name, attrs=[], setlink=true, noinl=true) attributes = function_attributes(llvmfn) custom[k_name] = linkage(llvmfn) if setlink @@ -4647,7 +4689,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; julia_activity_rule(llvmfn) if has_custom_rule - handleCustom("enzyme_custom", [StringAttribute("enzyme_preserve_primal", "*")]) + handleCustom(llvmfn, "enzyme_custom", [StringAttribute("enzyme_preserve_primal", "*")]) continue end @@ -4655,7 +4697,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; sparam_vals = mi.specTypes.parameters[2:end] # mi.sparam_vals if func == typeof(Base.eps) || func == typeof(Base.nextfloat) || func == typeof(Base.prevfloat) - handleCustom("jl_inactive_inout", [StringAttribute("enzyme_inactive"), + handleCustom(llvmfn, "jl_inactive_inout", [StringAttribute("enzyme_inactive"), EnumAttribute("readnone", 0), EnumAttribute("speculatable", 0), StringAttribute("enzyme_shouldrecompute") @@ -4663,7 +4705,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; continue end if func == typeof(Base.to_tuple_type) - handleCustom("jl_to_tuple_type", + handleCustom(llvmfn, "jl_to_tuple_type", [EnumAttribute("readonly", 0), EnumAttribute("inaccessiblememonly", 0), EnumAttribute("speculatable", 0), @@ -4674,7 +4716,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end if func == typeof(Base.Threads.threadid) || func == typeof(Base.Threads.nthreads) name = (func == typeof(Base.Threads.threadid)) ? "jl_threadid" : "jl_nthreads" - handleCustom(name, + handleCustom(llvmfn, name, [EnumAttribute("readonly", 0), EnumAttribute("inaccessiblememonly", 0), EnumAttribute("speculatable", 0), @@ -4689,15 +4731,15 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; # fn, but it doesn't presently so for now we will ensure this by hand if func == typeof(Base.Checked.throw_overflowerr_binaryop) llvmfn = functions(mod)[k.specfunc] - handleCustom("enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("readonly")]) + handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("readonly")]) continue end if EnzymeRules.is_inactive_from_sig(mi.specTypes; world, method_table, caller) - handleCustom("enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree")]) + handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree")]) continue end if EnzymeRules.is_inactive_noinl_from_sig(mi.specTypes; world, method_table, caller) - handleCustom("enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree")], false, false) + handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree")], false, false) for bb in blocks(llvmfn) for inst in instructions(bb) if isa(inst, LLVM.CallInst) @@ -4709,54 +4751,78 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; continue end if func == typeof(Base.enq_work) && length(sparam_vals) == 1 && first(sparam_vals) <: Task - handleCustom("jl_enq_work") + handleCustom(llvmfn, "jl_enq_work") continue end if func == typeof(Base.wait) || func == typeof(Base._wait) if length(sparam_vals) == 1 && first(sparam_vals) <: Task - handleCustom("jl_wait") + handleCustom(llvmfn, "jl_wait") end continue end if func == typeof(Base.Threads.threading_run) if length(sparam_vals) == 1 || length(sparam_vals) == 2 - handleCustom("jl_threadsfor") + handleCustom(llvmfn, "jl_threadsfor") end continue end - name = nothing - arity = nothing - toinject = nothing - Tys = nothing + @inline function find_math_method() + if func ∈ keys(known_ops) + name, arity, toinject = known_ops[func] + Tys = (Float32, Float64) + + if length(sparam_vals) == arity + T = first(sparam_vals) + legal = T ∈ Tys + + if legal + if name == :ldexp + if !(sparam_vals[2] <: Integer) + legal = false + end + elseif name == :pow + if sparam_vals[2] <: Integer + name = :powi + elseif sparam_vals[2] != T + legal = false + end + elseif name == :jl_rem2pi + else + if !all(==(T), sparam_vals) + legal = false + end + end + end + if legal + return name, toinject, T + end + end + end - if func ∈ keys(known_ops) - name, arity, toinject = known_ops[func] - Tys = (Float32, Float64) - elseif func ∈ keys(cmplx_known_ops) - name, arity, toinject = cmplx_known_ops[func] - Tys = (Complex{Float32}, Complex{Float64}) - else - continue - end + if func ∈ keys(cmplx_known_ops) + name, arity, toinject = cmplx_known_ops[func] + Tys = (Complex{Float32}, Complex{Float64}) + if length(sparam_vals) == arity + T = first(sparam_vals) + legal = T ∈ Tys - length(sparam_vals) == arity || continue - T = first(sparam_vals) - isfloat = T ∈ Tys - if !isfloat - continue + if legal + if !all(==(T), sparam_vals) + legal = false + end + end + if legal + return name, toinject, T + end + end + end + return nothing, nothing, nothing end - if name == :ldexp - sparam_vals[2] <: Integer || continue - elseif name == :pow - if sparam_vals[2] <: Integer - name = :powi - elseif sparam_vals[2] != T - continue - end - elseif name == :jl_rem2pi - else - all(==(T), sparam_vals) || continue + + name, toinject, T = find_math_method() + if name === nothing + continue end if toinject !== nothing @@ -4778,7 +4844,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; name = string(name) name = T == Float32 ? name*"f" : name - handleCustom(name, [EnumAttribute("readnone", 0), + handleCustom(llvmfn, name, [EnumAttribute("readnone", 0), StringAttribute("enzyme_shouldrecompute")]) end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index a2900b3356..5885679be5 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -94,7 +94,7 @@ function is_primitive_func(@nospecialize(TT)) end end - if ft == typeof(Base.inv) + if ft == typeof(Base.inv) || ft == typeof(Base.sqrt) if TT <: Tuple{ft, Complex{Float32}} || TT <: Tuple{ft, Complex{Float64}} return true end diff --git a/test/runtests.jl b/test/runtests.jl index 99aa3b208f..f6100bab81 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -277,6 +277,7 @@ make3() = (1.0, 2.0, 3.0) test_scalar(x->rem(x, 1), 0.7) test_scalar(x->rem2pi(x,RoundDown), 0.7) test_scalar(x->fma(x,x+1,x/3), 2.3) + test_scalar(sqrt, 1.7+2.1im) @test autodiff(Forward, sincos, Duplicated(1.0, 1.0))[1][1] ≈ cos(1.0) From 9d6b969b2c36f47d8acf31254498aa4a81edae31 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Mon, 15 Apr 2024 15:36:03 -0400 Subject: [PATCH 008/495] Try narrowing inactive rules for `rand` (#1388) * Restrict scope of inactive markers on Randon.rand * Simplify * Keep the randn rules * Add test * Rm newline * Fix typo * Add Random import for test --- src/internal_rules.jl | 4 ++-- test/internal_rules.jl | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 8afcaa0a20..b6dc8c75d6 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -66,10 +66,10 @@ end function EnzymeRules.inactive(::typeof(Core.kwfunc), args...) return nothing end -function EnzymeRules.inactive(::typeof(Random.rand), args...) +function EnzymeRules.inactive(::typeof(Random.rand), ::Random.AbstractRNG, ::Random.Sampler) return nothing end -function EnzymeRules.inactive(::typeof(Random.rand!), args...) +function EnzymeRules.inactive(::typeof(Random.rand!), ::Random.AbstractRNG, ::Random.Sampler, ::AbstractArray) return nothing end function EnzymeRules.inactive(::typeof(Random.randn), args...) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index f9b2aca957..e325189dc1 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -7,6 +7,7 @@ using FiniteDifferences using LinearAlgebra using SparseArrays using Test +import Random struct TPair a::Float64 @@ -432,4 +433,18 @@ end end end end + +@testset "rand and randn rules" begin + # Distributed as x + unit normal + uniform + struct MyDistribution + x::Float64 + end + + Random.rand(rng::Random.AbstractRNG, d::MyDistribution) = d.x + randn() + rand() + Random.rand(d::MyDistribution) = rand(Random.default_rng(), d) + + # Outer rand should be differentiated through, and inner rand and randn should be ignored. + @test autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) == ((1.0,),) +end + end # InternalRules From 43340b363fd01cc0e6c9af5c800fc043b752b72b Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Sat, 20 Apr 2024 20:21:00 +0100 Subject: [PATCH 009/495] Enzyme rules for `partialsort!` (#1373) * partialsort! Enzyme rules * partialsort! rule tests * version bound on test * fix test version bound --- src/internal_rules.jl | 105 +++++++++++++++++++++++++++++++++++++++++ test/internal_rules.jl | 24 ++++++++++ test/runtests.jl | 5 +- 3 files changed, 131 insertions(+), 3 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index b6dc8c75d6..ea33959b23 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -636,6 +636,111 @@ function EnzymeRules.reverse( return (nothing,) end +function EnzymeRules.forward( + ::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, + xs::Duplicated{T}, + k::Const{<:Union{Integer, OrdinalRange}}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}} + kv = k.val + inds = collect(eachindex(xs.val)) + partialsortperm!(inds, xs.val, kv; kwargs...) + xs.val .= xs.val[inds] + xs.dval .= xs.dval[inds] + if RT <: Const + return kv isa Integer ? xs.val[kv] : view(xs.val, kv) + elseif RT <: DuplicatedNoNeed + return kv isa Integer ? xs.dval[kv] : view(xs.dval, kv) + else + if kv isa Integer + return Duplicated(xs.val[kv], xs.dval[kv]) + else + return Duplicated(view(xs.val, kv), view(xs.dval, kv)) + end + end +end + +function EnzymeRules.forward( + ::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, + xs::BatchDuplicated{T, N}, + k::Const{<:Union{Integer, OrdinalRange}}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}, N} + kv = k.val + inds = collect(eachindex(xs.val)) + partialsortperm!(inds, xs.val, kv; kwargs...) + xs.val .= xs.val[inds] + for i in 1:N + xs.dval[i] .= xs.dval[i][inds] + end + if RT <: Const + return kv isa Integer ? xs.val[kv] : view(xs.val, kv) + elseif RT <: BatchDuplicatedNoNeed + if kv isa Integer + return ntuple(i -> xs.dval[i][kv], N) + else + return ntuple(i -> view(xs.dval[i], kv), N) + end + else + if kv isa Integer + return BatchDuplicated(xs.val[kv], ntuple(i -> xs.dval[i][kv], N)) + else + return BatchDuplicated(view(xs.val, kv), ntuple(i -> view(xs.dval[i], kv), N)) + end + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const, Active, DuplicatedNoNeed, Duplicated}}, + xs::Duplicated{T}, + k::Const{<:Union{Integer, OrdinalRange}}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}} + kv = k.val + inds = collect(eachindex(xs.val)) + partialsortperm!(inds, xs.val, kv; kwargs...) + xs.val .= xs.val[inds] + xs.dval .= xs.dval[inds] + if EnzymeRules.needs_primal(config) + primal = kv isa Integer ? xs.val[kv] : view(xs.val, kv) + else + primal = nothing + end + if RT <: Const || RT <: Active + shadow = nothing + else + shadow = kv isa Integer ? xs.dval[kv] : view(xs.dval, kv) + end + return EnzymeRules.AugmentedReturn(primal, shadow, inds) +end + +function EnzymeRules.reverse( + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(partialsort!)}, + dret::Union{Active, Type{<:Union{Const, Active, DuplicatedNoNeed, Duplicated}}}, + tape, + xs::Duplicated{T}, + k::Const{<:Union{Integer, OrdinalRange}}; + kwargs..., + ) where {T <: AbstractArray{<:AbstractFloat}} + inds = tape + kv = k.val + if dret isa Active + if kv isa Integer + xs.dval[kv] += dret.val + else + xs.dval[kv] .+= dret.val + end + end + back_inds = sortperm(inds) + xs.dval .= xs.dval[back_inds] + return (nothing, nothing) +end + function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...) fact = cholesky(A.val; kwargs...) if RT <: Const diff --git a/test/internal_rules.jl b/test/internal_rules.jl index e325189dc1..b076a51b3e 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -45,6 +45,30 @@ end @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0) @test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3 + function f3(x) + a = [2.0, 2.5, x, 1.0] + return partialsort(a, 2) + end + + @test autodiff(Forward, f3, Duplicated(1.5, 1.0))[1] == 1.0 + @test autodiff(Forward, f3, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0) + @test autodiff(Reverse, f3, Active(1.5))[1][1] == 1.0 + @test autodiff(Reverse, f3, Active(2.5))[1][1] == 0.0 + + function f4(x) + a = [2.0, 2.5, x, x / 2] + y = partialsort(a, 1:2) + return sum(y) + end + + @test autodiff(Forward, f4, Duplicated(1.5, 1.0))[1] == 1.5 + @static if VERSION < v"1.7-" || VERSION >= v"1.8-" + @test autodiff(Forward, f4, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.5, var"2"=3.0) + end + @test autodiff(Reverse, f4, Active(1.5))[1][1] == 1.5 + @test autodiff(Reverse, f4, Active(4.0))[1][1] == 0.5 + @test autodiff(Reverse, f4, Active(6.0))[1][1] == 0.0 + dd = Duplicated([TPair(1, 2), TPair(2, 3), TPair(0, 1)], [TPair(0, 0), TPair(0, 0), TPair(0, 0)]) res = Enzyme.autodiff(Reverse, sorterrfn, dd, Active(1.0)) diff --git a/test/runtests.jl b/test/runtests.jl index f6100bab81..0e9c455e99 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2883,11 +2883,10 @@ end @test autodiff(Forward, f6, Duplicated(4.0, 1.0))[1] ≈ 5/3 f7(x) = median([2.0, 1.0, x]) - # Fails on Julia 1.9 due to #880 - #=@test autodiff(Reverse, f7, Active, Active(1.5))[1][1] == 1 + @test autodiff(Reverse, f7, Active, Active(1.5))[1][1] == 1 @test autodiff(Forward, f7, Duplicated(1.5, 1.0))[1] == 1 @test autodiff(Reverse, f7, Active, Active(2.5))[1][1] == 0 - @test autodiff(Forward, f7, Duplicated(2.5, 1.0))[1] == 0=# + @test autodiff(Forward, f7, Duplicated(2.5, 1.0))[1] == 0 f8(x) = middle([2.0, x, 1.0]) @test autodiff(Reverse, f8, Active, Active(2.5))[1][1] == 0.5 From 4230cc09fb4218137cf86c732f6e4b203f1345c5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 20 Apr 2024 15:37:30 -0400 Subject: [PATCH 010/495] Remove vararg tuple warning (#1395) --- src/compiler.jl | 2 -- src/typetree.jl | 6 ++++++ src/utils.jl | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7fdcf86dbd..6fb1b7cc7a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -461,8 +461,6 @@ end end end - @inline is_concrete_tuple(x::T2) where T2 = (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) - @assert !Base.isabstracttype(T) if !(Base.isconcretetype(T) || is_concrete_tuple(T) || T isa UnionAll) throw(AssertionError("Type $T is not concrete type or concrete tuple")) diff --git a/src/typetree.jl b/src/typetree.jl index 50cd399cc0..79ca41cd81 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -191,6 +191,12 @@ function typetree_inner(@nospecialize(T), ctx, dl, seen::TypeTreeTable) return TypeTree() end + @static if VERSION >= v"1.7.0" + if is_concrete_tuple(T) && any(T2 isa Core.TypeofVararg for T2 in T.parameters) + return TypeTree() + end + end + try fieldcount(T) catch diff --git a/src/utils.jl b/src/utils.jl index 5a13a1673f..28dd0b4d65 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,9 @@ unsafe_to_pointer(ptr) = ccall(Base.@cfunction(x->x, Ptr{Cvoid}, (Ptr{Cvoid},)), Ptr{Cvoid}, (Any,), ptr) export unsafe_to_pointer +@inline is_concrete_tuple(x::T2) where T2 = (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) +export is_concrete_tuple + const Tracked = 10 const Derived = 11 export Tracked, Derived From 801c7434e9efa91e68a754bf6e5ed9c90b24dc70 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 20 Apr 2024 15:37:50 -0400 Subject: [PATCH 011/495] Identify and attempt fix for gc bug (#1386) * Identify and attempt fix for gc bug * audit unsafe_to_pointer * fixup! audit unsafe_to_pointer * Fixup * Add sparse eval test --------- Co-authored-by: Valentin Churavy --- src/compiler.jl | 2 -- src/rules/llvmrules.jl | 46 +++++++++++++++++++++++++----------------- src/utils.jl | 30 ++++++++++++++++++++++++--- test/runtests.jl | 20 ++++++++++++++++++ 4 files changed, 74 insertions(+), 24 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 6fb1b7cc7a..1cc7f14de8 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5257,8 +5257,6 @@ end function add_one_in_place(x) ty = typeof(x) - # ptr = Base.pointer_from_objref(x) - ptr = unsafe_to_pointer(x) if ty <: Base.RefValue || ty == Base.RefValue{Float64} x[] = recursive_add(x[], default_adjoint(eltype(ty))) else diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 4c06290ce9..1ce53d09f2 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -223,7 +223,6 @@ function arraycopy_fwd(B, orig, gutils, normalR, shadowR) elSize = LLVM.zext!(B, elSize, LLVM.IntType(8*sizeof(Csize_t))) len = get_array_len(B, shadowin) length = LLVM.mul!(B, len, elSize) - isVolatile = LLVM.ConstantInt(LLVM.IntType(1), 0) GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type" LLVM.memset!(B, get_array_data(B, shadowres), LLVM.ConstantInt(i8, 0, false), length, algn) end @@ -242,7 +241,6 @@ function arraycopy_fwd(B, orig, gutils, normalR, shadowR) elSize = LLVM.zext!(B, elSize, LLVM.IntType(8*sizeof(Csize_t))) len = get_array_len(B, shadowin) length = LLVM.mul!(B, len, elSize) - isVolatile = LLVM.ConstantInt(LLVM.IntType(1), 0) GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type" LLVM.memset!(B, get_array_data(callv), LLVM.ConstantInt(i8, 0, false), length, algn) end @@ -894,25 +892,35 @@ function jl_array_del_end_rev(B, orig, gutils, tape) offset = new_from_original(gutils, origops[2]) offset = lookup_value(gutils, offset, B) - if width == 1 - args = LLVM.Value[ - shadowin - offset - ] - LLVM.call!(B, fty, delF, args) - else - for idx in 1:width - args = LLVM.Value[ - extract_value!(B, shadowin, idx-1) - offset - ] - LLVM.call!(B, fty, delF, args) + # TODO get actual alignment + algn = 0 + + i8 = LLVM.IntType(8) + for idx in 1:width + anti = if width == 1 + shadowin + else + extract_value!(B, shadowin, idx-1) + end + if API.runtimeActivity() + emit_error(B, orig, "Enzyme: Not yet implemented runtime activity for reverse of jl_array_del_end") end + args = LLVM.Value[anti, offset] + + anti = shadowin + elSize = get_array_elsz(B, anti) + elSize = LLVM.zext!(B, elSize, LLVM.IntType(8*sizeof(Csize_t))) + len = get_array_len(B, anti) + + LLVM.call!(B, fty, delF, args) + + length = LLVM.mul!(B, len, elSize) + + GPUCompiler.@safe_warn "TODO reverse jl_array_del_end zero-set used memset rather than runtime type" + toset = get_array_data(B, anti) + toset = gep!(B, i8, toset, LLVM.Value[length]) + LLVM.memset!(B, toset, LLVM.ConstantInt(i8, 0, false), elSize, algn) end - - # GPUCompiler.@safe_warn "Not applying memsetUnknown concrete type" tt=string(tt) - emit_error(B, orig, "Not applying memset on reverse of jl_array_del_end") - # memset(data + idx * elsz, 0, inc * elsz); end return nothing end diff --git a/src/utils.jl b/src/utils.jl index 28dd0b4d65..a3268c6c94 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,4 +1,11 @@ -unsafe_to_pointer(ptr) = ccall(Base.@cfunction(x->x, Ptr{Cvoid}, (Ptr{Cvoid},)), Ptr{Cvoid}, (Any,), ptr) +""" + unsafe_to_pointer + +!!! warning + Assumes that `val` is globally rooted and pointer to it can be leaked. Prefer `pointer_from_objref`. + Only use inside Enzyme.jl should be for Types. +""" +@inline unsafe_to_pointer(val::Type{T}) where T = ccall(Base.@cfunction(x->x, Ptr{Cvoid}, (Ptr{Cvoid},)), Ptr{Cvoid}, (Any,), val) export unsafe_to_pointer @inline is_concrete_tuple(x::T2) where T2 = (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) @@ -8,12 +15,29 @@ const Tracked = 10 const Derived = 11 export Tracked, Derived +const captured_constants = Base.IdSet{Any}() + +# This mimicks literal_pointer_val / literal_pointer_val_slot function unsafe_to_llvm(val) T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) - fill_val = unsafe_to_pointer(val) - fill_val = LLVM.ConstantInt(convert(UInt, fill_val)) + # XXX: This prevents code from being runtime relocatable + # We likely should emit global variables and use something + # like `absolute_symbol_materialization` and write out cache-files + # that have relocation tables. + # TODO: What about things like `nothing` + if !Base.ismutable(val) + val = Core.Box(val) # FIXME many objects could be leaked here + @assert Base.ismutable(val) + push!(captured_constants, val) # Globally root + ptr = unsafe_load(Base.reinterpret(Ptr{Ptr{Cvoid}}, Base.pointer_from_objref(val))) + else + @assert Base.ismutable(val) + push!(captured_constants, val) # Globally root + ptr = Base.pointer_from_objref(val) + end + fill_val = LLVM.ConstantInt(convert(UInt, ptr)) fill_val = LLVM.const_inttoptr(fill_val, T_prjlvalue_UT) LLVM.const_addrspacecast(fill_val, T_prjlvalue) end diff --git a/test/runtests.jl b/test/runtests.jl index 0e9c455e99..d9d4c1ac37 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2380,6 +2380,26 @@ end @test dx ≈ [0 30 0] end + +function sparse_eval(x::Vector{Float64}) + A = sparsevec([1, 1, 2, 3], [2.0*x[2]^3.0, 1.0-x[1], 2.0+x[3], -1.0]) + B = sparsevec([1, 1, 2, 3], [2.0*x[2], 1.0-x[1], 2.0+x[3], -1.0]) + C = A + B + return A[1] +end + +@static if VERSION ≥ v"1.7-" +@testset "Type Unstable SparseArrays" begin + x = [3.1, 2.7, 8.2] + dx = [0.0, 0.0, 0.0] + + autodiff(Reverse, sparse_eval, Duplicated(x, dx)) + + @test x ≈ [3.1, 2.7, 8.2] + @test dx ≈ [-1.0, 43.74, 0] +end +end + @testset "Jacobian" begin function inout(v) [v[2], v[1]*v[1], v[1]*v[1]*v[1]] From 5ae36e5708f0452802c97029b566286ca5f2202b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 21 Apr 2024 06:31:40 +0200 Subject: [PATCH 012/495] [EnzymeTestUtils] Add rng keyword to testing functions and test with seed (#1398) --- lib/EnzymeTestUtils/Project.toml | 2 +- lib/EnzymeTestUtils/src/generate_tangent.jl | 23 ++++++++++++--------- lib/EnzymeTestUtils/src/test_forward.jl | 4 +++- lib/EnzymeTestUtils/src/test_reverse.jl | 8 ++++--- lib/EnzymeTestUtils/test/runtests.jl | 3 +++ 5 files changed, 25 insertions(+), 15 deletions(-) diff --git a/lib/EnzymeTestUtils/Project.toml b/lib/EnzymeTestUtils/Project.toml index 63cc2fe9ad..5069878de3 100644 --- a/lib/EnzymeTestUtils/Project.toml +++ b/lib/EnzymeTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeTestUtils" uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a" authors = ["Seth Axen ", "William Moses ", "Valentin Churavy "] -version = "0.1.5" +version = "0.1.6" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" diff --git a/lib/EnzymeTestUtils/src/generate_tangent.jl b/lib/EnzymeTestUtils/src/generate_tangent.jl index 676e87c7a3..e5ae0dd7e3 100644 --- a/lib/EnzymeTestUtils/src/generate_tangent.jl +++ b/lib/EnzymeTestUtils/src/generate_tangent.jl @@ -26,21 +26,24 @@ zero_tangent(x) = map_fields_recursive(zero_tangent, x) zero_tangent(::T) where {T<:AbstractFloat} = zero(T) zero_tangent(x::T) where {T<:Array{<:Number}} = zero_tangent.(x) -function auto_activity(arg::Tuple) +auto_activity(arg) = auto_activity(Random.default_rng(), arg) +function auto_activity(rng, arg::Tuple) if length(arg) == 2 && arg[2] isa Type && arg[2] <: Annotation - return _build_activity(arg...) + return _build_activity(rng, arg...) end return Const(arg) end -auto_activity(activity::Annotation) = activity -auto_activity(activity) = Const(activity) +auto_activity(rng, activity::Annotation) = activity +auto_activity(rng, activity) = Const(activity) -_build_activity(primal, ::Type{<:Const}) = Const(primal) -_build_activity(primal, ::Type{<:Active}) = Active(primal) -_build_activity(primal, ::Type{<:Duplicated}) = Duplicated(primal, rand_tangent(primal)) -function _build_activity(primal, ::Type{<:BatchDuplicated}) - return BatchDuplicated(primal, ntuple(_ -> rand_tangent(primal), 2)) +_build_activity(rng, primal, ::Type{<:Const}) = Const(primal) +_build_activity(rng, primal, ::Type{<:Active}) = Active(primal) +function _build_activity(rng, primal, ::Type{<:Duplicated}) + return Duplicated(primal, rand_tangent(rng, primal)) end -function _build_activity(primal, T::Type{<:Annotation}) +function _build_activity(rng, primal, ::Type{<:BatchDuplicated}) + return BatchDuplicated(primal, ntuple(_ -> rand_tangent(rng, primal), 2)) +end +function _build_activity(rng, primal, T::Type{<:Annotation}) throw(ArgumentError("Unsupported activity type: $T")) end diff --git a/lib/EnzymeTestUtils/src/test_forward.jl b/lib/EnzymeTestUtils/src/test_forward.jl index 53ecea94d3..eaef915a4d 100644 --- a/lib/EnzymeTestUtils/src/test_forward.jl +++ b/lib/EnzymeTestUtils/src/test_forward.jl @@ -17,6 +17,7 @@ additional constraints: # Keywords +- `rng::AbstractRNG`: The random number generator to use for generating random tangents. - `fdm=FiniteDifferences.central_fdm(5, 1)`: The finite differences method to use. - `fkwargs`: Keyword arguments to pass to `f`. - `rtol`: Relative tolerance for `isapprox`. @@ -54,6 +55,7 @@ function test_forward( f, ret_activity, args...; + rng::Random.AbstractRNG=Random.default_rng(), fdm=FiniteDifferences.central_fdm(5, 1), fkwargs::NamedTuple=NamedTuple(), rtol::Real=1e-9, @@ -67,7 +69,7 @@ function test_forward( end @testset "$testset_name" begin # format arguments for autodiff and FiniteDifferences - activities = map(auto_activity, (f, args...)) + activities = map(Base.Fix1(auto_activity, rng), (f, args...)) primals = map(x -> x.val, activities) # call primal, avoid mutating original arguments fcopy = deepcopy(first(primals)) diff --git a/lib/EnzymeTestUtils/src/test_reverse.jl b/lib/EnzymeTestUtils/src/test_reverse.jl index c2671126fe..1f36a04a5a 100644 --- a/lib/EnzymeTestUtils/src/test_reverse.jl +++ b/lib/EnzymeTestUtils/src/test_reverse.jl @@ -39,6 +39,7 @@ additional constraints: # Keywords +- `rng::AbstractRNG`: The random number generator to use for generating random tangents. - `fdm=FiniteDifferences.central_fdm(5, 1)`: The finite differences method to use. - `fkwargs`: Keyword arguments to pass to `f`. - `rtol`: Relative tolerance for `isapprox`. @@ -75,6 +76,7 @@ function test_reverse( f, ret_activity, args...; + rng::Random.AbstractRNG=Random.default_rng(), fdm=FiniteDifferences.central_fdm(5, 1), fkwargs::NamedTuple=NamedTuple(), rtol::Real=1e-9, @@ -87,7 +89,7 @@ function test_reverse( end @testset "$testset_name" begin # format arguments for autodiff and FiniteDifferences - activities = map(auto_activity, (f, args...)) + activities = map(Base.Fix1(auto_activity, rng), (f, args...)) primals = map(x -> x.val, activities) # call primal, avoid mutating original arguments fcopy = deepcopy(first(primals)) @@ -95,12 +97,12 @@ function test_reverse( y = fcopy(args_copy...; deepcopy(fkwargs)...) # generate tangent for output if !_any_batch_duplicated(map(typeof, activities)...) - ȳ = ret_activity <: Const ? zero_tangent(y) : rand_tangent(y) + ȳ = ret_activity <: Const ? zero_tangent(y) : rand_tangent(rng, y) else batch_size = _batch_size(map(typeof, activities)...) ks = ntuple(Symbol ∘ string, batch_size) ȳ = ntuple(batch_size) do _ - ret_activity <: Const ? zero_tangent(y) : rand_tangent(y) + return ret_activity <: Const ? zero_tangent(y) : rand_tangent(rng, y) end end # call finitedifferences, avoid mutating original arguments diff --git a/lib/EnzymeTestUtils/test/runtests.jl b/lib/EnzymeTestUtils/test/runtests.jl index 6eec393546..7785fe151a 100644 --- a/lib/EnzymeTestUtils/test/runtests.jl +++ b/lib/EnzymeTestUtils/test/runtests.jl @@ -1,6 +1,9 @@ using EnzymeTestUtils +using Random using Test +Random.seed!(0) + @testset "EnzymeTestUtils.jl" begin include("helpers.jl") include("test_approx.jl") From 8273a6e800ccce2a1cd67ec26f9940a552312dfb Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 21 Apr 2024 09:39:19 +0200 Subject: [PATCH 013/495] [EnzymeTestUtils] Vectorize function for FiniteDifferencesCalls (#1327) * Add to_vec * Use to_vec for tangent generation * Fix incorrect call to test_reverse * Use to_vec in calls to FiniteDifferences * Increment patch number * Add more cases to test_approx * Handle cases where constructorof not implemented but needed * Correctly handle case where ret activity is batched and all else const * Replace NamedTuple method with Dict * Add function for structured array testing * Add structured array test * Add tests for to_vec * Add to_vec * Use to_vec for tangent generation * Fix incorrect call to test_reverse * Use to_vec in calls to FiniteDifferences * Increment patch number * Add more cases to test_approx * Handle cases where constructorof not implemented but needed * Correctly handle case where ret activity is batched and all else const * Replace NamedTuple method with Dict * Add function for structured array testing * Add structured array test * Add tests for to_vec * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Add LinearAlgebra to test env * Run formatter on finitedifferences calls * Introduce AliasDict for checking for aliased arrays * Refactor to_vec to handle aliased arrays correctly * Test new to_vec behavior * Note difference between zero_tangent and make_zero * Restore deleted code * Don't treat immutable structs as equivalent * Remove obsolete limitation * Test cases where arrays alias * Document remaining limitation * Also test aliasing in when batching * Also test aliasing in forward-mode * Skip test that hits Julia GC bug pre v1.8 * Change mutating test to support returned arg * Clarify documentation of limitations * Skip structured array test for v1.7 * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Run formatter * Fix random seed in tests * Increment patch number --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- lib/EnzymeTestUtils/Project.toml | 5 +- lib/EnzymeTestUtils/src/EnzymeTestUtils.jl | 1 + .../src/compatible_activities.jl | 2 +- .../src/finite_difference_calls.jl | 43 +++-- lib/EnzymeTestUtils/src/generate_tangent.jl | 48 ++++- lib/EnzymeTestUtils/src/test_approx.jl | 20 ++ lib/EnzymeTestUtils/src/test_forward.jl | 4 +- lib/EnzymeTestUtils/src/test_reverse.jl | 17 +- lib/EnzymeTestUtils/src/to_vec.jl | 155 ++++++++++++++++ lib/EnzymeTestUtils/test/helpers.jl | 25 +++ lib/EnzymeTestUtils/test/runtests.jl | 1 + lib/EnzymeTestUtils/test/test_approx.jl | 31 ++++ lib/EnzymeTestUtils/test/test_forward.jl | 48 ++++- lib/EnzymeTestUtils/test/test_reverse.jl | 64 ++++++- lib/EnzymeTestUtils/test/to_vec.jl | 175 ++++++++++++++++++ 15 files changed, 591 insertions(+), 48 deletions(-) create mode 100644 lib/EnzymeTestUtils/src/to_vec.jl create mode 100644 lib/EnzymeTestUtils/test/to_vec.jl diff --git a/lib/EnzymeTestUtils/Project.toml b/lib/EnzymeTestUtils/Project.toml index 5069878de3..38b783facc 100644 --- a/lib/EnzymeTestUtils/Project.toml +++ b/lib/EnzymeTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeTestUtils" uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a" authors = ["Seth Axen ", "William Moses ", "Valentin Churavy "] -version = "0.1.6" +version = "0.1.7" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" @@ -21,8 +21,9 @@ Quaternions = "0.7" julia = "1.6" [extras] +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MetaTesting = "9e32d19f-1e4f-477a-8631-b16c78aa0f56" Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0" [targets] -test = ["MetaTesting", "Quaternions"] +test = ["LinearAlgebra", "MetaTesting", "Quaternions"] diff --git a/lib/EnzymeTestUtils/src/EnzymeTestUtils.jl b/lib/EnzymeTestUtils/src/EnzymeTestUtils.jl index cc4266cdbd..56a050455b 100644 --- a/lib/EnzymeTestUtils/src/EnzymeTestUtils.jl +++ b/lib/EnzymeTestUtils/src/EnzymeTestUtils.jl @@ -10,6 +10,7 @@ using Test export test_forward, test_reverse, are_activities_compatible include("output_control.jl") +include("to_vec.jl") include("test_approx.jl") include("compatible_activities.jl") include("finite_difference_calls.jl") diff --git a/lib/EnzymeTestUtils/src/compatible_activities.jl b/lib/EnzymeTestUtils/src/compatible_activities.jl index dcb584e067..48ee1a24df 100644 --- a/lib/EnzymeTestUtils/src/compatible_activities.jl +++ b/lib/EnzymeTestUtils/src/compatible_activities.jl @@ -20,7 +20,7 @@ _batch_size(::Type{BatchDuplicated{T,N}}) where {T,N} = N _batch_size(::Type{<:Annotation}) = nothing function _batch_size(activities...) sizes = filter(!isnothing, map(_batch_size, activities)) - isempty(sizes) && return nothing + isempty(sizes) && return 1 @assert all(==(sizes[1]), sizes) return sizes[1] end diff --git a/lib/EnzymeTestUtils/src/finite_difference_calls.jl b/lib/EnzymeTestUtils/src/finite_difference_calls.jl index 56dec44569..7433b9ccd9 100644 --- a/lib/EnzymeTestUtils/src/finite_difference_calls.jl +++ b/lib/EnzymeTestUtils/src/finite_difference_calls.jl @@ -22,17 +22,25 @@ function _fd_forward(fdm, f, rettype, y, activities) xs = map(x -> x.val, activities) ẋs = map(a -> a isa Const ? nothing : a.dval, activities) ignores = map(a -> a isa Const, activities) - f2 = _wrap_forward_function(f, xs, ignores) + f_sig_args = _wrap_forward_function(f, xs, ignores) ignores = collect(ignores) + _, from_vec_out = to_vec(y) + sig_arg_val_vec, from_vec_in = to_vec(xs[.!ignores]) + # vectorize inputs and outputs of function + f_vec = first ∘ to_vec ∘ Base.splat(f_sig_args) ∘ from_vec_in if rettype <: Union{Duplicated,DuplicatedNoNeed} all(ignores) && return zero_tangent(y) - sigargs = zip(xs[.!ignores], ẋs[.!ignores]) - return FiniteDifferences.jvp(fdm, f2, sigargs...) + sig_arg_dval_vec, _ = to_vec(ẋs[.!ignores]) + ret_deval_vec = FiniteDifferences.jvp(fdm, f_vec, + (sig_arg_val_vec, sig_arg_dval_vec)) + return from_vec_out(ret_deval_vec) elseif rettype <: Union{BatchDuplicated,BatchDuplicatedNoNeed} all(ignores) && return (var"1"=zero_tangent(y),) - sig_arg_vals = xs[.!ignores] ret_dvals = map(ẋs[.!ignores]...) do sig_args_dvals... - FiniteDifferences.jvp(fdm, f2, zip(sig_arg_vals, sig_args_dvals)...) + sig_args_dvals_vec, _ = to_vec(sig_args_dvals) + ret_dval_vec = FiniteDifferences.jvp(fdm, f_vec, + (sig_arg_val_vec, sig_args_dvals_vec)) + return from_vec_out(ret_dval_vec) end return NamedTuple{ntuple(Symbol, length(ret_dvals))}(ret_dvals) else @@ -58,7 +66,7 @@ Call `FiniteDifferences.j′vp` on `f` with the arguments `xs` determined by `ac function _fd_reverse(fdm, f, ȳ, activities, active_return) xs = map(x -> x.val, activities) ignores = map(a -> a isa Const, activities) - f2 = _wrap_reverse_function(active_return, f, xs, ignores) + f_sig_args = _wrap_reverse_function(active_return, f, xs, ignores) all(ignores) && return map(zero_tangent, xs) ignores = collect(ignores) is_batch = _any_batch_duplicated(map(typeof, activities)...) @@ -74,18 +82,21 @@ function _fd_reverse(fdm, f, ȳ, activities, active_return) sigargs = xs[.!ignores] s̄igargs = x̄s[.!ignores] sigarginds = eachindex(x̄s)[.!ignores] + sigargs_vec, from_vec_in = to_vec(sigargs) + # vectorize inputs and outputs of function + f_vec = first ∘ to_vec ∘ Base.splat(f_sig_args) ∘ from_vec_in if !is_batch - fd = FiniteDifferences.j′vp(fdm, f2, (ȳ, s̄igargs...), sigargs...) + ȳ_extended = (ȳ, s̄igargs...) + ȳ_extended_vec, _ = to_vec(ȳ_extended) + fd_vec = only(FiniteDifferences.j′vp(fdm, f_vec, ȳ_extended_vec, sigargs_vec)) + fd = from_vec_in(fd_vec) else - fd = Tuple( - zip( - map(ȳ, s̄igargs...) do y_dval, sigargs_dvals... - FiniteDifferences.j′vp( - fdm, f2, (y_dval, sigargs_dvals...), sigargs... - ) - end..., - ), - ) + fd = Tuple(zip(map(ȳ, s̄igargs...) do ȳ_extended... + ȳ_extended_vec, _ = to_vec(ȳ_extended) + fd_vec = only(FiniteDifferences.j′vp(fdm, f_vec, ȳ_extended_vec, + sigargs_vec)) + return from_vec_in(fd_vec) + end...)) end @assert length(fd) == length(sigarginds) x̄s[sigarginds] = collect(fd) diff --git a/lib/EnzymeTestUtils/src/generate_tangent.jl b/lib/EnzymeTestUtils/src/generate_tangent.jl index e5ae0dd7e3..d774036e7e 100644 --- a/lib/EnzymeTestUtils/src/generate_tangent.jl +++ b/lib/EnzymeTestUtils/src/generate_tangent.jl @@ -4,9 +4,9 @@ function map_fields_recursive(f, x::T...) where {T} fields = map(ConstructionBase.getfields, x) all(isempty, fields) && return first(x) new_fields = map(fields...) do xi... - map_fields_recursive(f, xi...) + return map_fields_recursive(f, xi...) end - return ConstructionBase.constructorof(T)(new_fields...) + return _construct(T, new_fields...) end function map_fields_recursive(f, x::T...) where {T<:Union{Array,Tuple,NamedTuple}} map(x...) do xi... @@ -17,14 +17,20 @@ map_fields_recursive(f, x::T...) where {T<:AbstractFloat} = f(x...) map_fields_recursive(f, x::Array{<:Number}...) = f(x...) rand_tangent(x) = rand_tangent(Random.default_rng(), x) -rand_tangent(rng, x) = map_fields_recursive(Base.Fix1(rand_tangent, rng), x) -# make numbers prettier sometimes when errors are printed. -rand_tangent(rng, ::T) where {T<:AbstractFloat} = rand(rng, -9:T(0.01):9) -rand_tangent(rng, x::T) where {T<:Array{<:Number}} = rand_tangent.(rng, x) +function rand_tangent(rng, x) + v, from_vec = to_vec(x) + T = eltype(v) + # make numbers prettier sometimes when errors are printed. + v_new = rand(rng, -9:T(0.01):9, length(v)) + return from_vec(v_new) +end -zero_tangent(x) = map_fields_recursive(zero_tangent, x) -zero_tangent(::T) where {T<:AbstractFloat} = zero(T) -zero_tangent(x::T) where {T<:Array{<:Number}} = zero_tangent.(x) +# differs from Enzyme.make_zero primarily in that reshaped Arrays in the argument will share +# the same memory in the output. +function zero_tangent(x) + v, from_vec = to_vec(x) + return from_vec(zero(v)) +end auto_activity(arg) = auto_activity(Random.default_rng(), arg) function auto_activity(rng, arg::Tuple) @@ -47,3 +53,27 @@ end function _build_activity(rng, primal, T::Type{<:Annotation}) throw(ArgumentError("Unsupported activity type: $T")) end + +# below code is adapted from https://github.com/JuliaDiff/FiniteDifferences.jl/blob/99ad77f05bdf6c023b249025dbb8edc746d52b4f/src/to_vec.jl +# MIT Expat License +# Copyright (c) 2018 Invenia Technical Computing + +# get around the constructors and make the type directly +# Note this is moderately evil accessing julia's internals +if VERSION >= v"1.3" + @generated function _force_construct(T, args...) + return Expr(:splatnew, :T, :args) + end +else + @generated function _force_construct(T, args...) + return Expr(:new, :T, Any[:(args[$i]) for i in 1:length(args)]...) + end +end + +function _construct(T, args...) + try + return ConstructionBase.constructorof(T)(args...) + catch MethodError + return _force_construct(T, args...) + end +end diff --git a/lib/EnzymeTestUtils/src/test_approx.jl b/lib/EnzymeTestUtils/src/test_approx.jl index c36657827e..305bef4021 100644 --- a/lib/EnzymeTestUtils/src/test_approx.jl +++ b/lib/EnzymeTestUtils/src/test_approx.jl @@ -21,6 +21,26 @@ function test_approx(x::AbstractArray, y::AbstractArray, msg; kwargs...) end return nothing end +function test_approx(x::Tuple, y::Tuple, msg; kwargs...) + @test_msg "$msg: lengths must match" length(x) == length(y) + for i in eachindex(x) + msg_new = "$msg: ::$(typeof(x))[$i]" + test_approx(x[i], y[i], msg_new; kwargs...) + end + return nothing +end +function test_approx(x::Dict, y::Dict, msg; kwargs...) + @test_msg "$msg: keys must match" issetequal(keys(x), keys(y)) + for k in keys(x) + msg_new = "$msg: ::$(typeof(x))[$k]" + test_approx(x[k], y[k], msg_new; kwargs...) + end + return nothing +end +function test_approx(x::Type, y::Type, msg; kwargs...) + @test_msg "$msg: types must match" x === y + return nothing +end test_approx(x, y, msg; kwargs...) = _test_fields_approx(x, y, msg; kwargs...) function _test_fields_approx(x, y, msg; kwargs...) diff --git a/lib/EnzymeTestUtils/src/test_forward.jl b/lib/EnzymeTestUtils/src/test_forward.jl index eaef915a4d..e57a5c7e34 100644 --- a/lib/EnzymeTestUtils/src/test_forward.jl +++ b/lib/EnzymeTestUtils/src/test_forward.jl @@ -3,8 +3,8 @@ Test `Enzyme.autodiff` of `f` in `Forward`-mode against finite differences. -`f` has all constraints of the same argument passed to `Enzyme.autodiff`, with several -additional constraints: +`f` has all constraints of the same argument passed to `Enzyme.autodiff`, with additional +constraints: - If it mutates one of its arguments, it _must_ return that argument. # Arguments diff --git a/lib/EnzymeTestUtils/src/test_reverse.jl b/lib/EnzymeTestUtils/src/test_reverse.jl index 1f36a04a5a..f204b00a7b 100644 --- a/lib/EnzymeTestUtils/src/test_reverse.jl +++ b/lib/EnzymeTestUtils/src/test_reverse.jl @@ -8,7 +8,7 @@ for N in 1:30 function call_with_kwargs(fkwargs::NT, f::FT, $(argexprs...)) where {NT, FT} Base.@_inline_meta @static if VERSION ≤ v"1.8" - # callsite inline syntax unsupported in <= 1.8 + # callsite inline syntax unsupported in <= 1.8 f($(argexprs...); fkwargs...) else @inline f($(argexprs...); fkwargs...) @@ -23,11 +23,10 @@ end Test `Enzyme.autodiff_thunk` of `f` in `ReverseSplitWithPrimal`-mode against finite differences. -`f` has all constraints of the same argument passed to `Enzyme.autodiff_thunk`, with several +`f` has all constraints of the same argument passed to `Enzyme.autodiff_thunk`, with additional constraints: -- If it mutates one of its arguments, it must not also return that argument. -- If the return value is a struct, then all floating point numbers contained in the struct - or its fields must be in arrays. +- If an `Array{<:AbstractFloat}` appears in the input/output, then a reshaped version of it + may not also appear in the input/output. # Arguments @@ -96,13 +95,13 @@ function test_reverse( args_copy = deepcopy(Base.tail(primals)) y = fcopy(args_copy...; deepcopy(fkwargs)...) # generate tangent for output - if !_any_batch_duplicated(map(typeof, activities)...) + if !_any_batch_duplicated(ret_activity, map(typeof, activities)...) ȳ = ret_activity <: Const ? zero_tangent(y) : rand_tangent(rng, y) else - batch_size = _batch_size(map(typeof, activities)...) + batch_size = _batch_size(ret_activity, map(typeof, activities)...) ks = ntuple(Symbol ∘ string, batch_size) ȳ = ntuple(batch_size) do _ - return ret_activity <: Const ? zero_tangent(y) : rand_tangent(rng, y) + return ret_activity <: Const ? zero_tangent(y) : rand_tangent(y) end end # call finitedifferences, avoid mutating original arguments @@ -137,7 +136,7 @@ function test_reverse( else # if there's a shadow result, then we need to set it to our random adjoint if !(shadow_result === nothing) - if !_any_batch_duplicated(map(typeof, activities)...) + if !_any_batch_duplicated(ret_activity, map(typeof, activities)...) map_fields_recursive(copyto!, shadow_result, ȳ) else for (sr, dy) in zip(shadow_result, ȳ) diff --git a/lib/EnzymeTestUtils/src/to_vec.jl b/lib/EnzymeTestUtils/src/to_vec.jl new file mode 100644 index 0000000000..412c6efb1b --- /dev/null +++ b/lib/EnzymeTestUtils/src/to_vec.jl @@ -0,0 +1,155 @@ +# Like an IdDict, but also handles cases where 2 arrays share the same memory due to +# reshaping +struct AliasDict{K,V} <: AbstractDict{K,V} + id_dict::IdDict{K,V} + dataids_dict::IdDict{Tuple{UInt,Vararg{UInt}},V} +end +AliasDict() = AliasDict(IdDict(), IdDict{Tuple{UInt,Vararg{UInt}},Any}()) + +function Base.haskey(d::AliasDict, key) + haskey(d.id_dict, key) && return true + key isa Array && haskey(d.dataids_dict, Base.dataids(key)) && return true + return false +end + +Base.getindex(d::AliasDict, key) = d.id_dict[key] +function Base.getindex(d::AliasDict, key::Array) + haskey(d.id_dict, key) && return d.id_dict[key] + dataids = Base.dataids(key) + return d.dataids_dict[dataids] +end + +function Base.setindex!(d::AliasDict, val, key) + d.id_dict[key] = val + if key isa Array + dataids = Base.dataids(key) + d.dataids_dict[dataids] = val + end + return d +end + +# alternative to FiniteDifferences.to_vec to use Enzyme's semantics for arrays instead of +# ChainRules': Enzyme treats tangents of AbstractArrays the same as tangents of any other +# struct (i.e. with a container of the same type as the original), while ChainRules +# represents the tangent with an array of some type that is tangent to the subspace defined +# by the original array type. +# We take special care that floats that occupy the same memory in the argument only appear +# once in the vector, and that the reconstructed object shares the same memory pattern + +function to_vec(x) + x_vec, from_vec_inner = to_vec(x, AliasDict()) + from_vec(x_vec::Vector{<:AbstractFloat}) = from_vec_inner(x_vec, AliasDict()) + return x_vec, from_vec +end + +# base case: we've unwrapped to a number, so we break the recursion +function to_vec(x::AbstractFloat, seen_vecs::AliasDict) + AbstractFloat_from_vec(v::Vector{<:AbstractFloat}, _) = oftype(x, only(v)) + return [x], AbstractFloat_from_vec +end + +# basic containers: loop over defined elements, recursively converting them to vectors +function to_vec(x::RT, seen_vecs::AliasDict) where {RT<:Array} + has_seen = haskey(seen_vecs, x) + is_const = Enzyme.Compiler.guaranteed_const_nongen(RT, nothing) + if has_seen || is_const + x_vec = Float32[] + else + x_vecs = Vector{<:AbstractFloat}[] + from_vecs = [] + subvec_inds = UnitRange{Int}[] + l = 0 + for i in eachindex(x) + isassigned(x, i) || continue + xi_vec, xi_from_vec = to_vec(x[i], seen_vecs) + push!(x_vecs, xi_vec) + push!(from_vecs, xi_from_vec) + push!(subvec_inds, (l + 1):(l + length(xi_vec))) + l += length(xi_vec) + end + x_vec = reduce(vcat, x_vecs; init=Float32[]) + seen_vecs[x] = x_vec + end + function Array_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict) + if xor(has_seen, haskey(seen_xs, x)) + throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized.")) + end + has_seen && return reshape(seen_xs[x], size(x)) + is_const && return x + x_new = typeof(x)(undef, size(x)) + k = 1 + for i in eachindex(x) + isassigned(x, i) || continue + xi = from_vecs[k](x_vec_new[subvec_inds[k]], seen_xs) + x_new[i] = xi + k += 1 + end + seen_xs[x] = x_new + return x_new + end + return x_vec, Array_from_vec +end +function to_vec(x::Tuple, seen_vecs::AliasDict) + x_vec, from_vec = to_vec(collect(x), seen_vecs) + function Tuple_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict) + return typeof(x)(Tuple(from_vec(x_vec_new, seen_xs))) + end + return x_vec, Tuple_from_vec +end +function to_vec(x::NamedTuple, seen_vecs::AliasDict) + x_vec, from_vec = to_vec(values(x), seen_vecs) + function NamedTuple_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict) + return NamedTuple{keys(x)}(from_vec(x_vec_new, seen_xs)) + end + return x_vec, NamedTuple_from_vec +end + +# fallback: for any other struct, loop over fields, recursively converting them to vectors +function to_vec(x::RT, seen_vecs::AliasDict) where {RT} + has_seen = haskey(seen_vecs, x) + is_const = Enzyme.Compiler.guaranteed_const_nongen(RT, nothing) + if has_seen || is_const + x_vec = Float32[] + else + @assert !Base.isabstracttype(RT) + @assert Base.isconcretetype(RT) + nf = fieldcount(RT) + flds = Vector{Any}(undef, nf) + for i in 1:nf + if isdefined(x, i) + flds[i] = xi = getfield(x, i) + elseif !ismutable(x) + nf = i - 1 # rest of tail must be undefined values + break + end + end + x_vec, fields_from_vec = to_vec(flds, seen_vecs) + if ismutable(x) + seen_vecs[x] = x_vec + end + end + function Struct_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict) + if xor(has_seen, haskey(seen_xs, x)) + throw(ErrorException("Objects must be reconstructed in the same order as they are vectorized.")) + end + has_seen && return seen_xs[x] + (is_const || nf == 0) && return x + flds_new = fields_from_vec(x_vec_new, seen_xs) + if ismutable(x) + x_new = ccall(:jl_new_struct_uninit, Any, (Any,), RT) + for i in 1:nf + if isdefined(x, i) + xi = flds_new[i] + ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), x_new, i - 1, xi) + end + end + else + x_new = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds_new, nf) + end + if ismutable(x) + seen_xs[x] = x_new + end + return x_new + end + return x_vec, Struct_from_vec +end diff --git a/lib/EnzymeTestUtils/test/helpers.jl b/lib/EnzymeTestUtils/test/helpers.jl index c3e3ece134..6754b0a935 100644 --- a/lib/EnzymeTestUtils/test/helpers.jl +++ b/lib/EnzymeTestUtils/test/helpers.jl @@ -1,8 +1,22 @@ +using LinearAlgebra + struct TestStruct{X,A} x::X a::A end +struct TestStruct2 + x::Any + a::Any + TestStruct2(x) = new(x) +end + +mutable struct MutableTestStruct + x::Any + a::Any + MutableTestStruct() = new() +end + struct MutatedCallable{T} x::T end @@ -14,3 +28,14 @@ end f_array(x) = sum(abs2, x) f_multiarg(x::AbstractArray, a) = abs2.(a .* x) + +function f_structured_array(x::Hermitian) + y = x * 3 + # mutate the unused triangle, which ensures that our Jacobian differs from FiniteDifferences + if y.uplo == 'U' + LowerTriangular(y.data) .*= 2 + else + UpperTriangular(y.data) .*= 2 + end + return y +end diff --git a/lib/EnzymeTestUtils/test/runtests.jl b/lib/EnzymeTestUtils/test/runtests.jl index 7785fe151a..8883ee78ef 100644 --- a/lib/EnzymeTestUtils/test/runtests.jl +++ b/lib/EnzymeTestUtils/test/runtests.jl @@ -8,6 +8,7 @@ Random.seed!(0) include("helpers.jl") include("test_approx.jl") include("compatible_activities.jl") + include("to_vec.jl") include("generate_tangent.jl") include("test_forward.jl") include("test_reverse.jl") diff --git a/lib/EnzymeTestUtils/test/test_approx.jl b/lib/EnzymeTestUtils/test/test_approx.jl index 57f1145576..99b8070bc1 100644 --- a/lib/EnzymeTestUtils/test/test_approx.jl +++ b/lib/EnzymeTestUtils/test/test_approx.jl @@ -25,6 +25,37 @@ end @test fails(() -> test_approx([0, 1], [0, 1 + 1e-9]; rtol=1e-9)) @test errors(() -> test_approx([1, 2], [1, 2, 3])) end + @testset "tuples" begin + test_approx((1, 2), (1, 2)) + test_approx((1, 2), (1, 2 + 1e-9); atol=1.1e-9) + @test fails(() -> test_approx((1, 2), (1, 2 + 1e-9); atol=1e-9)) + test_approx((0, 1), (0, 1 + 1e-9); rtol=1.1e-9) + @test fails(() -> test_approx((0, 1), (0, 1 + 1e-9); rtol=1e-9)) + @test fails(() -> test_approx((1, 2), (1, 2, 3))) + end + @testset "type" begin + test_approx(Bool, Bool) + test_approx(String, String) + @test fails(() -> test_approx(Bool, String)) + end + @testset "dict" begin + x1 = Dict(:x => randn(3), :y => randn(2)) + x2 = Dict(:x => copy(x1[:x]), :y => copy(x1[:y])) + test_approx(x1, x2) + for i in eachindex(x2[:x]), err in (1e-2, 1e-9) + y = copy(x1[:x]) + y[i] += rand((-1, 1)) * err + x2[:x] = y + test_approx(x1, x2; atol=err * 1.1) + @test fails() do + return test_approx(x1, x2; atol=err * 0.9) + end + end + x2[:x] = vcat(x1[:x], 1.0) + @test errors() do + return test_approx(x1, x2; atol=err * 0.9) + end + end @testset "non-numeric types" begin test_approx(:x, :x) @test fails(() -> test_approx(:x, :y)) diff --git a/lib/EnzymeTestUtils/test/test_forward.jl b/lib/EnzymeTestUtils/test/test_forward.jl index a2ab010042..7f870af7bf 100644 --- a/lib/EnzymeTestUtils/test/test_forward.jl +++ b/lib/EnzymeTestUtils/test/test_forward.jl @@ -1,5 +1,6 @@ using Enzyme using EnzymeTestUtils +using LinearAlgebra using MetaTesting using Test @@ -133,6 +134,49 @@ end end end + VERSION >= v"1.8" && @testset "structured array inputs/outputs" begin + @testset for Tret in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated), + T in (Float32, Float64, ComplexF32, ComplexF64) + + # if some are batch, none must be duplicated + are_activities_compatible(Tret, Tx) || continue + + x = Hermitian(randn(T, 5, 5)) + + atol = rtol = sqrt(eps(real(T))) + test_forward(f_structured_array, Tret, (x, Tx); atol, rtol) + end + end + + @testset "equivalent arrays in output" begin + function f(x) + z = x * 2 + return (z, z) + end + x = randn(2, 3) + @testset for Tret in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Tx) || continue + test_forward(f, Tret, (x, Tx)) + end + end + + @testset "arrays sharing memory in output" begin + function f(x) + z = x * 2 + return (z, z) + end + x = randn(2, 3) + @testset for Tret in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Tx) || continue + test_forward(f, Tret, (x, Tx)) + end + end + @testset "mutating function" begin Enzyme.API.runtimeActivity!(true) sz = (2, 3) @@ -163,10 +207,10 @@ end x = randn(3) a = randn() - test_reverse(f_kwargs_fwd!, Const, (x, Tx); fkwargs=(; a)) + test_forward(f_kwargs_fwd!, Const, (x, Tx); fkwargs=(; a)) fkwargs = (; a, incorrect_primal=true) @test fails() do - test_forward(f_kwargs_fwd!, Const, (x, Tx); fkwargs) + return test_forward(f_kwargs_fwd!, Const, (x, Tx); fkwargs) end end end diff --git a/lib/EnzymeTestUtils/test/test_reverse.jl b/lib/EnzymeTestUtils/test/test_reverse.jl index f73f3eaed3..b394fa171d 100644 --- a/lib/EnzymeTestUtils/test/test_reverse.jl +++ b/lib/EnzymeTestUtils/test/test_reverse.jl @@ -1,11 +1,12 @@ using Enzyme using EnzymeTestUtils +using LinearAlgebra using MetaTesting using Test function f_mut_rev!(y, x, a) map!(xi -> xi * a, y, x) - return nothing + return y end f_kwargs_rev(x; a=3.0, kwargs...) = a .* x .^ 2 @@ -90,23 +91,72 @@ end end end + VERSION >= v"1.8" && @testset "structured array inputs/outputs" begin + @testset for Tret in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated), + T in (Float32, Float64, ComplexF32, ComplexF64) + + # if some are batch, none must be duplicated + are_activities_compatible(Tret, Tx) || continue + + x = Hermitian(randn(T, 5, 5)) + + atol = rtol = sqrt(eps(real(T))) + test_reverse(f_structured_array, Tret, (x, Tx); atol, rtol) + end + end + + @testset "equivalent arrays in output" begin + function f(x) + z = x * 2 + return (z, z) + end + x = randn(2, 3) + + @testset for Tret in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Tx) || continue + test_reverse(f, Tret, (x, Tx)) + end + end + + @testset "arrays sharing memory in output" begin + function f(x) + z = x * 2 + return (z, vec(z)) + end + x = randn(2, 3) + @testset for Tret in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Tx) || continue + if Tx <: Const + test_reverse(f, Tret, (x, Tx)) + else + @test_broken !fails() do + return test_reverse(f, Tret, (x, Tx)) + end + end + end + end + @testset "mutating function" begin sz = (2, 3) @testset for Ty in (Const, Duplicated, BatchDuplicated), - Tx in (Const, Duplicated, BatchDuplicated), - Ta in (Const, Active), - Tret in (Const,), # return value is nothing - T in (Float32, Float64, ComplexF32, ComplexF64) + Tx in (Const, Duplicated, BatchDuplicated), + Ta in (Const, Active), + T in (Float32, Float64, ComplexF32, ComplexF64) # if some are batch, none must be duplicated - are_activities_compatible(Tret, Ty, Tx, Ta) || continue + are_activities_compatible(Ty, Tx, Ta) || continue x = randn(T, sz) y = zeros(T, sz) a = randn(T) atol = rtol = sqrt(eps(real(T))) - test_reverse(f_mut_rev!, Tret, (y, Ty), (x, Tx), (a, Ta); atol, rtol) + test_reverse(f_mut_rev!, Ty, (y, Ty), (x, Tx), (a, Ta); atol, rtol) end end diff --git a/lib/EnzymeTestUtils/test/to_vec.jl b/lib/EnzymeTestUtils/test/to_vec.jl new file mode 100644 index 0000000000..3f7609d47a --- /dev/null +++ b/lib/EnzymeTestUtils/test/to_vec.jl @@ -0,0 +1,175 @@ +using EnzymeTestUtils +using EnzymeTestUtils: to_vec +using Test + +function test_to_vec(x) + x_vec, from_vec = to_vec(x) + @test x_vec isa Vector{<:AbstractFloat} + x2 = from_vec(x_vec) + @test typeof(x2) === typeof(x) + return EnzymeTestUtils.test_approx(x2, x) +end + +@testset "to_vec" begin + @testset "BLAS floats" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64) + x = randn(T) + test_to_vec(x) + if T <: Real + @test to_vec(x)[1] == [x] + else + @test to_vec(x)[1] == [reim(x)...] + end + end + end + + @testset "non-vectorizable cases" begin + @testset for x in [Bool, (), true, 1, [2], (3, "string")] + test_to_vec(x) + @test isempty(to_vec(x)[1]) + end + end + + @testset "array of floats" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64), + sz in (2, (2, 3), (2, 3, 4)) + + test_to_vec(randn(T, sz)) + end + end + + @testset "struct" begin + v = randn(2, 3) + x = TestStruct(1, TestStruct("foo", v)) + test_to_vec(x) + @test to_vec(x)[1] == vec(v) + + x = (TestStruct(1.0, 2.0), TestStruct(1.0, 2.0)) + v, from_vec = to_vec(x) + @test v == [1.0, 2.0, 1.0, 2.0] + @test from_vec(v) === x + end + + @testset "incompletely initialized struct" begin + x = randn(2, 3) + y = TestStruct2(x) + v, from_vec = to_vec(y) + @test v == vec(x) + v2 = randn(size(v)) + y2 = from_vec(v2) + @test y2.x == reshape(v2, size(x)) + @test !isdefined(y2, :a) + end + + @testset "mutable struct" begin + @testset for k in (:a, :x) + x = randn(2, 3) + y = MutableTestStruct() + setfield!(y, k, x) + @test isdefined(y, k) + @test getfield(y, k) == x + v, from_vec = to_vec(y) + @test v == vec(x) + v2 = randn(size(v)) + y2 = from_vec(v2) + @test getfield(y2, k) == reshape(v2, size(x)) + @test !isdefined(y2, k === :a ? :x : :a) + end + + y = MutableTestStruct() + y.x = randn() + t = (y, y) + v, from_vec = to_vec(t) + @test v == [y.x] + t2 = from_vec(v) + @test t2[1] === t2[2] + + t = (y, deepcopy(y)) + v, from_vec = to_vec(t) + @test v == [y.x, y.x] + t2 = from_vec(v) + @test t2[1].x == t2[2].x + @test t2[1] !== t2[2] + end + + @testset "nested array" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64), + sz in (2, (2, 3), (2, 3, 4)) + + test_to_vec([randn(T, sz) for _ in 1:10]) + end + end + + @testset "partially defined array" begin + @testset for i in 1:2 + x = Vector{Vector{Float64}}(undef, 2) + x[i] = randn(5) + v, from_vec = to_vec(x) + @test v == x[i] + v2 = randn(size(v)) + x2 = from_vec(v2) + @test x2[i] == v2 + @test !isassigned(x2, 3 - i) + end + end + + @testset "tuple" begin + v = randn(3) + x = ("foo", 1, false, String, TestStruct(3.0, v)) + test_to_vec(x) + @test to_vec(x)[1] == vcat(3.0, v) + end + + @testset "namedtuple" begin + x = (x="bar", y=randn(3), z=randn(), w=TestStruct(4.0, randn(2))) + test_to_vec(x) + @test to_vec(x)[1] == vcat(x.y, x.z, x.w.x, x.w.a) + end + + @testset "dict" begin + x = Dict(:a => randn(2), :b => randn(3)) + test_to_vec(x) + end + + @testset "views of arrays" begin + x = randn(2, 3) + test_to_vec(reshape(x, 3, 2)) + test_to_vec(view(x, :, 1)) + test_to_vec(PermutedDimsArray(x, (2, 1))) + end + + @testset "subarrays" begin + x = randn(2, 3) + # note: bottom right 2x2 submatrix ommited from y but will be present in v + y = @views (x[:, 1], x[1, :]) + test_to_vec(y) + v, from_vec = to_vec(y) + @test v == vec(x) + v2 = randn(size(v)) + y2 = from_vec(v2) + @test y2[1] == reshape(v2, size(x))[:, 1] + @test y2[2] == reshape(v2, size(x))[1, :] + @test Base.dataids(y2[1]) == Base.dataids(y2[2]) + end + + @testset "reshaped arrays share memory" begin + struct MyContainer1 + a::Any + b::Any + end + mutable struct MyContainer2 + a::Any + b::Any + end + @testset for T in (MyContainer1, MyContainer2) + x = randn(2, 3) + x2 = vec(x) + y = T(x, x2) + test_to_vec(y) + v, from_vec = to_vec(y) + @test v == x2 + y2 = from_vec(v) + @test Base.dataids(y2.a) == Base.dataids(y2.b) + end + end +end From 34a5ce3fe694b54f0edf2c9f8411e66086e9be27 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 22 Apr 2024 14:22:54 -0400 Subject: [PATCH 014/495] stabilize default_adjoint (#1403) --- src/compiler.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 1cc7f14de8..1ea57ab697 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5251,9 +5251,15 @@ end end end -@inline default_adjoint(::Type{T}) where T = error("Active return values with automatic pullback (differential return value) deduction only supported for floating-like values and not type $T. If mutable memory, please use Duplicated. Otherwise, you can explicitly specify a pullback by using split mode, e.g. autodiff_thunk(ReverseSplitWithPrimal, ...)") -@inline default_adjoint(::Type{T}) where T<:AbstractFloat = one(T) -@inline default_adjoint(::Type{Complex{T}}) where T = error("Attempted to use automatic pullback (differential return value) deduction on a either a type unstable function returning an active complex number, or autodiff_deferred returning an active complex number. For the first case, please type stabilize your code, e.g. by specifying autodiff(Reverse, f->f(x)::Complex, ...). For the second case, please use regular non-deferred autodiff") +@inline function default_adjoint(T) + if T <: AbstractFloat + return one(T) + elseif T <: Complex + error("Attempted to use automatic pullback (differential return value) deduction on a either a type unstable function returning an active complex number, or autodiff_deferred returning an active complex number. For the first case, please type stabilize your code, e.g. by specifying autodiff(Reverse, f->f(x)::Complex, ...). For the second case, please use regular non-deferred autodiff") + else + error("Active return values with automatic pullback (differential return value) deduction only supported for floating-like values and not type $T. If mutable memory, please use Duplicated. Otherwise, you can explicitly specify a pullback by using split mode, e.g. autodiff_thunk(ReverseSplitWithPrimal, ...)") + end +end function add_one_in_place(x) ty = typeof(x) From 1e27530c10989926c45377e1efd47f047415603e Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 22 Apr 2024 14:29:33 -0400 Subject: [PATCH 015/495] Null addr cast (#1404) --- src/compiler/optimize.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 110a1636cb..5be2d712a8 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -376,11 +376,15 @@ function nodecayed_phis!(mod::LLVM.Module) if opcode(v) == LLVM.API.LLVMAddrSpaceCast v2 = operands(v)[1] if addrspace(value_type(v2)) == 0 - if addr == 11 && isa(v, LLVM.ConstantExpr) - v2 = const_addrspacecast(operands(v)[1], LLVM.PointerType(eltype(value_type(v)), 10)) + if addr == 11 + v2 = const_addrspacecast(v2, LLVM.PointerType(eltype(value_type(v)), 10)) return v2, offset, hasload end end + if LLVM.isnull(v2) + v2 = const_addrspacecast(v2, LLVM.PointerType(eltype(value_type(v)), 10)) + return v2, offset, hasload + end end end From aef17d1b02ca02b6dd2af9923a431b15bdc0240e Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 3 May 2024 18:24:38 -0400 Subject: [PATCH 016/495] Fix autodiff_deferred_thunk abi (#1412) --- src/Enzyme.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 5168e116fb..a45aad9596 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -768,7 +768,7 @@ function f(A, v) end TapeType = tape_type(ReverseSplitWithPrimal, Const{typeof(f)}, Active, Duplicated{typeof(A)}, Active{typeof(v)}) -forward, reverse = autodiff_deferred_thunk(ReverseSplitWithPrimal, TapeType, Const{typeof(f)}, Active, Active{Float64}, Duplicated{typeof(A)}, Active{typeof(v)}) +forward, reverse = autodiff_deferred_thunk(ReverseSplitWithPrimal, TapeType, Const{typeof(f)}, Active{Float64}, Duplicated{typeof(A)}, Active{typeof(v)}) tape, result, shadow_result = forward(Const(f), Duplicated(A, ∂A), Active(v)) _, ∂v = reverse(Const(f), Duplicated(A, ∂A), Active(v), 1.0, tape)[1] @@ -780,7 +780,7 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{TapeType}, ::Type{FA}, ::Type{A}, ::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, A2, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} +@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{TapeType}, ::Type{FA}, ::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, A2, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} @assert RABI == FFIABI width = if Width == 0 w = same_or_one(args...) From 4c2a26a343ccfc89879e5271ee6fc58862e4cf9f Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 3 May 2024 18:25:48 -0400 Subject: [PATCH 017/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1032750266..7340591dc4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.0" +version = "0.12.1" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 61b4a94f21088f6da0ae8f60d00427bc8bfb83e6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 5 May 2024 22:11:02 -0400 Subject: [PATCH 018/495] Fix test (#1414) --- test/DiffTests.jl | 12 +++++++++++- test/runtests.jl | 1 - 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/test/DiffTests.jl b/test/DiffTests.jl index fead15d23f..98851f1559 100644 --- a/test/DiffTests.jl +++ b/test/DiffTests.jl @@ -29,8 +29,13 @@ num2num_3(x) = 10.31^(x + x) - x num2num_4(x) = 1.0 num2num_5(x) = 1. / (1. + exp(-x)) +@static if sizeof(Int) == Int64 || VERSION ≥ v"1.7-" const NUMBER_TO_NUMBER_FUNCS = (num2num_1, num2num_2, num2num_3, num2num_4, num2num_5, identity) +else +const NUMBER_TO_NUMBER_FUNCS = (num2num_1, num2num_2, num2num_3, + num2num_4, identity) +end ####################### # f(x::Number)::Array # @@ -122,12 +127,17 @@ const VECTOR_TO_NUMBER_FUNCS = (vec2num_1, vec2num_2, vec2num_3, vec2num_4, vec #=vec2num_6,=# vec2num_7, rosenbrock_1, rosenbrock_2, rosenbrock_3, #=rosenbrock_4,=# ackley, self_weighted_logit, first) -else +elseif sizeof(Int) == Int64 || VERSION ≥ v"1.7-" # vec2num_6 fails due to #708 const VECTOR_TO_NUMBER_FUNCS = (vec2num_1, vec2num_2, vec2num_3, vec2num_4, vec2num_5, #=vec2num_6,=# vec2num_7, rosenbrock_1, rosenbrock_2, rosenbrock_3, rosenbrock_4, ackley, self_weighted_logit, first) +else +const VECTOR_TO_NUMBER_FUNCS = (#=vec2num_1,=# vec2num_2, vec2num_3, vec2num_4, vec2num_5, + #=vec2num_6,=# vec2num_7, rosenbrock_1, rosenbrock_2, + rosenbrock_3, rosenbrock_4, #=ackley,=# self_weighted_logit, + first) end ######################## # f(x::Matrix)::Number # diff --git a/test/runtests.jl b/test/runtests.jl index d9d4c1ac37..486b00d402 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -325,7 +325,6 @@ end ReverseSplitWithPrimal, TapeType, Const{typeof(dot)}, - Active, Active{Float64}, Duplicated{typeof(thunk_A)} ) From 151ea309b81a8382efe10693e03f890dd97a9d69 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 5 May 2024 22:13:13 -0400 Subject: [PATCH 019/495] More debug info on error (#1396) * More debug info on error * Disable bad 1.6 x86 tests --- test/runtests.jl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 486b00d402..a8506016ca 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,6 +26,9 @@ using InlineStrings using Enzyme_jll @info "Testing against" Enzyme_jll.libEnzyme +function isapproxfn(fn, args...; kwargs...) + isapprox(args...; kwargs...) +end # Test against FiniteDifferences function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) ∂x, = autodiff(ReverseHolomorphic, f, Active, Active(x))[1] @@ -37,7 +40,7 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs.. fdm(f, x) end - @test isapprox(∂x, finite_diff; rtol=rtol, atol=atol, kwargs...) + @test isapproxfn((Enzyme.Reverse, f), ∂x, finite_diff; rtol=rtol, atol=atol, kwargs...) if typeof(x) <: Integer x = Float64(x) @@ -51,7 +54,7 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs.. ∂x, = autodiff(Forward, f, Duplicated(x, one(typeof(x)))) end - @test isapprox(∂x, finite_diff; rtol=rtol, atol=atol, kwargs...) + @test isapproxfn((Enzyme.Reverse, f), ∂x, finite_diff; rtol=rtol, atol=atol, kwargs...) end @@ -267,11 +270,17 @@ make3() = (1.0, 2.0, 3.0) test_scalar(cbrt, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5) test_scalar(Base.sinh, 1.0) test_scalar(Base.cosh, 1.0) + if sizeof(Int) == Int64 || VERSION ≥ v"1.7-" test_scalar(Base.sinc, 2.2) + end test_scalar(Base.FastMath.sinh_fast, 1.0) test_scalar(Base.FastMath.cosh_fast, 1.0) + if sizeof(Int) == Int64 || VERSION ≥ v"1.7-" test_scalar(Base.FastMath.exp_fast, 1.0) + end + if sizeof(Int) == Int64 || VERSION ≥ v"1.7-" test_scalar(Base.exp10, 1.0) + end test_scalar(Base.exp2, 1.0) test_scalar(Base.expm1, 1.0) test_scalar(x->rem(x, 1), 0.7) From d62e4f31a194164ff461d1e4c87212aa1220004a Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 6 May 2024 22:35:38 -0400 Subject: [PATCH 020/495] Special case make_zer(array) for perf (#1415) * Special case make_zer(array) for perf * Additional stabilizations --- src/Enzyme.jl | 175 ++++++++++++++++++++++++++---------------------- src/compiler.jl | 58 +++++++++++++--- test/abi.jl | 34 ++++++++++ 3 files changed, 178 insertions(+), 89 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index a45aad9596..0fa3f6bb90 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -67,6 +67,13 @@ end end) end +@inline function vaTypeof(args::Vararg{Any, N}) where N + return Tuple{(ntuple(Val(N)) do i + Base.@_inline_meta + Core.Typeof(args[i]) + end)...} +end + @inline function same_or_one_helper(current, next) if current == -1 return next @@ -92,17 +99,60 @@ end same_or_one_rec(same_or_one_helper(current, N), args...) @inline same_or_one_rec(current, arg, args...) = same_or_one_rec(current, args...) -@inline function same_or_one(args...) - res = same_or_one_rec(-1, args...) - if res == -1 - return 1 +@inline function same_or_one(defaultVal, args...) + local_soo_res = same_or_one_rec(-1, args...) + if local_soo_res == -1 + defaultVal else - return res + local_soo_res + end +end + + +@inline function refn_seed(x::T) where T + if T <: Complex + return conj(x) / 2 + else + return x + end +end + +@inline function imfn_seed(x::T) where T + if T <: Complex + return im * conj(x) / 2 + else + return T(0) + end +end + +@inline function seed_complex_args(seen, seen2, args::Vararg{Annotation, Nargs}) where {Nargs} + return ntuple(Val(Nargs)) do i + Base.@_inline_meta + arg = args[i] + if arg isa Const || arg isa Active + arg + elseif arg isa Duplicated || arg isa DuplicatedNoNeed + RT = eltype(Core.Typeof(arg)) + BatchDuplicated(arg.val, (arg.dval, make_zero(RT, seen, arg.dval), make_zero(RT, seen2, arg.dval))) + else + throw(ErrorException("Active Complex return does not yet support batching in combined reverse mode")) + end + end +end + +@inline function fuse_complex_results(results, args::Vararg{Annotation, Nargs}) where {Nargs} + ntuple(Val(Nargs)) do i + Base.@_inline_meta + if args[i] isa Active + Compiler.recursive_add(Compiler.recursive_add(results[1][i][1], results[1][i][2], refn_seed), results[1][i][3], imfn_seed) + else + results[1][i] + end end end """ - autodiff(::ReverseMode, f, Activity, args::Vararg{Annotation, Nargs}) + autodiff(::ReverseMode, f, Activity, args::Vararg{<:Annotation, Nargs}) Auto-differentiate function `f` at arguments `args` using reverse mode. @@ -161,9 +211,9 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) [`Active`](@ref) will automatically convert plain integers to floating point values, but cannot do so for integer values in tuples and structs. """ -@inline function autodiff(::ReverseMode{ReturnPrimal, RABI,Holomorphic}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RABI<:ABI, Nargs,Holomorphic} - tt′ = Tuple{map(Core.Typeof, args)...} - width = same_or_one(args...) +@inline function autodiff(::ReverseMode{ReturnPrimal, RABI,Holomorphic}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RABI<:ABI,Holomorphic, Nargs} + tt′ = vaTypeof(args...) + width = same_or_one(1, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end @@ -207,59 +257,25 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) throw(ErrorException("Active Complex return does not yet support batching in combined reverse mode")) end - args = ntuple(Val(Nargs)) do i - Base.@_inline_meta - arg = args[i] - if arg isa Const || arg isa Active - arg - elseif arg isa Duplicated || arg isa DuplicatedNoNeed - RT = eltype(Core.Typeof(arg)) - BatchDuplicated(arg.val, (arg.dval, make_zero(RT, seen, arg.dval), make_zero(RT, seen2, arg.dval))) - else - throw(ErrorException("Active Complex return does not yet support batching in combined reverse mode")) - end - end - width = same_or_one_rec(3, args...) - tt′ = Tuple{map(Core.Typeof, args)...} + width = same_or_one(3, args...) + args = seed_complex_args(seen, seen2, args...) + tt′ = vaTypeof(args...) thunk = Enzyme.Compiler.thunk(Val(world), typeof(f), A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) results = thunk(f, args..., (rt(0), rt(1), rt(im))) - @inline function refn(x::T) where T - if T <: Complex - return conj(x) / 2 - else - return x - end - end - - @inline function imfn(x::T) where T - if T <: Complex - return im * conj(x) / 2 - else - return T(0) - end - end - # compute the correct complex derivative in reverse mode by propagating the conjugate return values # then subtracting twice the imaginary component to get the correct result for (k, v) in seen - Compiler.recursive_accumulate(k, v, refn) + Compiler.recursive_accumulate(k, v, refn_seed) end for (k, v) in seen2 - Compiler.recursive_accumulate(k, v, imfn) + Compiler.recursive_accumulate(k, v, imfn_seed) end - fused = ntuple(Val(Nargs)) do i - Base.@_inline_meta - if args[i] isa Active - Compiler.recursive_add(Compiler.recursive_add(results[1][i][1], results[1][i][2], refn), results[1][i][3], imfn) - else - results[1][i] - end - end + fused = fuse_complex_results(results, args...) return (fused, results[2:end]...) end @@ -288,12 +304,11 @@ end end """ - autodiff(mode::Mode, f, args::Vararg{Annotation, Nargs}) + autodiff(mode::Mode, f, args...) Like [`autodiff`](@ref) but will try to guess the activity of the return value. """ @inline function autodiff(mode::CMode, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, CMode<:Mode, Nargs} - tt′ = Tuple{map(Core.Typeof, args)...} tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} rt = Core.Compiler.return_type(f.val, tt) A = guess_activity(rt, mode) @@ -301,7 +316,7 @@ Like [`autodiff`](@ref) but will try to guess the activity of the return value. end """ - autodiff(::ForwardMode, f, Activity, args::Vararg{Annotation, Nargs}) + autodiff(::ForwardMode, f, Activity, args::Vararg{<:Annotation, Nargs}) Auto-differentiate function `f` at arguments `args` using forward mode. @@ -349,8 +364,8 @@ f(x) = x*x if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end - tt′ = Tuple{map(Core.Typeof, args)...} - width = same_or_one(args...) + tt′ = vaTypeof(args...) + width = same_or_one(1, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end @@ -385,14 +400,14 @@ f(x) = x*x end """ - autodiff_deferred(::ReverseMode, f, Activity, args::Vararg{Annotation, Nargs}) + autodiff_deferred(::ReverseMode, f, Activity, args::Vararg{<:Annotation, Nargs}) Same as [`autodiff`](@ref) but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ @inline function autodiff_deferred(::ReverseMode{ReturnPrimal}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, Nargs} - tt′ = Tuple{map(Core.Typeof, args)...} - width = same_or_one(args...) + tt′ = vaTypeof(args...) + width = same_or_one(1, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end @@ -426,17 +441,17 @@ code, as well as high-order differentiation. end """ - autodiff_deferred(::ForwardMode, f, Activity, args::Vararg{Annotation, Nargs}) + autodiff_deferred(::ForwardMode, f, Activity, args::Vararg{<:Annotation, Nargs}) Same as `autodiff(::ForwardMode, f, Activity, args)` but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ForwardMode, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs}) where {FA<:Annotation, A<:Annotation, Nargs} +@inline function autodiff_deferred(::ForwardMode, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, Nargs} if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end - tt′ = Tuple{map(Core.Typeof, args)...} - width = same_or_one(args...) + tt′ = vaTypeof(args...) + width = same_or_one(1, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end @@ -484,7 +499,7 @@ code, as well as high-order differentiation. end """ - autodiff_deferred(mode::Mode, f, ::Type{A}, args::Vararg{Annotation, Nargs}) + autodiff_deferred(mode::Mode, f, ::Type{A}, args) Like [`autodiff_deferred`](@ref) but will try to extend f to an annotation, if needed. """ @@ -496,7 +511,7 @@ end end """ - autodiff_deferred(mode, f, args::Vararg{Annotation, Nargs}) + autodiff_deferred(mode, f, args...) Like [`autodiff_deferred`](@ref) but will try to guess the activity of the return value. """ @@ -513,7 +528,7 @@ Like [`autodiff_deferred`](@ref) but will try to guess the activity of the retur end """ - autodiff_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs}) + autodiff_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation, Nargs}) Provide the split forward and reverse pass functions for annotated function type ftype when called with args of type `argtypes` when using reverse mode. @@ -557,7 +572,7 @@ result, ∂v, ∂A """ @inline function autodiff_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI<:ABI, Nargs} width = if Width == 0 - w = same_or_one(args...) + w = same_or_one(1, args...) if w == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end @@ -627,7 +642,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated ``` """ @inline function autodiff_thunk(::ForwardMode{RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, RABI<:ABI, Nargs} - width = same_or_one(A, args...) + width = same_or_one(1, A, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end @@ -646,7 +661,7 @@ end @inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} width = if Width == 0 - w = same_or_one(args...) + w = same_or_one(1, args...) if w == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end @@ -679,10 +694,10 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType @inline function tape_type( parent_job::Union{GPUCompiler.CompilerJob,Nothing}, ::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, - ::Type{FA}, ::Type{A}, args... -) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI} + ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs} +) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} width = if Width == 0 - w = same_or_one(args...) + w = same_or_one(1, args...) if w == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end @@ -780,10 +795,10 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{TapeType}, ::Type{FA}, ::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, A2, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} +@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{TapeType}, ::Type{FA}, ::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A2<:Annotation, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} @assert RABI == FFIABI width = if Width == 0 - w = same_or_one(args...) + w = same_or_one(1, args...) if w == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end @@ -928,14 +943,14 @@ grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) (a = 3.0, b = [2.0], c = "str") ``` """ -@inline function gradient(::ReverseMode, f::F, x::X) where {F, X} +@inline function gradient(rm::ReverseMode, f::F, x::X) where {F, X} if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState dx = Ref(make_zero(x)) - autodiff(Reverse, f∘only, Active, Duplicated(Ref(x), dx)) + autodiff(rm, f∘only, Active, Duplicated(Ref(x), dx)) return only(dx) else dx = make_zero(x) - autodiff(Reverse, f, Active, Duplicated(x, dx)) + autodiff(rm, f, Active, Duplicated(x, dx)) return dx end end @@ -970,7 +985,7 @@ gradient!(Reverse, dx, f, [2.0, 3.0]) end """ - gradient(::ForwardMode, f, x::Array; shadow=onehot(x)) + gradient(::ForwardMode, f, x; shadow=onehot(x)) Compute the gradient of an array-input function `f` using forward mode. The optional keyword argument `shadow` is a vector of one-hot vectors of type `x` @@ -990,7 +1005,7 @@ grad = gradient(Forward, f, [2.0, 3.0]) (3.0, 2.0) ``` """ -@inline function gradient(::ForwardMode, f, x::Array; shadow=onehot(x)) +@inline function gradient(::ForwardMode, f, x; shadow=onehot(x)) if length(x) == 0 return () end @@ -1011,7 +1026,7 @@ end @inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...) """ - gradient(::ForwardMode, f, x::Array, ::Val{chunk}; shadow=onehot(x)) + gradient(::ForwardMode, f, x::Union{Array,NTuple}, ::Val{chunk}; shadow=onehot(x)) Compute the gradient of an array-input function `f` using vector forward mode. Like [`gradient`](@ref), except it uses a chunk size of `chunk` to compute @@ -1029,7 +1044,7 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2)) (3.0, 2.0) ``` """ -@inline function gradient(::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X<:Array, chunk} +@inline function gradient(::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk} if chunk == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end @@ -1039,7 +1054,7 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2)) tupleconcat(tmp...) end -@inline function gradient(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X<:Array} +@inline function gradient(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X} ntuple(length(shadow)) do i autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] end diff --git a/src/compiler.jl b/src/compiler.jl index 1ea57ab697..dcfb9b4efc 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -362,6 +362,13 @@ end end end +@inline function staticInTup(::Val{T}, tup::NTuple{N, Val}) where {T, N} + any(ntuple(Val(N)) do i + Base.@_inline_meta + Val(T) == tup[i] + end) +end + @inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret} if T === Any @@ -469,20 +476,28 @@ end @static if VERSION < v"1.7.0" nT = T else - nT = if is_concrete_tuple(T) && any(T2 isa Core.TypeofVararg for T2 in T.parameters) - Tuple{((T2 isa Core.TypeofVararg ? Any : T2) for T2 in T.parameters)...,} + nT = if is_concrete_tuple(T) + Tuple{(ntuple(length(T.parameters)) do i + Base.@_inline_meta + sT = T.parameters[i] + if sT isa Core.TypeofVararg + Any + else + sT + end + end)...} else T end end - if Val(nT) ∈ seen + if staticInTup(Val(nT), seen) return MixedState end - seen = (Val(nT), seen...) + seen2 = (Val(nT), seen...) - fty = Merger{seen,typeof(world),justActive, UnionSret}(world) + fty = Merger{seen2,typeof(world),justActive, UnionSret}(world) ty = forcefold(Val(AnyState), ntuple(fty, Val(fieldcount(nT)))...) @@ -521,7 +536,7 @@ end return res end -Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = guess_activity(T, convert(API.CDerivativeMode, mode)) +@inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = guess_activity(T, convert(API.CDerivativeMode, mode)) @inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T} ActReg = active_reg_inner(T, (), nothing) @@ -1177,6 +1192,30 @@ function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, N) end end +@inline function EnzymeCore.make_zero(x::Array{FT, N})::Array{FT, N} where {FT <: AbstractFloat, N} + return Base.zero(x) +end +@inline function EnzymeCore.make_zero(x::Array{Complex{FT}, N})::Array{Complex{FT}, N} where {FT <: AbstractFloat, N} + return Base.zero(x) +end + +@inline function EnzymeCore.make_zero(::Type{Array{FT, N}}, seen::IdDict, prev::Array{FT, N}, ::Val{copy_if_inactive}=Val(false))::Array{FT, N} where {copy_if_inactive, FT<:AbstractFloat, N} + if haskey(seen, prev) + return seen[prev] + end + newa = Base.zero(prev) + seen[prev] = newa + return newa +end +@inline function EnzymeCore.make_zero(::Type{Array{Complex{FT}, N}}, seen::IdDict, prev::Array{Complex{FT}, N}, ::Val{copy_if_inactive}=Val(false))::Array{Complex{FT}, N} where {copy_if_inactive, FT<:AbstractFloat, N} + if haskey(seen, prev) + return seen[prev] + end + newa = Base.zero(prev) + seen[prev] = newa + return newa +end + @inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT<:AbstractFloat} return RT(0) end @@ -1205,11 +1244,12 @@ end end @inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT<:Tuple} - return ((EnzymeCore.make_zero(a, seen, prev[i], Val(copy_if_inactive)) for (i, a) in enumerate(RT.parameters))...,) + return ntuple(length(prev)) do i + Base.@_inline_meta + EnzymeCore.make_zero(RT.parameters[i], seen, prev[i], Val(copy_if_inactive)) + end end - - @inline function EnzymeCore.make_zero(::Type{NamedTuple{A,RT}}, seen::IdDict, prev::NamedTuple{A,RT}, ::Val{copy_if_inactive}=Val(false))::NamedTuple{A,RT} where {copy_if_inactive, A,RT} return NamedTuple{A,RT}(EnzymeCore.make_zero(RT, seen, RT(prev), Val(copy_if_inactive))) end diff --git a/test/abi.jl b/test/abi.jl index ef0db2fa22..7371af504e 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -408,3 +408,37 @@ end @test r[2][1] ≈ -400.0 @test r[2][2] ≈ 200.0 end + +abssum(x) = sum(abs2, x); + +@testset "Type inference" begin + x = ones(10) + @inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x)) + @inferred autodiff(Enzyme.ReverseWithPrimal, abssum, Duplicated(x,x)) + @inferred autodiff(Enzyme.ReverseHolomorphic, abssum, Duplicated(x,x)) + @inferred autodiff(Enzyme.ReverseHolomorphicWithPrimal, abssum, Duplicated(x,x)) + @inferred autodiff(Enzyme.Forward, abssum, Duplicated(x,x)) + @inferred autodiff(Enzyme.Forward, abssum, Duplicated, Duplicated(x,x)) + @inferred autodiff(Enzyme.Forward, abssum, DuplicatedNoNeed, Duplicated(x,x)) + + @inferred gradient(Reverse, abssum, x) + @inferred gradient!(Reverse, x, abssum, x) + + cx = ones(10) + @inferred autodiff(Enzyme.ReverseHolomorphic, sum, Duplicated(cx,cx)) + @inferred autodiff(Enzyme.ReverseHolomorphicWithPrimal, sum, Duplicated(cx,cx)) + @inferred autodiff(Enzyme.Forward, sum, Duplicated(cx,cx)) + + @inferred Enzyme.make_zero(x) + @inferred Enzyme.make_zero(cx) + + tx = (1.0, 2.0, 3.0) + + @inferred Enzyme.Compiler.active_reg_inner(Tuple{Float64,Float64,Float64}, (), nothing, Val(true)) + @inferred Enzyme.make_zero(tx) + + @inferred gradient(Reverse, abssum, tx) + @inferred gradient(Forward, abssum, tx) + +end + From d1419b96003c9cb9d2a6e6fb9ca97c4c696fe26f Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 7 May 2024 10:08:31 -0400 Subject: [PATCH 021/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7340591dc4..b9353f4863 100644 --- a/Project.toml +++ b/Project.toml @@ -25,7 +25,7 @@ EnzymeSpecialFunctionsExt = "SpecialFunctions" [compat] CEnum = "0.4, 0.5" EnzymeCore = "0.7" -Enzyme_jll = "0.0.104" +Enzyme_jll = "0.0.105" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1" ObjectFile = "0.4" From 4bcae3322322d67bb54c0e3160ef41b6da990ef5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 7 May 2024 10:10:16 -0400 Subject: [PATCH 022/495] WIP opt out of types (#1413) * WIP opt out of types * Update src/compiler.jl Co-authored-by: Valentin Churavy * fix --------- Co-authored-by: Valentin Churavy --- src/compiler.jl | 116 ++++++++++++++++++++++--------------- src/rules/activityrules.jl | 53 +++++++++-------- test/runtests.jl | 11 ++++ 3 files changed, 107 insertions(+), 73 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index dcfb9b4efc..05a300c2bd 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4426,6 +4426,21 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function return wrapper_f, returnRoots, boxedArgs, loweredArgs end +using Random +# returns arg, return +function no_type_setting(@nospecialize(specTypes); world=nothing) + @static if VERSION >= v"1.7.0-" + # Even though the julia type here is ptr{int8}, the actual data can be something else + if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_simd) + return (true, false) + end + if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_nosimd) + return (true, false) + end + end + return (false, false) +end + function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; libraries::Bool=true, deferred_codegen::Bool=true, optimize::Bool=true, toplevel::Bool=true, strip::Bool=false, validate::Bool=true, only_entry::Bool=false, parent_job::Union{Nothing, CompilerJob} = nothing) @@ -4596,71 +4611,76 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; ctx = LLVM.context(f) - for arg in jlargs - if arg.cc == GPUCompiler.GHOST || arg.cc == RemovedParam - continue - end - push!( - parameter_attributes(f, arg.codegen.i), - StringAttribute( - "enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))) - ), - ) - push!( - parameter_attributes(f, arg.codegen.i), - StringAttribute("enzymejl_parmtype_ref", string(UInt(arg.cc))), - ) + push!(function_attributes(f), StringAttribute("enzyme_ta_norecur")) - byref = arg.cc + if !no_type_setting(mi.specTypes; world)[1] + for arg in jlargs + if arg.cc == GPUCompiler.GHOST || arg.cc == RemovedParam + continue + end + push!( + parameter_attributes(f, arg.codegen.i), + StringAttribute( + "enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))) + ), + ) + push!( + parameter_attributes(f, arg.codegen.i), + StringAttribute("enzymejl_parmtype_ref", string(UInt(arg.cc))), + ) - rest = typetree(arg.typ, ctx, dl) + byref = arg.cc - if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF - # adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader - # object passing this in by ref isnt a {[-1]:Pointer, [-1,-1]:Int} - # aka the next field after this in the bigger object isn't guaranteed to also be the same. - if allocatedinline(arg.typ) - shift!(rest, dl, 0, sizeof(arg.typ), 0) - end - merge!(rest, TypeTree(API.DT_Pointer, ctx)) - only!(rest, -1) - else - # canonicalize wrt size - end - push!( - parameter_attributes(f, arg.codegen.i), - StringAttribute("enzyme_type", string(rest)), - ) - end + rest = typetree(arg.typ, ctx, dl) - if sret !== nothing - idx = 0 - if !in(0, parmsRemoved) - rest = typetree(sret, ctx, dl) + if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF + # adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader + # object passing this in by ref isnt a {[-1]:Pointer, [-1,-1]:Int} + # aka the next field after this in the bigger object isn't guaranteed to also be the same. + if allocatedinline(arg.typ) + shift!(rest, dl, 0, sizeof(arg.typ), 0) + end + merge!(rest, TypeTree(API.DT_Pointer, ctx)) + only!(rest, -1) + else + # canonicalize wrt size + end push!( - parameter_attributes(f, idx + 1), + parameter_attributes(f, arg.codegen.i), StringAttribute("enzyme_type", string(rest)), ) - idx += 1 end - if returnRoots !== nothing - if !in(1, parmsRemoved) - rest = TypeTree(API.DT_Pointer, -1, ctx) + end + + if !no_type_setting(mi.specTypes; world)[2] + if sret !== nothing + idx = 0 + if !in(0, parmsRemoved) + rest = typetree(sret, ctx, dl) push!( parameter_attributes(f, idx + 1), StringAttribute("enzyme_type", string(rest)), ) + idx += 1 + end + if returnRoots !== nothing + if !in(1, parmsRemoved) + rest = TypeTree(API.DT_Pointer, -1, ctx) + push!( + parameter_attributes(f, idx + 1), + StringAttribute("enzyme_type", string(rest)), + ) + end end end - end - if llRT !== nothing && LLVM.return_type(LLVM.function_type(f)) != LLVM.VoidType() - @assert !retRemoved - rest = typetree(llRT, ctx, dl) - push!(return_attributes(f), StringAttribute("enzyme_type", string(rest))) + if llRT !== nothing && LLVM.return_type(LLVM.function_type(f)) != LLVM.VoidType() + @assert !retRemoved + rest = typetree(llRT, ctx, dl) + push!(return_attributes(f), StringAttribute("enzyme_type", string(rest))) + end end - push!(function_attributes(f), StringAttribute("enzyme_ta_norecur")) end custom = Dict{String,LLVM.API.LLVMLinkage}() diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl index 45489031b3..2b36a9740e 100644 --- a/src/rules/activityrules.jl +++ b/src/rules/activityrules.jl @@ -27,7 +27,6 @@ function julia_activity_rule(f::LLVM.Function) if mi.specTypes.parameters[end] === Vararg{Any} return end - world = enzyme_extract_world(f) if expectLen != length(parameters(f)) @@ -39,39 +38,43 @@ function julia_activity_rule(f::LLVM.Function) jlargs = classify_arguments(mi.specTypes, function_type(f), sret !== nothing, returnRoots !== nothing, swiftself, parmsRemoved) - for arg in jlargs - if arg.cc == GPUCompiler.GHOST || arg.cc == RemovedParam - continue - end + if !Enzyme.Compiler.no_type_setting(mi.specTypes; world)[1] + for arg in jlargs + if arg.cc == GPUCompiler.GHOST || arg.cc == RemovedParam + continue + end - op_idx = arg.codegen.i - - typ, _ = enzyme_extract_parm_type(f, arg.codegen.i) - @assert typ == arg.typ + op_idx = arg.codegen.i + + typ, _ = enzyme_extract_parm_type(f, arg.codegen.i) + @assert typ == arg.typ - if guaranteed_const_nongen(arg.typ, world) - push!(parameter_attributes(f, arg.codegen.i), StringAttribute("enzyme_inactive")) + if guaranteed_const_nongen(arg.typ, world) + push!(parameter_attributes(f, arg.codegen.i), StringAttribute("enzyme_inactive")) + end end end - if sret !== nothing - idx = 0 - if !in(0, parmsRemoved) - if guaranteed_const_nongen(RT, world) - push!(parameter_attributes(f, idx+1), StringAttribute("enzyme_inactive")) + if !Enzyme.Compiler.no_type_setting(mi.specTypes; world)[2] + if sret !== nothing + idx = 0 + if !in(0, parmsRemoved) + if guaranteed_const_nongen(RT, world) + push!(parameter_attributes(f, idx+1), StringAttribute("enzyme_inactive")) + end + idx+=1 end - idx+=1 - end - if returnRoots !== nothing - if !in(idx, parmsRemoved) - push!(parameter_attributes(f, idx+1), StringAttribute("enzyme_inactive")) + if returnRoots !== nothing + if !in(idx, parmsRemoved) + push!(parameter_attributes(f, idx+1), StringAttribute("enzyme_inactive")) + end end end - end - if llRT !== nothing && LLVM.return_type(function_type(f)) != LLVM.VoidType() - if guaranteed_const_nongen(RT, world) - push!(return_attributes(f), StringAttribute("enzyme_inactive")) + if llRT !== nothing && LLVM.return_type(function_type(f)) != LLVM.VoidType() + if guaranteed_const_nongen(RT, world) + push!(return_attributes(f), StringAttribute("enzyme_inactive")) + end end end end diff --git a/test/runtests.jl b/test/runtests.jl index a8506016ca..8de8311271 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1145,6 +1145,17 @@ end Enzyme.API.runtimeActivity!(false) end +function fillsum(x) + a = similar(rand(3, 3)) + fill!(a, x) + return sum(a) +end + +@testset "Fill sum" begin + res = autodiff(Forward, fillsum, Duplicated(2.0, 1.0))[1] + @test 9.0 ≈ res +end + mutable struct RTGData x From 1e84d872162107cfa2b3506c2b49f218411fd1b3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 7 May 2024 11:09:16 -0400 Subject: [PATCH 023/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b9353f4863..e279c5c46c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.1" +version = "0.12.2" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 1c184e327e4f53440777d841e6632b88e96885d8 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 7 May 2024 15:30:00 -0400 Subject: [PATCH 024/495] fix add_one_in_place (#1418) --- src/compiler.jl | 7 ++--- test/runtests.jl | 77 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 05a300c2bd..097ff04c5a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3554,7 +3554,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end end - cf = nested_codegen!(Mode, mod, add_one_in_place, Tuple{actualRetType}, world) + cf = nested_codegen!(Mode, mod, add_one_in_place, Tuple{Any}, world) push!(function_attributes(cf), EnumAttribute("alwaysinline", 0)) for shadowv in shadows c = call!(builder, LLVM.function_type(cf), cf, [shadowv]) @@ -5322,9 +5322,8 @@ end end function add_one_in_place(x) - ty = typeof(x) - if ty <: Base.RefValue || ty == Base.RefValue{Float64} - x[] = recursive_add(x[], default_adjoint(eltype(ty))) + if x isa Base.RefValue + x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x)))) else error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string(x)) end diff --git a/test/runtests.jl b/test/runtests.jl index 8de8311271..4ca7aa45a8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2857,6 +2857,83 @@ end end end +const SEED = 42 +const N_SAMPLES = 500 +const N_COMPONENTS = 4 + +const rnd = Random.MersenneTwister(SEED) +const data = randn(rnd, N_SAMPLES) +const params0 = [rand(rnd, N_COMPONENTS); randn(rnd, N_COMPONENTS); 2rand(rnd, N_COMPONENTS)] + +# ========== Objective function ========== +normal_pdf(x::Real, mean::Real, var::Real) = + exp(-(x - mean)^2 / (2var)) / sqrt(2π * var) + +normal_pdf(x, mean, var) = + exp(-(x - mean)^2 / (2var)) / sqrt(2π * var) + +# original objective (doesn't work) +function mixture_loglikelihood1(params::AbstractVector{<:Real}, data::AbstractVector{<:Real})::Real + K = length(params) ÷ 3 + weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end] + mat = normal_pdf.(data, means', stds' .^2) # (N, K) + sum(mat .* weights', dims=2) .|> log |> sum +end + +# another form of original objective (doesn't work) +function mixture_loglikelihood2(params::AbstractVector{<:Real}, data::AbstractVector{<:Real})::Real + K = length(params) ÷ 3 + weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end] + mat = normal_pdf.(data, means', stds' .^2) # (N, K) + obj_true = sum( + sum( + weight * normal_pdf(x, mean, std^2) + for (weight, mean, std) in zip(weights, means, stds) + ) |> log + for x in data + ) +end + +# objective re-written by me +function mixture_loglikelihood3(params::AbstractVector{<:Real}, data::AbstractVector{<:Real})::Real + K = length(params) ÷ 3 + weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end] + mat = normal_pdf.(data, means', stds' .^2) # (N, K) + + obj = zero(eltype(mat)) + for x in data + obj_i = zero(eltype(mat)) + for (weight, mean, std) in zip(weights, means, stds) + obj_i += weight * normal_pdf(x, mean, std^2) + end + obj += log(obj_i) + end + return obj +end + +const objective1 = params -> mixture_loglikelihood1(params, data) +const objective2 = params -> mixture_loglikelihood2(params, data) +const objective3 = params -> mixture_loglikelihood3(params, data) + +@testset "Type unsstable return" begin + expected = [289.7308495620467, + 199.27559524985728, + 236.6894577756876, + 292.0612340227955, + -9.429799389881452, + 26.722295646439047, + -1.9180355546752244, + 37.98749089573396, + -24.095620148778277, + -13.935687326484112, + -38.00044665702692, + 12.87712891527131] + @test expected ≈ Enzyme.gradient(Reverse, objective1, params0) + # objective2 fails from runtime activity requirements + # @test expected ≈ Enzyme.gradient(Reverse, objective2, params0) + @test expected ≈ Enzyme.gradient(Reverse, objective3, params0) +end + struct HarmonicAngle k::Float64 t0::Float64 From ff878b53e723431de3d3f12f58472e7e0d6fbdce Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 8 May 2024 00:43:53 +0200 Subject: [PATCH 025/495] add import_frule (reprised) (#1333) * Add import frule functionality * Add wip rrule importer * move everything to extension; add tests frule * remove import_rrule * runtests * esc(fn) fixes tests * added failing batchduplicated test * remove test import * address review comments * add dollar in macro * cleanup * Fixup and cleanup --------- Co-authored-by: William S. Moses Co-authored-by: Billy Moses --- Project.toml | 8 ++ ext/EnzymeChainRulesCoreExt.jl | 107 +++++++++++++++++++++ src/Enzyme.jl | 5 + test/Project.toml | 2 + test/ext/chainrulescore.jl | 70 ++++++++++++++ test/{packages => ext}/specialfunctions.jl | 0 test/runtests.jl | 29 ++++-- 7 files changed, 211 insertions(+), 10 deletions(-) create mode 100644 ext/EnzymeChainRulesCoreExt.jl create mode 100644 test/ext/chainrulescore.jl rename test/{packages => ext}/specialfunctions.jl (100%) diff --git a/Project.toml b/Project.toml index e279c5c46c..0f5073d450 100644 --- a/Project.toml +++ b/Project.toml @@ -17,17 +17,25 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] EnzymeSpecialFunctionsExt = "SpecialFunctions" +EnzymeChainRulesCoreExt = "ChainRulesCore" [compat] CEnum = "0.4, 0.5" +ChainRulesCore = "1" EnzymeCore = "0.7" Enzyme_jll = "0.0.105" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1" ObjectFile = "0.4" Preferences = "1.4" +SpecialFunctions = "1, 2" julia = "1.6" + +[extras] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" \ No newline at end of file diff --git a/ext/EnzymeChainRulesCoreExt.jl b/ext/EnzymeChainRulesCoreExt.jl new file mode 100644 index 0000000000..2c8d180a57 --- /dev/null +++ b/ext/EnzymeChainRulesCoreExt.jl @@ -0,0 +1,107 @@ +module EnzymeChainRulesCoreExt + +using ChainRulesCore +using EnzymeCore +using Enzyme + + +""" + import_frule(::fn, tys...) + +Automatically import a `ChainRulesCore.frule`` as a custom forward mode `EnzymeRule`. When called in batch mode, this +will end up calling the primal multiple times, which may result in incorrect behavior if the function mutates, +and slow code, always. Importing the rule from `ChainRules` is also likely to be slower than writing your own rule, +and may also be slower than not having a rule at all. + +Use with caution. + +```jldoctest +Enzyme.@import_frule(typeof(Base.sort), Any); + +x=[1.0, 2.0, 0.0]; dx=[0.1, 0.2, 0.3]; ddx = [0.01, 0.02, 0.03]; + +Enzyme.autodiff(Forward, sort, Duplicated, BatchDuplicated(x, (dx,ddx))) +Enzyme.autodiff(Forward, sort, DuplicatedNoNeed, BatchDuplicated(x, (dx,ddx))) +Enzyme.autodiff(Forward, sort, DuplicatedNoNeed, BatchDuplicated(x, (dx,))) +Enzyme.autodiff(Forward, sort, Duplicated, BatchDuplicated(x, (dx,))) + +# output + +(var"1" = [0.0, 1.0, 2.0], var"2" = (var"1" = [0.3, 0.1, 0.2], var"2" = [0.03, 0.01, 0.02])) +(var"1" = (var"1" = [0.3, 0.1, 0.2], var"2" = [0.03, 0.01, 0.02]),) +(var"1" = [0.3, 0.1, 0.2],) +(var"1" = [0.0, 1.0, 2.0], var"2" = [0.3, 0.1, 0.2]) + +``` +""" +function Enzyme._import_frule(fn, tys...) + vals = [] + exprs = [] + primals = [] + tangents = [] + tangentsi = [] + anns = [] + for (i, ty) in enumerate(tys) + val = Symbol("arg_$i") + TA = Symbol("AN_$i") + e = :($val::$TA) + push!(anns, :($TA <: Annotation{<:$ty})) + push!(vals, val) + push!(exprs, e) + push!(primals, :($val.val)) + push!(tangents, :($val isa Const ? $ChainRulesCore.NoTangent() : $val.dval)) + push!(tangentsi, :($val isa Const ? $ChainRulesCore.NoTangent() : $val.dval[i])) + end + + quote + function EnzymeRules.forward(fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {RetAnnotation, FA<:Annotation{<:$(esc(fn))}, $(anns...)} + batchsize = same_or_one(1, $(vals...)) + if batchsize == 1 + dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval + cres = $ChainRulesCore.frule((dfn, $(tangents...),), fn.val, $(primals...); kwargs...) + if RetAnnotation <: Const + return nothing + elseif RetAnnotation <: Duplicated + return Duplicated(cres[1], cres[2]) + elseif RetAnnotation <: DuplicatedNoNeed + return cres[2]::eltype(RetAnnotation) + else + @assert false + end + else + if RetAnnotation <: Const + ntuple(Val(batchsize)) do i + Base.@_inline_meta + dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i] + $ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...) + end + return nothing + elseif RetAnnotation <: BatchDuplicated + cres1 = begin + i = 1 + dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i] + $ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...) + end + batches = ntuple(Val(batchsize-1)) do j + Base.@_inline_meta + i = j+1 + dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i] + $ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...)[2] + end + return BatchDuplicated(cres1[1], (cres1[2], batches...)) + elseif RetAnnotation <: BatchDuplicatedNoNeed + ntuple(Val(batchsize)) do i + Base.@_inline_meta + dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i] + $ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...)[2] + end + else + @assert false + end + end + end + end # quote +end + + +end # module \ No newline at end of file diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 0fa3f6bb90..748ce04c04 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1197,5 +1197,10 @@ end mapreduce(LinearAlgebra.adjoint, vcat, rows) end +function _import_frule end # defined in EnzymeChainRulesCoreExt extension + +macro import_frule(args...) + return _import_frule(args...) +end end # module diff --git a/test/Project.toml b/test/Project.toml index 28cd57f049..bf44952c27 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,7 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" Enzyme_jll = "7cc45869-7501-5eee-bdea-0790c847d4ef" diff --git a/test/ext/chainrulescore.jl b/test/ext/chainrulescore.jl new file mode 100644 index 0000000000..217176a657 --- /dev/null +++ b/test/ext/chainrulescore.jl @@ -0,0 +1,70 @@ +using Enzyme +using Test +using ChainRules +using ChainRulesCore +using LinearAlgebra +using EnzymeTestUtils + +fdiff(f, x::Number) = autodiff(Forward, f, Duplicated, Duplicated(x, one(x)))[2] + +@testset "import_frule" begin + f1(x) = 2*x + ChainRulesCore.@scalar_rule f1(x) (5*one(x),) + Enzyme.@import_frule typeof(f1) Any + @test fdiff(f1, 1f0) === 5f0 + @test fdiff(f1, 1.0) === 5.0 + + # specific signature + f2(x) = 2*x + ChainRulesCore.@scalar_rule f2(x) (5*one(x),) + Enzyme.@import_frule typeof(f2) Float32 + @test fdiff(f2, 1f0) === 5f0 + @test fdiff(f2, 1.0) === 2.0 + + # two arguments + f3(x, y) = 2*x + y + ChainRulesCore.@scalar_rule f3(x, y) (5*one(x), y) + Enzyme.@import_frule typeof(f3) Any Any + @test fdiff(x -> f3(x, 1.0), 2.) === 5.0 + @test fdiff(y -> f3(1.0, y), 2.) === 2.0 + + @testset "batch duplicated" begin + x = [1.0, 2.0, 0.0] + Enzyme.@import_frule typeof(Base.sort) Any + + test_forward(Base.sort, Duplicated, (x, Duplicated)) + # Unsupported by EnzymeTestUtils + # test_forward(Base.sort, Duplicated, (x, DuplicatedNoNeed)) + test_forward(Base.sort, DuplicatedNoNeed, (x, Duplicated)) + # Unsupported by EnzymeTestUtils + # test_forward(Base.sort, DuplicatedNoNeed, (x, DuplicatedNoNeed)) + test_forward(Base.sort, Const, (x, Duplicated)) + # Unsupported by EnzymeTestUtils + # test_forward(Base.sort, Const, (x, DuplicatedNoNeed)) + + test_forward(Base.sort, Const, (x, Const)) + + # ChainRules does not support this case (returning notangent) + # test_forward(Base.sort, Duplicated, (x, Const)) + # test_forward(Base.sort, DuplicatedNoNeed, (x, Const)) + + test_forward(Base.sort, BatchDuplicated, (x, BatchDuplicated)) + # Unsupported by EnzymeTestUtils + # test_forward(Base.sort, BatchDuplicated, (x, BatchDuplicatedNoNeed)) + test_forward(Base.sort, BatchDuplicatedNoNeed, (x, BatchDuplicated)) + # Unsupported by EnzymeTestUtils + # test_forward(Base.sort, BatchDuplicatedNoNeed, (x, BatchDuplicatedNoNeed)) + test_forward(Base.sort, Const, (x, BatchDuplicated)) + # Unsupported by EnzymeTestUtils + # test_forward(Base.sort, Const, (x, BatchDuplicatedNoNeed)) + + # ChainRules does not support this case (returning notangent) + # test_forward(Base.sort, BatchDuplicated, (x, Const)) + # test_forward(Base.sort, BatchDuplicatedNoNeed, (x, Const)) + end +end + + + + + diff --git a/test/packages/specialfunctions.jl b/test/ext/specialfunctions.jl similarity index 100% rename from test/packages/specialfunctions.jl rename to test/ext/specialfunctions.jl diff --git a/test/runtests.jl b/test/runtests.jl index 4ca7aa45a8..bf7bcfee5d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -100,15 +100,6 @@ end include("blas.jl") end -@static if VERSION ≥ v"1.9-" - using SpecialFunctions - @testset "SpecialFunctions ext" begin - lgabsg(x) = SpecialFunctions.logabsgamma(x)[1] - test_scalar(lgabsg, 1.0; rtol = 1.0e-5, atol = 1.0e-5) - test_scalar(lgabsg, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5) - end -end - f0(x) = 1.0 + x function vrec(start, x) if start > length(x) @@ -1218,7 +1209,7 @@ end ## https://github.com/JuliaDiff/ChainRules.jl/tree/master/test/rulesets if !Sys.iswindows() - include("packages/specialfunctions.jl") + include("ext/specialfunctions.jl") end @testset "Threads" begin @@ -3032,5 +3023,23 @@ end @test res[2][5] ≈ 0 @test res[2][6] ≈ 6.0 end + +# TEST EXTENSIONS +@static if VERSION ≥ v"1.9-" + using SpecialFunctions + @testset "SpecialFunctions ext" begin + lgabsg(x) = SpecialFunctions.logabsgamma(x)[1] + test_scalar(lgabsg, 1.0; rtol = 1.0e-5, atol = 1.0e-5) + test_scalar(lgabsg, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5) + end + + using ChainRulesCore + @testset "ChainRulesCore ext" begin + include("ext/chainrulescore.jl") + end +end + + + end From 43193d6251b56ea4b3710e2bba2b95f0d1db9f13 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 7 May 2024 20:47:03 -0400 Subject: [PATCH 026/495] Handle constant gep (#1419) --- src/compiler/optimize.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 5be2d712a8..4dcf37a145 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -422,6 +422,14 @@ function nodecayed_phis!(mod::LLVM.Module) return v2, offset, skipload end + if isa(v, LLVM.ConstantExpr) && opcode(v) == LLVM.API.LLVMGetElementPtr && !hasload + v2, offset, skipload = getparent(operands(v)[1], offset, hasload) + offset = nuwadd!(b, offset, API.EnzymeComputeByteOffsetOfGEP(b, v, offty)) + v2 = bitcast!(b, v2, LLVM.PointerType(eltype(value_type(v)), addrspace(value_type(v2)))) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end + undeforpoison = isa(v, LLVM.UndefValue) @static if LLVM.version() >= v"12" undeforpoison |= isa(v, LLVM.PoisonValue) From b7c96de0547d1e9ea339be066e370c2930eb01ac Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 7 May 2024 21:54:53 -0400 Subject: [PATCH 027/495] Mark inactive functions as having no escaping allocation --- src/compiler.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 097ff04c5a..28486503e3 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4793,11 +4793,11 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; continue end if EnzymeRules.is_inactive_from_sig(mi.specTypes; world, method_table, caller) - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree")]) + handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation")]) continue end if EnzymeRules.is_inactive_noinl_from_sig(mi.specTypes; world, method_table, caller) - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree")], false, false) + handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation")], false, false) for bb in blocks(llvmfn) for inst in instructions(bb) if isa(inst, LLVM.CallInst) From 04aa71b0801b21545baccb9adc14f114eedd0d2a Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 8 May 2024 15:20:46 -0400 Subject: [PATCH 028/495] Correct dead arg elimination (#1421) --- src/compiler/optimize.jl | 21 +++++++++++++++++---- test/runtests.jl | 21 +++++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 4dcf37a145..7ee45963cb 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -837,6 +837,15 @@ function propagate_returned!(mod::LLVM.Module) if !prevent && (linkage(fn) == LLVM.API.LLVMInternalLinkage || linkage(fn) == LLVM.API.LLVMPrivateLinkage) && any(kind(attr) == kind(EnumAttribute("nocapture")) for attr in collect(parameter_attributes(fn, i))) val = nothing illegalUse = false + torem = LLVM.Instruction[] + argeltype = if LLVM.version().major >= 12 + # TODO try to get sret element type if possible + # note currently opaque pointers has this break [and we need to doa check if opaque + # and if so get inner piece] + eltype(value_type(arg)) + else + eltype(value_type(arg)) + end for u in LLVM.uses(fn) un = LLVM.user(u) if !isa(un, LLVM.CallInst) @@ -859,8 +868,8 @@ function propagate_returned!(mod::LLVM.Module) illegalUse = true break end + eltype = LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(ops[i])) seenfn = false - torem = LLVM.Instruction[] todo = LLVM.Instruction[] for u2 in LLVM.uses(ops[i]) un2 = LLVM.user(u2) @@ -905,14 +914,19 @@ function propagate_returned!(mod::LLVM.Module) push!(torem, un2) end if illegalUse - continue + break end + end + if !illegalUse for c in reverse(torem) unsafe_delete!(LLVM.parent(c), c) end B = IRBuilder() position!(B, first(instructions(first(blocks(fn))))) - al = alloca!(B, LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(ops[i]))) + al = alloca!(B, argeltype) + if value_type(al) != value_type(arg) + al = addrspacecast!(B, al, value_type(arg)) + end LLVM.replace_uses!(arg, al) end end @@ -1547,7 +1561,6 @@ end gvn!(pm) # Exxtra run!(pm, mod) end - removeDeadArgs!(mod) detect_writeonly!(mod) nodecayed_phis!(mod) diff --git a/test/runtests.jl b/test/runtests.jl index bf7bcfee5d..1331eb0cc5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -535,6 +535,27 @@ end @test autodiff(Reverse, f10, Active, Active(2.0))[1][1] == sqrt(5) end +function deadarg_pow(z::T, i) where {T<:Real} + zabs = abs(z) + if sign(z) < zero(T) + return (zabs^i) * (cos(T(π) * i) + sin(T(π) * i)im) + end + return zabs^i + zero(T)im +end + +function deadargtest(n) + wp = 1 + deadarg_pow(-n, 0.5) + + deadarg_pow(-n, 0.5) + + return real(wp) +end + +@testset "Dead arg elim" begin + res = autodiff(Enzyme.ReverseWithPrimal, deadargtest, Active, Active(0.25)) + @test res[2] ≈ 1.0 +end + @testset "Taylor series tests" begin # Taylor series for `-log(1-x)` From 732af6f729ae65dfb64d9f5a7760d7a272de9ece Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 9 May 2024 05:04:03 -0400 Subject: [PATCH 029/495] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 0f5073d450..d52945b335 100644 --- a/Project.toml +++ b/Project.toml @@ -28,7 +28,7 @@ EnzymeChainRulesCoreExt = "ChainRulesCore" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7" -Enzyme_jll = "0.0.105" +Enzyme_jll = "0.0.106" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1" ObjectFile = "0.4" @@ -38,4 +38,4 @@ julia = "1.6" [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" \ No newline at end of file +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" From f7b4c4ed7699ecc312f3b377634c789cc5803a86 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 9 May 2024 12:42:43 -0700 Subject: [PATCH 030/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d52945b335..2b11b8a6d3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.2" +version = "0.12.3" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 9c5e95e40f7b421d7464698854d04468f9a9f588 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Fri, 10 May 2024 20:33:55 +0200 Subject: [PATCH 031/495] Adjust triangular solve rules to new EnzymeTestUtils capabilities (#1407) --- test/internal_rules.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index b076a51b3e..965d8d4b55 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -423,15 +423,12 @@ end A = T(M) @testset "test through constructor" begin _A = T(A) - function f!(Y, A, B, ::T) where T - ldiv!(Y, T(A), B) - return nothing - end + f!(Y, A, B, ::T) where {T} = ldiv!(Y, T(A), B) for TY in (Const, Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), TB in (Const, Duplicated, BatchDuplicated) are_activities_compatible(Const, TY, TM, TB) || continue - test_reverse(f!, Const, (Y, TY), (M, TM), (B, TB), (_A, Const)) + test_reverse(f!, TY, (Y, TY), (M, TM), (B, TB), (_A, Const)) end end @testset "test through `Adjoint` wrapper (regression test for #1306)" begin From 32dd788c0823cd4ccad9af9e00c73cb278ae56fa Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 10 May 2024 12:37:11 -0700 Subject: [PATCH 032/495] CompatHelper: bump compat for LLVM to 7, (keep existing compat) (#1420) --- Project.toml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 2b11b8a6d3..5dc0112338 100644 --- a/Project.toml +++ b/Project.toml @@ -16,26 +16,26 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -[weakdeps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" - -[extensions] -EnzymeSpecialFunctionsExt = "SpecialFunctions" -EnzymeChainRulesCoreExt = "ChainRulesCore" - [compat] CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7" Enzyme_jll = "0.0.106" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" -LLVM = "6.1" +LLVM = "6.1, 7" ObjectFile = "0.4" Preferences = "1.4" SpecialFunctions = "1, 2" julia = "1.6" +[extensions] +EnzymeChainRulesCoreExt = "ChainRulesCore" +EnzymeSpecialFunctionsExt = "SpecialFunctions" + [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" + +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" \ No newline at end of file From 5d4f9e02f38195217c023d00ebe813d2895984ed Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 10 May 2024 13:43:34 -0700 Subject: [PATCH 033/495] No escaping allocation (#1422) --- src/compiler.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 28486503e3..21895e3b4b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2653,6 +2653,7 @@ end function annotate!(mod, mode) inactive = LLVM.StringAttribute("enzyme_inactive", "") active = LLVM.StringAttribute("enzyme_active", "") + no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") fns = functions(mod) for f in fns @@ -2663,6 +2664,7 @@ function annotate!(mod, mode) if haskey(fns, fname) fn = fns[fname] push!(function_attributes(fn), inactive) + push!(function_attributes(fn), no_escaping_alloc) for u in LLVM.uses(fn) c = LLVM.user(u) if !isa(c, LLVM.CallInst) @@ -2679,6 +2681,7 @@ function annotate!(mod, mode) continue end LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) + LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) end end end @@ -2771,6 +2774,15 @@ function annotate!(mod, mode) end end + for fname in ("julia.get_pgcstack", "julia.ptls_states", "jl_get_ptls_states", "julia.safepoint", "ijl_throw") + if haskey(fns, fname) + fn = fns[fname] + push!(function_attributes(fn), no_escaping_alloc) + end + end + + + for fname in ("julia.pointer_from_objref",) if haskey(fns, fname) fn = fns[fname] @@ -2788,6 +2800,7 @@ function annotate!(mod, mode) if haskey(fns, boxfn) fn = fns[boxfn] push!(return_attributes(fn), LLVM.EnumAttribute("noalias", 0)) + push!(function_attributes(fn), no_escaping_alloc) if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly", 0)) end @@ -2813,6 +2826,7 @@ function annotate!(mod, mode) continue end LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0)) + LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, no_escaping_alloc) if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), LLVM.EnumAttribute("inaccessiblememonly", 0)) end From 1aa60874ea617ad5e59a7976c5d82ad310523f8a Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 10 May 2024 14:26:40 -0700 Subject: [PATCH 034/495] Fewerprints (#1423) * Now with fewer prints * Even fewer prints --- src/compiler/optimize.jl | 86 +++++++++++++++++++++----------------- src/compiler/validation.jl | 34 +++++++++++---- src/rules/activityrules.jl | 19 ++++++--- 3 files changed, 87 insertions(+), 52 deletions(-) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 7ee45963cb..f5ae2f823f 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -667,12 +667,15 @@ function fix_decayaddr!(mod::LLVM.Module) end end if !sret - println(string(f)) - @show inst, st, fop - flush(stdout) + msg = sprint() do io + println(io, "Enzyme Internal Error: did not have sret when expected") + println(io, "f=", string(f)) + println(io, "inst=", string(inst)) + println(io, "st=", string(st)) + println(io, "fop=", string(fop)) + end + throw(AssertionError(msg)) end - - @assert sret elt = eltype(value_type(inst)) if temp === nothing @@ -1208,9 +1211,12 @@ function validate_return_roots!(mod) if length(enzyme_srets) == 1 && LLVM.return_type(LLVM.function_type(f)) == VT && length(enzyme_srets_v) == 0 # Upgrading to sret requires writeonly if !any(kind(attr) == kind(EnumAttribute("writeonly")) for attr in collect(parameter_attributes(f, 1))) - @show f - @show collect(parameter_attributes(f, 1)) - @assert false + msg = sprint() do io::IO + println(io, "Enzyme internal error (not writeonly sret)") + println(io, string(f)) + println(io, "collect(parameter_attributes(f, 1))=", collect(parameter_attributes(f, 1))) + end + throw(AssertionError(msg)) end alty = nothing @@ -1220,7 +1226,14 @@ function validate_return_roots!(mod) @assert LLVM.called_operand(u) == f alop = operands(u)[1] if !isa(alop, LLVM.AllocaInst) - @show alop, u, f + msg = sprint() do io::IO + println(io, "Enzyme internal error (!isa(alop, LLVM.AllocaInst))") + println(io, "alop=", alop) + println(io, "u=", u) + println(io, "f=", string(f)) + end + throw(AssertionError(msg)) + end @assert isa(alop, LLVM.AllocaInst) nty = API.EnzymeAllocaType(alop) @@ -1275,9 +1288,16 @@ function validate_return_roots!(mod) enzyme_srets = enzyme_srets2 if length(enzyme_srets) != 0 - @show f - @show enzyme_srets, enzyme_srets_v, srets, rroots, rroots_v - @assert false + msg = sprint() do io::IO + println(io, "Enzyme internal error (length(enzyme_srets) != 0)") + println(io, "f=", string(f)) + println(io, "enzyme_srets=", enzyme_srets) + println(io, "enzyme_srets_v=", enzyme_srets_v) + println(io, "srets=", srets) + println(io, "rroots=", rroots) + println(io, "rroots_v=", rroots_v) + end + throw(AssertionError(msg)) end end end @@ -1303,39 +1323,33 @@ function checkNoAssumeFalse(mod, shouldshow=false) continue end intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(inst)) - if shouldshow - @show intr, inst - end if intr != LLVM.Intrinsic("llvm.assume").id continue end - if shouldshow - @show inst - end op = operands(inst)[1] - if shouldshow - @show op - end if isa(op, LLVM.ConstantInt) op2 = convert(Bool, op) - if shouldshow - @show op2 - end if !op2 - println(string(mod)) - println(string(f)) - println(string(bb)) - flush(stdout) - @assert false + msg = sprint() do io + println(io, "Enzyme Internal Error: non-constant assume condition") + println(io, "mod=", string(mod)) + println(io, "f=", string(f)) + println(io, "bb=", string(bb)) + println(io, "op2=", string(op2)) + end + throw(AssertionError(msg)) end end if isa(op, LLVM.ICmpInst) if predicate_int(op) == LLVM.API.LLVMIntNE && operands(op)[1] == operands(op)[2] - println(string(mod)) - println(string(f)) - println(string(bb)) - flush(stdout) - @assert false + msg = sprint() do io + println(io, "Enzyme Internal Error: non-icmp assume condition") + println(io, "mod=", string(mod)) + println(io, "f=", string(f)) + println(io, "bb=", string(bb)) + println(io, "op=", string(op)) + end + throw(AssertionError(msg)) end end end @@ -1716,10 +1730,6 @@ function post_optimze!(mod, tm, machine=true) if LLVM.API.LLVMVerifyModule(mod, LLVM.API.LLVMReturnStatusAction, out_error) != 0 throw(LLVM.LLVMException("broken gc calling conv fix\n"*string(unsafe_string(out_error[]))*"\n"*string(mod))) end - # println(string(mod)) - # @safe_show "pre_post", mod - # flush(stdout) - # flush(stderr) LLVM.ModulePassManager() do pm addTargetPasses!(pm, tm, LLVM.triple(mod)) addOptimizationPasses!(pm) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index d5ecc3c424..75edca9d9b 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -576,9 +576,16 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) if legal && isa(flib, Core.MethodInstance) if !Base.isvarargtype(flib.specTypes.parameters[end]) if length(tys) != length(flib.specTypes.parameters) - @show tys, flib, inst, offset, start + msg = sprint() do io::IO + println(io, "Enzyme internal error (length(tys) != length(flib.specTypes.parameters))") + println(io, "tys=", tys) + println(io, "flib=", flib) + println(io, "inst=", inst) + println(io, "offset=", offset) + println(io, "start=", start) + end + throw(AssertionError(msg)) end - @assert length(tys) == length(flib.specTypes.parameters) end tys = flib.specTypes.parameters end @@ -726,9 +733,15 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width continue end - println(string(enzymefn)) - @show "BAD", acur, aoff, prev - @assert false + msg = sprint() do io::IO + println(io, "Enzyme Internal Error (rewrite_union_returns_as_ref[1])") + println(io, string(enzymefn)) + println(io, "BAD") + println(io, "acur=", acur) + println(io, "aoff=", aoff) + println(io, "prev=", prev) + end + throw(AssertionError(msg)) end continue end @@ -744,9 +757,12 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width end end - println(string(enzymefn)) - - @show cur, off - @assert false + msg = sprint() do io::IO + println(io, "Enzyme Internal Error (rewrite_union_returns_as_ref[2])") + println(io, string(enzymefn)) + println(io, "cur=", cur) + println(io, "off=", off) + end + throw(AssertionError(msg)) end end diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl index 2b36a9740e..9e32023957 100644 --- a/src/rules/activityrules.jl +++ b/src/rules/activityrules.jl @@ -29,12 +29,21 @@ function julia_activity_rule(f::LLVM.Function) end world = enzyme_extract_world(f) - if expectLen != length(parameters(f)) - println(string(f)) - @show expectLen, swiftself, sret, returnRoots, mi.specTypes.parameters, retRemoved, parmsRemoved - end # TODO fix the attributor inlining such that this can assert always true - @assert expectLen == length(parameters(f)) + if expectLen != length(parameters(f)) + msg = sprint() do io::IO + println(io, "Enzyme Internal Error (expectLen != length(parameters(f)))") + println(io, string(f)) + println(io, "expectLen=", string(expectLen)) + println(io, "swiftself=", string(swiftself)) + println(io, "sret=", string(sret)) + println(io, "returnRoots=", string(returnRoots)) + println(io, "mi.specTypes.parameters=", string(mi.specTypes.parameters)) + println(io, "retRemoved=", string(retRemoved)) + println(io, "parmsRemoved=", string(parmsRemoved)) + end + throw(AssertionError(msg)) + end jlargs = classify_arguments(mi.specTypes, function_type(f), sret !== nothing, returnRoots !== nothing, swiftself, parmsRemoved) From 4663bba3c00bb77b13be114bcbc81e9b3a017256 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 10 May 2024 16:59:05 -0700 Subject: [PATCH 035/495] Update Project.toml --- Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 5dc0112338..381de17f1b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.3" +version = "0.12.4" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -20,7 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7" -Enzyme_jll = "0.0.106" +Enzyme_jll = "0.0.107" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" @@ -38,4 +38,4 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" \ No newline at end of file +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" From 96681470080e8ad5524d42fad0e1a5e61e09584e Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 10 May 2024 16:59:18 -0700 Subject: [PATCH 036/495] Fix emiterror (#1425) --- src/compiler.jl | 6 +++--- src/gradientutils.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 21895e3b4b..8491f5d05f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1649,7 +1649,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end throw(exc) elseif errtype == API.ET_NoShadow - data = GradientUtils(API.EnzymeGradientUtilsRef(data)) + gutils = GradientUtils(API.EnzymeGradientUtilsRef(data)) msgN = sprint() do io::IO print(io, "Enzyme could not find shadow for value\n") @@ -1661,7 +1661,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end if !isa(val, LLVM.Argument) print(io, "\n Inverted pointers: \n") - ip = API.EnzymeGradientUtilsInvertedPointersToString(data) + ip = API.EnzymeGradientUtilsInvertedPointersToString(gutils) sval = Base.unsafe_string(ip) write(io, sval) API.EnzymeStringFree(ip) @@ -1673,7 +1673,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err println(io) end end - emit_error(B, nothing, msgN) + emit_error(IRBuilder(B), nothing, msgN) return LLVM.null(get_shadow_type(gutils, value_type(val))).ref elseif errtype == API.ET_IllegalTypeAnalysis data = API.EnzymeTypeAnalyzerRef(data) diff --git a/src/gradientutils.jl b/src/gradientutils.jl index 67618e3a45..cc64726f8e 100644 --- a/src/gradientutils.jl +++ b/src/gradientutils.jl @@ -14,7 +14,7 @@ end get_width(gutils::GradientUtils) = API.EnzymeGradientUtilsGetWidth(gutils) get_mode(gutils::GradientUtils) = API.EnzymeGradientUtilsGetMode(gutils) -function get_shadow_type(gutils::GradientUtils, T::LLVM.Type) +function get_shadow_type(gutils::GradientUtils, T::LLVM.LLVMType) w = get_width(gutils) if w == 1 return T From 9ca0d9d19d3c53e628176930a1c333fa5c5b71b6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 11 May 2024 00:25:56 -0700 Subject: [PATCH 037/495] fix deferred (#1426) * fix deferred * fixup * no escaping --- src/Enzyme.jl | 38 +++++++++++++++++++++++++++++++++----- src/compiler.jl | 38 ++++++++++++++++++++++++++++++++++++-- src/compiler/validation.jl | 8 ++++++++ test/runtests.jl | 34 +++++++++++++++++++++++++++++++++- 4 files changed, 110 insertions(+), 8 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 748ce04c04..981bbadbab 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -795,7 +795,7 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{TapeType}, ::Type{FA}, ::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A2<:Annotation, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} +@inline function autodiff_deferred_thunk(mode::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, tt::Type{TapeType}, fa::Type{FA}, a2::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A2<:Annotation, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} @assert RABI == FFIABI width = if Width == 0 w = same_or_one(1, args...) @@ -819,10 +819,38 @@ result, ∂v, ∂A primal_tt = Tuple{map(eltype, args)...} world = codegen_world_age(eltype(FA), primal_tt) - primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(A2), Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType) - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(A2), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType) - aug_thunk = Compiler.AugmentedForwardThunk{Ptr{Cvoid}, FA, A2, TT, Val{width}, Val(ReturnPrimal), TapeType}(primal_ptr) - adj_thunk = Compiler.AdjointThunk{Ptr{Cvoid}, FA, A2, TT, Val{width}, TapeType}(adjoint_ptr) + primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType) + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType) + + RT = if A2 <: Duplicated && width != 1 + if A2 isa UnionAll + BatchDuplicated{T, width} where T + else + BatchDuplicated{eltype(A2), width} + end + elseif A2 <: DuplicatedNoNeed && width != 1 + if A2 isa UnionAll + BatchDuplicatedNoNeed{T, width} where T + else + BatchDuplicatedNoNeed{eltype(A2), width} + end + else + A2 + end + + rt = if RT isa UnionAll + @static if VERSION < v"1.8-" + throw(MethodError(autodiff_deferred_thunk, (mode, tt, fa, a2, args...))) + else + RT{Core.Compiler.return_type(Tuple{eltype(FA), map(eltype, args)...})} + end + else + @assert RT isa DataType + RT + end + + aug_thunk = Compiler.AugmentedForwardThunk{Ptr{Cvoid}, FA, rt, TT, Val{width}, Val(ReturnPrimal), TapeType}(primal_ptr) + adj_thunk = Compiler.AdjointThunk{Ptr{Cvoid}, FA, rt, TT, Val{width}, TapeType}(adjoint_ptr) aug_thunk, adj_thunk end diff --git a/src/compiler.jl b/src/compiler.jl index 8491f5d05f..74d6013fef 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2361,6 +2361,7 @@ function zero_allocation(B::LLVM.IRBuilder, jlType, LLVMType, obj, AlignedSize, wrapper_f = LLVM.Function(mod, "zeroType", LLVM.FunctionType(LLVM.VoidType(), [value_type(obj), T_int8, value_type(Size)])) push!(function_attributes(wrapper_f), StringAttribute("enzyme_math", "enzyme_zerotype")) push!(function_attributes(wrapper_f), StringAttribute("enzyme_inactive")) + push!(function_attributes(wrapper_f), StringAttribute("enzyme_no_escaping_allocation")) push!(function_attributes(wrapper_f), EnumAttribute("alwaysinline", 0)) push!(function_attributes(wrapper_f), EnumAttribute("nofree", 0)) push!(function_attributes(wrapper_f), EnumAttribute("argmemonly", 0)) @@ -2774,7 +2775,22 @@ function annotate!(mod, mode) end end - for fname in ("julia.get_pgcstack", "julia.ptls_states", "jl_get_ptls_states", "julia.safepoint", "ijl_throw") + for fname in ("julia.get_pgcstack", "julia.ptls_states", "jl_get_ptls_states", "julia.safepoint", "ijl_throw", "julia.pointer_from_objref", + "ijl_array_grow_end", "jl_array_grow_end", "ijl_array_del_end", "jl_array_del_end", + "ijl_array_grow_beg", "jl_array_grow_beg", "ijl_array_del_beg", "jl_array_del_beg", + "ijl_array_grow_at", "jl_array_grow_at", + "ijl_array_del_at", "jl_array_del_at", + "ijl_pop_handler", "jl_pop_handler", + "ijl_push_handler", "jl_push_handler", + "ijl_module_name", "jl_module_name", + "ijl_restore_excstack", "jl_restore_excstack", + "julia.except_enter", + "ijl_get_nth_field_checked", "jl_get_nth_field_checked", + "jl_egal__unboxed", + "ijl_reshape_array", "jl_reshape_array", + "ijl_eqtable_get", "jl_eqtable_get", + "jl_gc_run_pending_finalizers", + ) if haskey(fns, fname) fn = fns[fname] push!(function_attributes(fn), no_escaping_alloc) @@ -2826,7 +2842,7 @@ function annotate!(mod, mode) continue end LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0)) - LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, no_escaping_alloc) + LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), LLVM.EnumAttribute("inaccessiblememonly", 0)) end @@ -4324,6 +4340,9 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if kind(prev) == kind(StringAttribute("enzyme_inactive")) push!(attributes, prev) end + if kind(prev) == kind(StringAttribute("enzyme_no_escaping_allocation")) + push!(attributes, prev) + end end if LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMReturnStatusAction) != 0 @@ -4794,6 +4813,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; EnumAttribute("speculatable", 0), StringAttribute("enzyme_shouldrecompute"), StringAttribute("enzyme_inactive"), + StringAttribute("enzyme_no_escaping_allocation") ]) continue end @@ -4815,6 +4835,20 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; for bb in blocks(llvmfn) for inst in instructions(bb) if isa(inst, LLVM.CallInst) + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("no_escaping_allocation")) + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("enzyme_inactive")) + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), EnumAttribute("nofree")) + end + end + end + continue + end + if func === typeof(Base.match) + handleCustom(llvmfn, "base_match", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation")], false, false) + for bb in blocks(llvmfn) + for inst in instructions(bb) + if isa(inst, LLVM.CallInst) + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("no_escaping_allocation")) LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("enzyme_inactive")) LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), EnumAttribute("nofree")) end diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 75edca9d9b..03e7a4458e 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -478,12 +478,16 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) nofree = LLVM.EnumAttribute("nofree") LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) + no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) end if funclib == Base.tuple && length(operands(inst)) == 4+1+1 && Base.isconcretetype(GT) && Enzyme.Compiler.guaranteed_const_nongen(GT, world) inactive = LLVM.StringAttribute("enzyme_inactive", "") LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) nofree = LLVM.EnumAttribute("nofree") LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) + no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) end end end @@ -515,6 +519,8 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) nofree = LLVM.EnumAttribute("nofree") LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) + no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) end end end @@ -596,6 +602,8 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) nofree = LLVM.EnumAttribute("nofree") LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) + no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") + LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) end end end diff --git a/test/runtests.jl b/test/runtests.jl index 1331eb0cc5..7164667bac 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -321,11 +321,16 @@ end Const{typeof(dot)}, Active, Duplicated{typeof(thunk_A)} ) @test Tuple{Float64,Float64} === TapeType + Ret = if VERSION < v"1.8-" + Active{Float64} + else + Active + end fwd, rev = Enzyme.autodiff_deferred_thunk( ReverseSplitWithPrimal, TapeType, Const{typeof(dot)}, - Active{Float64}, + Ret, Duplicated{typeof(thunk_A)} ) tape, primal, _ = fwd(Const(dot), dup) @@ -335,6 +340,33 @@ end @test all(dA .== [6.0, 10.0]) @test all(dA .== def_dA) @test all(dA .== thunk_dA) + + @static if VERSION < v"1.8-" + else + function kernel(len, A) + for i in 1:len + A[i] *= A[i] + end + end + + A = Array{Float64}(undef, 64) + dA = Array{Float64}(undef, 64) + + A .= (1:1:64) + dA .= 1 + + function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, args...) where {ModifiedBetween, FT} + TapeType = Enzyme.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...) + forward, reverse = Enzyme.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...) + forward(Const(f), Const(ctx), args...)[1] + return nothing + end + + ModifiedBetween = Val((false, false, true)) + + aug_fwd(64, kernel, ModifiedBetween, Duplicated(A, dA)) + end + end @testset "Simple Complex tests" begin From 416dc86c8dd93d40c15bd1ac3f58f3ff0b24c97e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 11 May 2024 16:46:20 +0200 Subject: [PATCH 038/495] Update Format.yml (#1428) --- .github/workflows/Format.yml | 58 +++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/.github/workflows/Format.yml b/.github/workflows/Format.yml index 682c2744dc..88098a453e 100644 --- a/.github/workflows/Format.yml +++ b/.github/workflows/Format.yml @@ -1,30 +1,40 @@ -name: Format suggestions - on: + push: + branches: + - master + tags: '*' pull_request: - -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + types: + - opened + - reopened + - synchronize + - ready_for_review jobs: format: - permissions: - contents: read - pull-requests: write - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 + timeout-minutes: 30 steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 - with: - version: 1 - - run: | - julia -e 'using Pkg; Pkg.add("JuliaFormatter")' - julia -e 'using JuliaFormatter; format("."; verbose=true)' - - uses: reviewdog/action-suggester@v1 - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - tool_name: JuliaFormatter - fail_on_error: true + - uses: actions/checkout@v4.1.5 + + - uses: dorny/paths-filter@v3.0.2 + id: filter + with: + filters: | + julia_file_change: + - added|modified: '**.jl' + + - uses: julia-actions/setup-julia@latest + if: steps.filter.outputs.julia_file_change == 'true' + with: + version: 1.9 + + - name: Apply JuliaFormatter + if: steps.filter.outputs.julia_file_change == 'true' + run: | + julia --color=yes dev/flux_format.jl --verbose . + + - name: Check formatting diff + if: steps.filter.outputs.julia_file_change == 'true' + run: | + git diff --color=always --exit-code From ccdfd8c6e85f6a09398688d57ae0e7ef1e0ca3ff Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 11 May 2024 14:39:49 -0700 Subject: [PATCH 039/495] Getfield with reference (#1430) * fix deferred * Getfield with reference --- src/rules/typeunstablerules.jl | 33 ++++++++++++++++-------- test/runtests.jl | 47 ++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 10 deletions(-) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 149ed46893..91b1ee837e 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -249,11 +249,12 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) return false end -getfield_idx(v, idx) = ccall(:jl_get_nth_field_checked, Any, (Any, UInt), v, idx) -setfield_idx(v, idx, rhs) = ccall(:jl_set_nth_field, Cvoid, (Any, UInt, Any), v, idx, rhs) - function rt_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst} - res = getfield(dptr, symname) + res = if dptr isa Base.RefValue + Base.getfield(dptr[], symname) + else + Base.getfield(dptr, symname) + end RT = Core.Typeof(res) if active_reg(RT) if length(dptrs) == 0 @@ -271,7 +272,11 @@ function rt_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs end function idx_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst} - res = getfield_idx(dptr, symname) + res = if dptr isa Base.RefValue + Base.getfield(dptr[], symname+1) + else + Base.getfield(dptr, symname+1) + end RT = Core.Typeof(res) if active_reg(RT) if length(dptrs) == 0 @@ -289,7 +294,11 @@ function idx_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptr end function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst} - cur = getfield(dptr, symname) + cur = if dptr isa Base.RefValue + getfield(dptr[], symname) + else + getfield(dptr, symname) + end RT = Core.Typeof(cur) if active_reg(RT) && !isconst @@ -305,16 +314,20 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, return nothing end function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst} - cur = getfield_idx(dptr, symname) + cur = if dptr isa Base.RefValue + Base.getfield(dptr[], symname+1) + else + Base.getfield(dptr, symname+1) + end RT = Core.Typeof(cur) if active_reg(RT) && !isconst if length(dptrs) == 0 - setfield_idx(dptr, symname, recursive_add(cur, dret[])) + setfield!(dptr, symname+1, recursive_add(cur, dret[])) else - setfield_idx(dptr, symname, recursive_add(cur, dret[1][])) + setfield!(dptr, symname+1, recursive_add(cur, dret[1][])) for i in 1:length(dptrs) - setfield_idx(dptrs[i], symname, recursive_add(cur, dret[1+i][])) + setfield!(dptrs[i], symname+1, recursive_add(cur, dret[1+i][])) end end end diff --git a/test/runtests.jl b/test/runtests.jl index 7164667bac..b63541a75a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2116,6 +2116,53 @@ end @test dmt2.y ≈ 2.4 end + +struct GFUniform{T} + a::T + b::T +end +GFlogpdf(d::GFUniform, ::Real) = -log(d.b - d.a) + +struct GFNormal{T} + μ::T + σ::T +end +GFlogpdf(d::GFNormal, x::Real) = -(x - d.μ)^2 / (2 * d.σ^2) + +struct GFProductDist{V} + dists::V +end +function GFlogpdf(d::GFProductDist, x::Vector) + dists = d.dists + s = zero(eltype(x)) + for i in eachindex(x) + s += GFlogpdf(dists[i], x[i]) + end + return s +end + +struct GFNamedDist{Names, D<:NamedTuple{Names}} + dists::D +end + +function GFlogpdf(d::GFNamedDist{N}, x::NamedTuple{N}) where {N} + vt = values(x) + dists = d.dists + return mapreduce((dist, acc) -> GFlogpdf(dist, acc), +, dists, vt) +end + + +@testset "Getfield with reference" begin + Enzyme.API.runtimeActivity!(true) + + d = GFNamedDist((;a = GFNormal(0.0, 1.0), b = GFProductDist([GFUniform(0.0, 1.0), GFUniform(0.0, 1.0)]))) + p = (a = 1.0, b = [0.5, 0.5]) + dp = Enzyme.make_zero(p) + GFlogpdf(d, p) + autodiff(Reverse, GFlogpdf, Active, Const(d), Duplicated(p, dp)) + Enzyme.API.runtimeActivity!(false) +end + @testset "apply iterate" begin function mktup(v) tup = tuple(v...) From 183de43eaae6bf955c752a009fe25591d2d649fa Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 11 May 2024 15:08:38 -0700 Subject: [PATCH 040/495] Optimize away unnecessary recursive forward passes (#1431) --- Project.toml | 2 +- src/compiler.jl | 10 ++- src/compiler/optimize.jl | 140 ++++++++++++++++++++++++++++++++++++++- test/runtests.jl | 8 +++ 4 files changed, 155 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 381de17f1b..e6f656068e 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7" -Enzyme_jll = "0.0.107" +Enzyme_jll = "0.0.108" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" diff --git a/src/compiler.jl b/src/compiler.jl index 74d6013fef..9c780fa598 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3017,7 +3017,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr elseif T <: DuplicatedNoNeed || T<: BatchDuplicatedNoNeed push!(args_activity, API.DFT_DUP_NONEED) else - error("illegal annotation type") + error("illegal annotation type $T") end typeTree = typetree(source_typ, ctx, dl, seen) if isboxed @@ -5769,8 +5769,12 @@ function _thunk(job, postopt::Bool=true) end # Run post optimization pipeline - if postopt && job.config.params.ABI <: FFIABI - post_optimze!(mod, JIT.get_tm()) + if postopt + if job.config.params.ABI <: FFIABI + post_optimze!(mod, JIT.get_tm()) + else + propagate_returned!(mod) + end end return (mod, adjoint_name, primal_name, meta.TapeType) end diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index f5ae2f823f..1d88b56219 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -798,6 +798,127 @@ function prop_global!(g) return changed, newfns end +# From https://llvm.org/doxygen/IR_2Instruction_8cpp_source.html#l00959 +function mayWriteToMemory(inst::LLVM.Instruction)::Bool + # we will ignore fense here + if isa(inst, LLVM.StoreInst) + return true + end + if isa(inst, LLVM.VAArgInst) + return true + end + if isa(inst, LLVM.AtomicCmpXchgInst) + return true + end + if isa(inst, LLVM.AtomicRMWInst) + return true + end + if isa(inst, LLVM.CatchPadInst) + return true + end + if isa(inst, LLVM.CatchRetInst) + return true + end + if isa(inst, LLVM.CallInst) || isa(inst, LLVM.InvokeInst) || isa(inst, LLVM.CallBrInst) + idx = reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx); + + Attrs = Base.unsafe_convert(Ptr{LLVM.API.LLVMAttributeRef}, Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef)*count)) + LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) + for j in 1:count + attr = LLVM.Attribute(unsafe_load(Attrs, j)) + if kind(attr) == kind(EnumAttribute("readnone")) + return false + end + if kind(attr) == kind(EnumAttribute("readonly")) + return false + end + end + Libc.free(Attrs) + return true + end + # Ignoring load unordered case + return false +end + +function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String}) + calls = LLVM.CallInst[] + + for u in LLVM.uses(fn) + un = LLVM.user(u) + + # Only permit call users + if !isa(un, LLVM.CallInst) + return false + end + un = un::LLVM.CallInst + + # Passing the fn as an argument is not permitted + for op in collect(operands(un))[1:end-1] + if op == fn + return false + end + end + + # Something with a user is not permitted + for u2 in LLVM.uses(un) + return false + end + push!(calls, un) + end + if length(calls) == 0 + return false + end + + done = Set{LLVM.Function}() + todo = LLVM.Function[fn] + + while length(todo) != 0 + cur = pop!(todo) + if cur in done + continue + end + push!(done, cur) + + attrs = collect(function_attributes(cur)) + if any(kind(attr) == kind(EnumAttribute("readonly")) for attr in attrs) || any(kind(attr) == kind(EnumAttribute("readnone")) for attr in attrs) + continue + end + + if LLVM.name(cur) == "julia.safepoint" + continue + end + + if isempty(blocks(cur)) + return false + end + for bb in blocks(cur) + for inst in instructions(bb) + if !mayWriteToMemory(inst) + continue + end + if isa(inst, LLVM.CallInst) + + fn2 = LLVM.called_operand(inst) + if isa(fn2, LLVM.Function) + push!(todo, fn2) + continue + end + end + return false + end + end + end + + for c in calls + parentf = LLVM.parent(LLVM.parent(c)) + push!(next, LLVM.name(parentf)) + LLVM.API.LLVMInstructionEraseFromParent(c) + end + push!(next, LLVM.name(fn)) + return true +end + function propagate_returned!(mod::LLVM.Module) globs = LLVM.GlobalVariable[] for g in globals(mod) @@ -824,6 +945,9 @@ function propagate_returned!(mod::LLVM.Module) if isempty(blocks(fn)) continue end + if remove_readonly_unused_calls!(fn, next) + changed = true + end attrs = collect(function_attributes(fn)) prevent = any(kind(attr) == kind(StringAttribute("enzyme_preserve_primal")) for attr in attrs) # if any(kind(attr) == kind(EnumAttribute("noinline")) for attr in attrs) @@ -1107,7 +1231,21 @@ function propagate_returned!(mod::LLVM.Module) if !changed break else - todo = collect(functions(mod)[name] for name in next) + todo = LLVM.Function[] + for name in next + fn = functions(mod)[name] + if linkage(fn) == LLVM.API.LLVMInternalLinkage || linkage(fn) == LLVM.API.LLVMPrivateLinkage + has_user = false + for u in LLVM.uses(fn) + has_user = true + break + end + if !has_user + LLVM.API.LLVMDeleteFunction(fn) + end + end + push!(todo, fn) + end end end end diff --git a/test/runtests.jl b/test/runtests.jl index b63541a75a..a2fc9c6eda 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -196,6 +196,14 @@ end end end +@testset "Recursion optimization" begin + # Test that we can successfully optimize out the augmented primal from the recursive divide and conquer + fn = sprint() do io + Enzyme.Compiler.enzyme_code_llvm(io, sum, Active, Tuple{Duplicated{Vector{Float64}}}) + end + @test occursin("diffe",fn) + @test !occursin("aug",fn) +end # @testset "Split Tape" begin # f(x) = x[1] * x[1] From 19dbbb2f757c1408e73870f6de8ca88021eed76c Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 11 May 2024 15:09:29 -0700 Subject: [PATCH 041/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e6f656068e..0188836aac 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.4" +version = "0.12.5" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 75e5311ac0dda08498fcfb0aa54e89930cb3e1c9 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 11 May 2024 19:18:40 -0700 Subject: [PATCH 042/495] Docs: describe differentiable types (#1433) --- docs/src/faq.md | 109 ++++++++++++++++++++++++++++++++++++++++++++++++ src/compiler.jl | 3 ++ 2 files changed, 112 insertions(+) diff --git a/docs/src/faq.md b/docs/src/faq.md index 72fa1f97d9..5e57a8ada8 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -540,3 +540,112 @@ For `d/d conj(z)`, $\frac12 \left( [u_x + i v_x] + i [u_y + i v_y] \right) = \fr ``` Note: when writing rules for complex scalar functions, in reverse mode one needs to conjugate the differential return, and similarly the true result will be the conjugate of that value (in essence you can think of reverse-mode AD as working in the conjugate space). + +## What types are differentiable? + +Enzyme tracks differentiable dataflow through values. Specifically Enzyme tracks differentiable data in base types like Float32, Float64, Float16, BFloat16, etc. + +As a simple example: + +```jldoctest types +f(x) = x * x +Enzyme.autodiff(Forward, f, Duplicated(3.0, 1.0)) + +# output + +(6.0,) +``` + +Enzyme also tracks differentiable data in any types containing these base types (e.g. floats). For example, consider a struct or array containing floats. + +```jldoctest types +struct Pair + lhs::Float64 + rhs::Float64 +end +f_pair(x) = x.lhs * x.rhs +Enzyme.autodiff(Forward, f_pair, Duplicated(Pair(3.0, 2.0), Pair(1.0, 0.0))) + +# output + +(2.0,) +``` + +```jldoctest types +Enzyme.autodiff(Forward, sum, Duplicated([1.0, 2.0, 3.0], [5.0, 0.0, 100.0])) + + +# output + +(105.0,) +``` + +A differentiable data structure can be arbitrarily complex, such as a linked list. + + +```jldoctest types + +struct LList + prev::Union{Nothing, LList} + value::Float64 +end + +function make_list(x::Vector) + result = nothing + for value in reverse(x) + result = LList(result, value) + end + return result +end + +function list_sum(list::Union{Nothing, LList}) + result = 0.0 + while list != nothing + result += list.value + list = list.prev + end + return result +end + +list = make_list([1.0, 2.0, 3.0]) +dlist = make_list([5.0, 0.0, 100.0]) + +Enzyme.autodiff(Forward, list_sum, Duplicated(list, dlist)) + +# output + +(105.0,) +``` + +Presently Enzyme only considers floats as base types. As a result, Enzyme does not support differentiating data contained in Ints, Strings, or Vals. If it is desirable for Enzyme to add a base type, please open an issue. + +```jldoctest types +f_int(x) = x * x +Enzyme.autodiff(Forward, f_int, DuplicatedNoNeed, Duplicated(3, 1)) + +# output + +ERROR: Return type `Int64` not marked Const, but type is guaranteed to be constant +``` + +```jldoctest types +f_str(x) = parse(Float64, x) * parse(Float64, x) + +autodiff(Forward, f_str, Duplicated("1.0", "1.0")) + +# output + +(0.0,) +``` + +```jldoctest types +f_val(::Val{x}) where x = x * x + +autodiff(Forward, f_val, Duplicated(Val(1.0), Val(1.0))) + +# output + +ERROR: Type of ghost or constant type Duplicated{Val{1.0}} is marked as differentiable. +``` + + diff --git a/src/compiler.jl b/src/compiler.jl index 9c780fa598..25cfd4d4cd 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -103,6 +103,7 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( end const nofreefns = Set{String}(( + "ijl_try_substrtod", "jl_try_substrtod", "jl_f__apply_iterate", "ijl_field_index", "jl_field_index", "julia.call", "julia.call2", @@ -178,6 +179,7 @@ const nofreefns = Set{String}(( )) const inactivefns = Set{String}(( + "ijl_try_substrtod", "jl_try_substrtod", "ijl_tagged_gensym", "jl_tagged_gensym", "jl_get_world_counter", "ijl_get_world_counter", "memhash32_seed", "memhash_seed", @@ -2790,6 +2792,7 @@ function annotate!(mod, mode) "ijl_reshape_array", "jl_reshape_array", "ijl_eqtable_get", "jl_eqtable_get", "jl_gc_run_pending_finalizers", + "ijl_try_substrtod", "jl_try_substrtod", ) if haskey(fns, fname) fn = fns[fname] From d7931135656d2120498316ff1531d7fe50a81aec Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 12 May 2024 11:46:17 -0700 Subject: [PATCH 043/495] Remove unnecessary val in thunk types (#1432) --- src/Enzyme.jl | 12 ++++++------ src/compiler.jl | 22 +++++++++++----------- src/rules/parallelrules.jl | 8 ++++---- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 981bbadbab..d1f60e0936 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -431,7 +431,7 @@ code, as well as high-order differentiation. adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal)) - thunk = Compiler.CombinedAdjointThunk{Ptr{Cvoid}, FA, rt, tt′, typeof(Val(width)), Val(ReturnPrimal)}(adjoint_ptr) + thunk = Compiler.CombinedAdjointThunk{Ptr{Cvoid}, FA, rt, tt′, width, ReturnPrimal}(adjoint_ptr) if rt <: Active args = (args..., Compiler.default_adjoint(eltype(rt))) elseif A <: Duplicated || A<: DuplicatedNoNeed || A <: BatchDuplicated || A<: BatchDuplicatedNoNeed @@ -490,11 +490,11 @@ code, as well as high-order differentiation. throw(ErrorException("Active Returns not allowed in forward mode")) end - ReturnPrimal = Val(RT <: Duplicated || RT <: BatchDuplicated) + ReturnPrimal = RT <: Duplicated || RT <: BatchDuplicated ModifiedBetween = Val(falses_from_args(Nargs+1)) - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal) - thunk = Compiler.ForwardModeThunk{Ptr{Cvoid}, FA, rt, tt′, typeof(Val(width)), ReturnPrimal}(adjoint_ptr) + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, Val(ReturnPrimal)) + thunk = Compiler.ForwardModeThunk{Ptr{Cvoid}, FA, rt, tt′, width, ReturnPrimal}(adjoint_ptr) thunk(f, args...) end @@ -849,8 +849,8 @@ result, ∂v, ∂A RT end - aug_thunk = Compiler.AugmentedForwardThunk{Ptr{Cvoid}, FA, rt, TT, Val{width}, Val(ReturnPrimal), TapeType}(primal_ptr) - adj_thunk = Compiler.AdjointThunk{Ptr{Cvoid}, FA, rt, TT, Val{width}, TapeType}(adjoint_ptr) + aug_thunk = Compiler.AugmentedForwardThunk{Ptr{Cvoid}, FA, rt, TT, width, ReturnPrimal, TapeType}(primal_ptr) + adj_thunk = Compiler.AdjointThunk{Ptr{Cvoid}, FA, rt, TT, width, TapeType}(adjoint_ptr) aug_thunk, adj_thunk end diff --git a/src/compiler.jl b/src/compiler.jl index 25cfd4d4cd..fd6d6b04d6 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5261,20 +5261,20 @@ struct CompileResult{AT, PT} end @inline (thunk::CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal} = -enzyme_call(Val(false), thunk.adjoint, CombinedAdjointThunk, Width, ReturnPrimal, TT, RT, fn, Cvoid, args...) +enzyme_call(Val(false), thunk.adjoint, CombinedAdjointThunk, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) @inline (thunk::ForwardModeThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal} = -enzyme_call(Val(false), thunk.adjoint, ForwardModeThunk, Width, ReturnPrimal, TT, RT, fn, Cvoid, args...) +enzyme_call(Val(false), thunk.adjoint, ForwardModeThunk, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) @inline (thunk::AdjointThunk{PT, FA, RT, TT, Width, TapeT})(fn::FA, args...) where {PT, FA, Width, RT, TT, TapeT} = -enzyme_call(Val(false), thunk.adjoint, AdjointThunk, Width, #=ReturnPrimal=#Val(false), TT, RT, fn, TapeT, args...) +enzyme_call(Val(false), thunk.adjoint, AdjointThunk, Val(Width), #=ReturnPrimal=#Val(false), TT, RT, fn, TapeT, args...) @inline raw_enzyme_call(thunk::AdjointThunk{PT, FA, RT, TT, Width, TapeT}, fn::FA, args...) where {PT, FA, Width, RT, TT, TapeT} = -enzyme_call(Val(true), thunk.adjoint, AdjointThunk, Width, #=ReturnPrimal=#Val(false), TT, RT, fn, TapeT, args...) +enzyme_call(Val(true), thunk.adjoint, AdjointThunk, Val(Width), #=ReturnPrimal=#Val(false), TT, RT, fn, TapeT, args...) @inline (thunk::AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeT})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal, TapeT} = -enzyme_call(Val(false), thunk.primal, AugmentedForwardThunk, Width, ReturnPrimal, TT, RT, fn, TapeT, args...) +enzyme_call(Val(false), thunk.primal, AugmentedForwardThunk, Val(Width), Val(ReturnPrimal), TT, RT, fn, TapeT, args...) @inline raw_enzyme_call(thunk::AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeT}, fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal, TapeT} = -enzyme_call(Val(true), thunk.primal, AugmentedForwardThunk, Width, ReturnPrimal, TT, RT, fn, TapeT, args...) +enzyme_call(Val(true), thunk.primal, AugmentedForwardThunk, Val(Width), Val(ReturnPrimal), TT, RT, fn, TapeT, args...) function jl_set_typeof(v::Ptr{Cvoid}, T) @@ -5381,7 +5381,7 @@ function add_one_in_place(x) return nothing end -@generated function enzyme_call(::Val{RawCall}, fptr::PT, ::Type{CC}, ::Type{Val{width}}, ::Val{returnPrimal}, tt::Type{T}, +@generated function enzyme_call(::Val{RawCall}, fptr::PT, ::Type{CC}, ::Val{width}, ::Val{returnPrimal}, tt::Type{T}, rt::Type{RT}, fn::FA, ::Type{TapeType}, args::Vararg{Any, N}) where {RawCall, PT, FA, T, RT, TapeType, N, CC, width, returnPrimal} JuliaContext() do ctx @@ -5862,8 +5862,8 @@ end compile_result = cached_compilation(job) if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient TapeType = compile_result.TapeType - AugT = AugmentedForwardThunk{typeof(compile_result.primal), FA, rt2, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal), TapeType} - AdjT = AdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, Val{width}, TapeType} + AugT = AugmentedForwardThunk{typeof(compile_result.primal), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal, TapeType} + AdjT = AdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, TapeType} return quote Base.@_inline_meta augmented = $AugT($(compile_result.primal)) @@ -5871,13 +5871,13 @@ end (augmented, adjoint) end elseif Mode == API.DEM_ReverseModeCombined - CAdjT = CombinedAdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal)} + CAdjT = CombinedAdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal} return quote Base.@_inline_meta $CAdjT($(compile_result.adjoint)) end elseif Mode == API.DEM_ForwardMode - FMT = ForwardModeThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal)} + FMT = ForwardModeThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal} return quote Base.@_inline_meta $FMT($(compile_result.adjoint)) diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index a2fefee446..c13e21c1ce 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -203,7 +203,7 @@ end push!(function_attributes(functions(mod)[fwdmodenm]), EnumAttribute("alwaysinline")) permit_inlining!(functions(mod)[fwdmodenm]) end - thunkTy = ForwardModeThunk{Ptr{Cvoid}, dupClosure ? Duplicated{funcT} : Const{funcT}, Const{Nothing}, e_tt, Val{width}, #=returnPrimal=#Val(false)} + thunkTy = ForwardModeThunk{Ptr{Cvoid}, dupClosure ? Duplicated{funcT} : Const{funcT}, Const{Nothing}, e_tt, width, #=returnPrimal=#false} subfunc = functions(mod)[fwdmodenm] elseif mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient @@ -243,10 +243,10 @@ end end if mode == API.DEM_ReverseModePrimal - thunkTy = AugmentedForwardThunk{Ptr{Cvoid}, dupClosure ? Duplicated{funcT} : Const{funcT}, Const{Nothing}, e_tt, Val{width}, #=returnPrimal=#Val(true), TapeType} + thunkTy = AugmentedForwardThunk{Ptr{Cvoid}, dupClosure ? Duplicated{funcT} : Const{funcT}, Const{Nothing}, e_tt, width, #=returnPrimal=#true, TapeType} subfunc = functions(mod)[augfwdnm] else - thunkTy = AdjointThunk{Ptr{Cvoid}, dupClosure ? Duplicated{funcT} : Const{funcT}, Const{Nothing}, e_tt, Val{width}, TapeType} + thunkTy = AdjointThunk{Ptr{Cvoid}, dupClosure ? Duplicated{funcT} : Const{funcT}, Const{Nothing}, e_tt, width, TapeType} subfunc = functions(mod)[adjointnm] end else @@ -736,4 +736,4 @@ function wait_rev(B, orig, gutils, tape) debug_from_orig!(gutils, cal, orig) callconv!(cal, callconv(orig)) return nothing -end \ No newline at end of file +end From a58e73a6fbba4a2295f711df50c2a4edfd267724 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 12 May 2024 18:49:45 -0700 Subject: [PATCH 044/495] Document API options (#1435) --- src/api.jl | 152 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/src/api.jl b/src/api.jl index 8a06999d01..017dd86b91 100644 --- a/src/api.jl +++ b/src/api.jl @@ -333,61 +333,192 @@ function zcache!(val) ccall((:EnzymeSetCLBool, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) end + +""" + printperf!(val::Bool) + +An debugging option for developers of Enzyme. If one sets this flag prior +to the first differentiation of a function, Enzyme will print (to stderr) +performance information about generated derivative programs. It will provide +debug information that warns why particular values are cached for the +reverse pass, and thus require additional computation/storage. This is particularly +helpful for debugging derivatives which OOM or otherwise run slow. +ff by default +""" function printperf!(val) ptr = cglobal((:EnzymePrintPerf, libEnzyme)) ccall((:EnzymeSetCLBool, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) end +""" + printdiffuse!(val::Bool) + +An debugging option for developers of Enzyme. If one sets this flag prior +to the first differentiation of a function, Enzyme will print (to stderr) +information about each LLVM value -- specifically whether it and its shadow +is required for computing the derivative. In contrast to `printunnecessary!`, +this flag prints debug log for the analysis which determines for each value +and shadow value, whether it can find a user which would require it to be kept +around (rather than being deleted). This is prior to any cache optimizations +and a debug log of Differential Use Analysis. This may be helpful for debugging +caching, phi node deletion, performance, and other errors. +Off by default +""" function printdiffuse!(val) ptr = cglobal((:EnzymePrintDiffUse, libEnzyme)) ccall((:EnzymeSetCLBool, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) end +""" + printtype!(val::Bool) + +An debugging option for developers of Enzyme. If one sets this flag prior +to the first differentiation of a function, Enzyme will print (to stderr) +a log of all decisions made during Type Analysis (the analysis which +Enzyme determines the type of all values in the program). This may be useful +for debugging correctness errors, illegal type analysis errors, insufficient +type information errors, correctness, and performance errors. +Off by default +""" function printtype!(val) ptr = cglobal((:EnzymePrintType, libEnzyme)) ccall((:EnzymeSetCLBool, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) end +""" + printactivity!(val::Bool) + +An debugging option for developers of Enzyme. If one sets this flag prior +to the first differentiation of a function, Enzyme will print (to stderr) +a log of all decisions made during Activity Analysis (the analysis which +determines what values/instructions are differentiated). This may be useful +for debugging MixedActivity errors, correctness, and performance errors. +Off by default +""" function printactivity!(val) ptr = cglobal((:EnzymePrintActivity, libEnzyme)) ccall((:EnzymeSetCLBool, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) end +""" + printall!(val::Bool) + +An debugging option for developers of Enzyme. If one sets this flag prior +to the first differentiation of a function, Enzyme will print (to stderr) +the LLVM function being differentiated, as well as all generated derivatives +immediately after running Enzyme (but prior to any other optimizations). +Off by default +""" function printall!(val) ptr = cglobal((:EnzymePrint, libEnzyme)) ccall((:EnzymeSetCLBool, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) end +""" + printunnecessary!(val::Bool) + +An debugging option for developers of Enzyme. If one sets this flag prior +to the first differentiation of a function, Enzyme will print (to stderr) +information about each LLVM value -- specifically whether it and its shadow +is required for computing the derivative. In contrast to `printdiffuse!`, +this flag prints the final results after running cache optimizations such +as minCut (see Recompute vs Cache Heuristics from https://c.wsmoses.com/papers/EnzymeGPU.pdf +and slides 31-33 from https://c.wsmoses.com/presentations/enzyme-sc.pdf) for a +description of the caching algorithm. This may be helpful for debugging +caching, phi node deletion, performance, and other errors. +Off by default +""" function printunnecessary!(val) ptr = cglobal((:EnzymePrintUnnecessary, libEnzyme)) ccall((:EnzymeSetCLBool, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) end +""" + inlineall!(val::Bool) + +Whether to inline all (non-recursive) functions generated by Julia within a +single compilation unit. This may improve Enzyme's ability to successfully +differentiate code and improve performance of the original and generated +derivative program. It often, however, comes with an increase in compile time. +This is off by default. +""" function inlineall!(val) ptr = cglobal((:EnzymeInline, libEnzyme)) ccall((:EnzymeSetCLBool, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) end + +""" + maxtypeoffset!(val::Bool) + +Enzyme runs a type analysis to deduce the corresponding types of all values being +differentiated. This is necessary to compute correct derivatives of various values. +To ensure this analysis temrinates, it operates on a finite lattice of possible +states. This function sets the maximum offset into a type that Enzyme will consider. +A smaller value will cause type analysis to run faster, but may result in some +necessary types not being found and result in unknown type errors. A larger value +may result in unknown type errors being resolved by searching a larger space, but +may run longer. The default setting is 512. +""" function maxtypeoffset!(val) ptr = cglobal((:MaxTypeOffset, libEnzyme)) ccall((:EnzymeSetCLInteger, libEnzyme), Cvoid, (Ptr{Cvoid}, Int64), ptr, val) end +""" + looseTypeAnalysis!(val::Bool) + +Enzyme runs a type analysis to deduce the corresponding types of all values being +differentiated. This is necessary to compute correct derivatives of various values. +For example, a copy of Float32's requires a different derivative than a memcpy of +Float64's, Ptr's, etc. In some cases Enzyme may not be able to deduce all the types +necessary and throw an unknown type error. If this is the case, open an issue. +One can silence these issues by setting `looseTypeAnalysis!(true)` which tells +Enzyme to make its best guess. This will remove the error and allow differentiation +to continue, however, it may produce incorrect results. Alternatively one can +consider increasing the space of the evaluated type lattice which gives Enzyme +more time to run a more thorough analysis through the use of `maxtypeoffset!(val)` +""" function looseTypeAnalysis!(val) ptr = cglobal((:looseTypeAnalysis, libEnzyme)) ccall((:EnzymeSetCLInteger, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) end + +""" + strictAliasing!(val::Bool) + +Whether Enzyme's type analysis will assume strict aliasing semantics. When strict +aliasing semantics are on (the default), Enzyme can propagate type information up +through conditional branches. This may lead to illegal type errors when analyzing +code with unions. Disabling strict aliasing will enable these union types to be +correctly analyzed. However, it may lead to some errors that sufficient type information +cannot be deduced. One can turn these insufficient type information errors into to +warnings by calling `looseTypeAnalysis!(true)` which tells Enzyme to use its best +guess in such scenarios. +""" function strictAliasing!(val) ptr = cglobal((:EnzymeStrictAliasing, libEnzyme)) ccall((:EnzymeSetCLInteger, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) end +""" + fast_math!(val::Bool) + +Whether generated derivatives have fast math on or off, default on. +""" function fast_math!(val) ptr = cglobal((:EnzymeFastMath, libEnzyme)) ccall((:EnzymeSetCLInteger, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) end +""" + strong_zero!(val::Bool) + +Whether to enforce multiplication by zero as enforcing a zero result even if multiplying +against a NaN or infinity. Necessary for some programs in which a value has a zero +derivative since it is unused, even if it has an otherwise infinite or nan derivative. +""" function strong_zero!(val) ptr = cglobal((:EnzymeStrongZero, libEnzyme)) ccall((:EnzymeSetCLInteger, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) @@ -438,16 +569,37 @@ function runtimeActivity() return EnzymeGetCLBool(ptr) != 0 end +""" + typeWarning!(val::Bool) + +Whether to print a warning when Type Analysis learns informatoin about a value's type +which cannot be represented in the current size of the lattice. See `maxtypeoffset` for +more information. +Off by default. +""" function typeWarning!(val) ptr = cglobal((:EnzymeTypeWarning, libEnzyme)) ccall((:EnzymeSetCLInteger, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) end +""" + instname!(val::Bool) + +Whether to add a name to all LLVM values. This may be helpful for debugging generated +programs, both primal and derivative. +Off by default. +""" function instname!(val) ptr = cglobal((:EnzymeNameInstructions, libEnzyme)) ccall((:EnzymeSetCLBool, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) end +""" + memmove_warning!(val::Bool) + +Whether to issue a warning when differentiating memmove. +Off by default. +""" function memmove_warning!(val) ptr = cglobal((:EnzymeMemmoveWarning, libEnzyme)) ccall((:EnzymeSetCLBool, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) From e58b8e6d99db400c653cd69e8e2e242f3ed3367e Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 12 May 2024 18:50:24 -0700 Subject: [PATCH 045/495] Modernize developer instructions to use Preferences.tom (#1436) --- docs/src/dev_docs.md | 62 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 59 insertions(+), 3 deletions(-) diff --git a/docs/src/dev_docs.md b/docs/src/dev_docs.md index 06f235f9a6..6a3bd54742 100644 --- a/docs/src/dev_docs.md +++ b/docs/src/dev_docs.md @@ -1,15 +1,71 @@ # Enzyme developer documentation -## Development of Enzyme and Enzyme.jl together +## Development of Enzyme and Enzyme.jl together (recommended) -Normally Enzyme.jl downloads and install Enzyme for the user automatically since Enzyme needs to be built against +Normally Enzyme.jl downloads and install [Enzyme](github.com/EnzymeAD/Enzyme) for the user automatically since Enzyme needs to be built against Julia bundeled LLVM. In case that you are making updates to Enzyme and want to test them against Enzyme.jl the instructions below should help you get started. +Start Julia in your development copy of Enzyme.jl and initialize the deps project + +```bash +~/s/Enzyme.jl (master)> julia --project=deps +``` + +```julia-repl +julia> # Hit the `]` key to enter package repl. +(deps) pkg> instantiate +``` + +We can now build a custom version of Enzyme for use in Enzyme.jl. To build the latest commit on the main branch of Enzyme, run the following. +It may take a few minutes to compile fully. + +```bash +~/s/Enzyme.jl (master)> julia --project=deps deps/build_local.jl +``` + +You will now find a file LocalPrefernces.toml which has been generated and contains a path to the new Enzyme_jll binary you have built. +To use your Enzyme_jll instead of the default shipped by Enzyme.jl, ensure that this file is at the root of any Julia project you wish +to test it with *and* that the Julia project has Enzyme_jll as an explicit dependency. Note that an indirect dependency here is not +sufficient (e.g. just because a project depends on Enzyme.jl, which depends on Enzyme_jll, does not mean that your project will pick up +this file unless you also add a direct dependency to Enzyme_jll). + +To test whether your project found the custom version of Enzyme_jll, you can inspect the path of the Enzyme_jll library in use as follows. + +```bash +~/my/project.jl (master)> julia --project=. +``` + +```julia-repl +julia> using Enzyme_jll +julia> Enzyme_jll.libEnzyme_path +"${JULIA_PKG_DEVDIR}/Enzyme_jll/override/lib/LLVMEnzyme-9.so" +``` + +This should correspond to the path in the LocalPreferences.toml you just generated. + +Note that your system can have only one custom built Enzyme_jll at a time. If you build one version for one version of Enzyme or Julia +and later build a new verison of Enzyme, it remove the old build. + +Note that Julia versions are tightly coupled and you cannot use an Enzyme_jll built for one version of Julia for another version of Julia. + +The same script can also be used to build Enzyme_jll for a branch other than main as follows. + +```bash +~/s/Enzyme.jl (master)> julia --project=deps deps/build_local.jl --branch mybranch +``` + +It can also be used to build Enzyme_jll from a local copy of Enzyme on your machine, which do not need to be committed to git. + +```bash +~/s/Enzyme.jl (master)> julia --project=deps deps/build_local.jl ../path/to/Enzyme +``` + +## Development of Enzyme and Enzyme.jl together (manual) Start Julia in your development copy of Enzyme.jl ```bash -~/s/Enzyme (master)> julia --project=. +~/s/Enzyme.jl (master)> julia --project=. ``` Then create a development copy of Enzyme_jll and activate it within. From 8002ac5b6326e011dc2ba7ea09b82c3cc965eb90 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 12 May 2024 22:45:09 -0700 Subject: [PATCH 046/495] Add macro to import rrule from chainrules (#996) * Add import frule functionality * Add wip rrule importer * fix --- ext/EnzymeChainRulesCoreExt.jl | 171 ++++++++++++++++++++++++++++++++- src/Enzyme.jl | 6 ++ test/ext/chainrulescore.jl | 61 ++++++++++++ 3 files changed, 237 insertions(+), 1 deletion(-) diff --git a/ext/EnzymeChainRulesCoreExt.jl b/ext/EnzymeChainRulesCoreExt.jl index 2c8d180a57..4549aa84b8 100644 --- a/ext/EnzymeChainRulesCoreExt.jl +++ b/ext/EnzymeChainRulesCoreExt.jl @@ -104,4 +104,173 @@ function Enzyme._import_frule(fn, tys...) end -end # module \ No newline at end of file +""" + import_rrule(::fn, tys...) + +Automatically import a ChainRules.rrule as a custom reverse mode EnzymeRule. When called in batch mode, this +will end up calling the primal multiple times which results in slower code. This macro assumes that the underlying +function to be imported is read-only, and returns a Duplicated or Const object. This macro also assumes that the +inputs permit a .+= operation and that the output has a valid Enzyme.make_zero function defined. It also assumes +that overwritten(x) accurately describes if there is any non-preserved data from forward to reverse, not just +the outermost data structure being overwritten as provided by the specification. + +Finally, this macro falls back to almost always caching all of the inputs, even if it may not be needed for the +derivative computation. + +As a result, this auto importer is also likely to be slower than writing your own rule, and may also be slower +than not having a rule at all. + +Use with caution. + +``` +Enzyme.@import_rrule(typeof(Base.sort), Any); +``` +""" +function Enzyme._import_rrule(fn, tys...) + vals = [] + valtys = [] + exprs = [] + primals = [] + tangents = [] + tangentsi = [] + anns = [] + nothings = [] + ntys = length(tys) + act_res = Expr[:(fn isa Active ? res[1] : nothing)] + invertcomb = Expr[] + # TODO at one point extend api to support active fn's + # push!(nothings, :nothing) + # push!(invertcomb, + # quote + # fn isa Active ? ( + # (EnzymeRules.width(config) == 1) ? tcomb[1][1] : + # ntuple(Val(EnzymeRules.width(config))) do batch_i + # Base.@_inline_meta + # tcomb[batch_i][1] + # end + # ) : nothing + # end) + + for (i, ty) in enumerate(tys) + push!(nothings, :(nothing)) + val = Symbol("arg_$i") + TA = Symbol("AN_$i") + e = :($val::$TA) + push!(anns, :($TA <: Annotation{<:$ty})) + push!(vals, val) + push!(exprs, e) + primal = Symbol("primcopy_$i") + push!(primals, primal) + push!(valtys, :($primal = $(EnzymeRules.overwritten)(config)[$i+1] ? deepcopy($val.val) : $val.val)) + push!(tangents, :($val isa $Enzyme.Const ? $ChainRulesCore.NoTangent() : $val.dval)) + push!(tangentsi, :($val isa $Enzyme.Const ? $ChainRulesCore.NoTangent() : $val.dval[i])) + push!(act_res, :($val isa Active ? (res[$i+1] isa $ChainRulesCore.NoTangent ? zero($val) : $ChainRulesCore.unthunk(res[$i+1]) ) : nothing)) + push!(invertcomb, quote + $val isa Active ? ( + (EnzymeRules.width(config) == 1) ? tcomb[1][$i+1] : + ntuple(Val(EnzymeRules.width(config))) do batch_i + Base.@_inline_meta + tcomb[batch_i][$i+1] + end + ) : nothing + end) + end + + + quote + function EnzymeRules.augmented_primal(config, fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {RetAnnotation, FA<:Annotation{<:$(esc(fn))}, $(anns...)} + $(valtys...) + + res, pullback = if RetAnnotation <: Const + (fn.val($(primals...); kwargs...), nothing) + else + $ChainRulesCore.rrule(fn.val, $(primals...); kwargs...) + end + + primal = if EnzymeRules.needs_primal(config) + res + else + nothing + end + + shadow = if !EnzymeRules.needs_shadow(config) + nothing + else + if EnzymeRules.width(config) == 1 + Enzyme.make_zero(res) + else + ntuple(Val(EnzymeRules.width(config))) do j + Base.@_inline_meta + Enzyme.make_zero(res) + end + end + end + + return EnzymeRules.AugmentedReturn(primal, shadow, (shadow, pullback)) + end + + function EnzymeRules.reverse(config, fn::FA, ::Type{RetAnnotation}, tape::TapeTy, $(exprs...); kwargs...) where {RetAnnotation, TapeTy, FA<:Annotation{<:$(esc(fn))}, $(anns...)} + if !(RetAnnotation <: Const) + shadow, pullback = tape + + tcomb = ntuple(Val(EnzymeRules.width(config))) do batch_i + Base.@_inline_meta + shad = EnzymeRules.width(config) == 1 ? shadow : shadow[batch_i] + res = pullback(shad) + + for (cr, en) in zip(res, (fn, $(vals...),)) + if en isa Const || cr isa $ChainRulesCore.NoTangent + continue + end + if en isa Active + continue + end + if EnzymeRules.width(config) == 1 + en.dval .+= cr + else + en.dval[batch_i] .+= cr + end + end + + ($(act_res...),) + end + + return ($(invertcomb...),) + end + + return ($(nothings...),) + end + + function EnzymeRules.reverse(config, fn::FA, dval::Active{RetAnnotation}, tape::TapeTy, $(exprs...); kwargs...) where {RetAnnotation, TapeTy, FA<:Annotation{<:$(esc(fn))}, $(anns...)} + oldshadow, pullback = tape + + shadow = dval.val + + tcomb = ntuple(Val(EnzymeRules.width(config))) do batch_i + Base.@_inline_meta + shad = EnzymeRules.width(config) == 1 ? shadow : shadow[batch_i] + res = pullback(shad) + + for (cr, en) in zip(res, (fn, $(vals...),)) + if en isa Const || cr isa $ChainRulesCore.NoTangent + continue + end + if en isa Active + continue + end + if EnzymeRules.width(config) == 1 + en.dval .+= cr + else + en.dval[batch_i] .+= cr + end + end + + ($(act_res...),) + end + + return ($(invertcomb...),) + end + end +end + +end # module diff --git a/src/Enzyme.jl b/src/Enzyme.jl index d1f60e0936..9087055022 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1231,4 +1231,10 @@ macro import_frule(args...) return _import_frule(args...) end +function _import_rrule end # defined in EnzymeChainRulesCoreExt extension + +macro import_rrule(args...) + return _import_rrule(args...) +end + end # module diff --git a/test/ext/chainrulescore.jl b/test/ext/chainrulescore.jl index 217176a657..38fe07e87f 100644 --- a/test/ext/chainrulescore.jl +++ b/test/ext/chainrulescore.jl @@ -64,6 +64,67 @@ fdiff(f, x::Number) = autodiff(Forward, f, Duplicated, Duplicated(x, one(x)))[2] end end +rdiff(f, x::Number) = autodiff(Reverse, f, Active, Active(x))[1][1] + +@testset "import_rrule" begin + f1(x) = 2*x + ChainRulesCore.@scalar_rule f1(x) (5*one(x),) + Enzyme.@import_rrule typeof(f1) Any + @test rdiff(f1, 1f0) === 5f0 + @test rdiff(f1, 1.0) === 5.0 + + # specific signature + f2(x) = 2*x + ChainRulesCore.@scalar_rule f2(x) (5*one(x),) + Enzyme.@import_rrule typeof(f2) Float32 + @test rdiff(f2, 1f0) === 5f0 + @test rdiff(f2, 1.0) === 2.0 + + # two arguments + f3(x, y) = 2*x + y + ChainRulesCore.@scalar_rule f3(x, y) (5*one(x), y) + Enzyme.@import_rrule typeof(f3) Any Any + @test rdiff(x -> f3(x, 1.0), 2.) === 5.0 + @test rdiff(y -> f3(1.0, y), 2.) === 2.0 + + @testset "batch duplicated" begin + x = [1.0, 2.0, 0.0] + Enzyme.@import_rrule typeof(Base.sort) Any + + test_reverse(Base.sort, Duplicated, (x, Duplicated)) + # Unsupported by EnzymeTestUtils + # test_reverse(Base.sort, Duplicated, (x, DuplicatedNoNeed)) + test_reverse(Base.sort, DuplicatedNoNeed, (x, Duplicated)) + # Unsupported by EnzymeTestUtils + # test_reverse(Base.sort, DuplicatedNoNeed, (x, DuplicatedNoNeed)) + test_reverse(Base.sort, Const, (x, Duplicated)) + # Unsupported by EnzymeTestUtils + # test_reverse(Base.sort, Const, (x, DuplicatedNoNeed)) + + test_reverse(Base.sort, Const, (x, Const)) + + # ChainRules does not support this case (returning notangent) + # test_reverse(Base.sort, Duplicated, (x, Const)) + # test_reverse(Base.sort, DuplicatedNoNeed, (x, Const)) + + test_reverse(Base.sort, BatchDuplicated, (x, BatchDuplicated)) + # Unsupported by EnzymeTestUtils + # test_reverse(Base.sort, BatchDuplicated, (x, BatchDuplicatedNoNeed)) + test_reverse(Base.sort, BatchDuplicatedNoNeed, (x, BatchDuplicated)) + # Unsupported by EnzymeTestUtils + # test_reverse(Base.sort, BatchDuplicatedNoNeed, (x, BatchDuplicatedNoNeed)) + test_reverse(Base.sort, Const, (x, BatchDuplicated)) + # Unsupported by EnzymeTestUtils + # test_reverse(Base.sort, Const, (x, BatchDuplicatedNoNeed)) + + # ChainRules does not support this case (returning notangent) + # test_reverse(Base.sort, BatchDuplicated, (x, Const)) + # test_reverse(Base.sort, BatchDuplicatedNoNeed, (x, Const)) + end +end + + + From d7ef5e75a04577ef2202237ea25974b2c93fc45b Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 12 May 2024 22:52:59 -0700 Subject: [PATCH 047/495] Update dev_docs.md --- docs/src/dev_docs.md | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/src/dev_docs.md b/docs/src/dev_docs.md index 6a3bd54742..85b0ee6150 100644 --- a/docs/src/dev_docs.md +++ b/docs/src/dev_docs.md @@ -24,13 +24,13 @@ It may take a few minutes to compile fully. ~/s/Enzyme.jl (master)> julia --project=deps deps/build_local.jl ``` -You will now find a file LocalPrefernces.toml which has been generated and contains a path to the new Enzyme_jll binary you have built. -To use your Enzyme_jll instead of the default shipped by Enzyme.jl, ensure that this file is at the root of any Julia project you wish -to test it with *and* that the Julia project has Enzyme_jll as an explicit dependency. Note that an indirect dependency here is not -sufficient (e.g. just because a project depends on Enzyme.jl, which depends on Enzyme_jll, does not mean that your project will pick up -this file unless you also add a direct dependency to Enzyme_jll). +You will now find a file LocalPrefernces.toml which has been generated and contains a path to the new Enzyme\_jll binary you have built. +To use your Enzyme\_jll instead of the default shipped by Enzyme.jl, ensure that this file is at the root of any Julia project you wish +to test it with *and* that the Julia project has Enzyme\_jll as an explicit dependency. Note that an indirect dependency here is not +sufficient (e.g. just because a project depends on Enzyme.jl, which depends on Enzyme\_jll, does not mean that your project will pick up +this file unless you also add a direct dependency to Enzyme\_jll). -To test whether your project found the custom version of Enzyme_jll, you can inspect the path of the Enzyme_jll library in use as follows. +To test whether your project found the custom version of Enzyme\_jll, you can inspect the path of the Enzyme\_jll library in use as follows. ```bash ~/my/project.jl (master)> julia --project=. @@ -44,18 +44,18 @@ julia> Enzyme_jll.libEnzyme_path This should correspond to the path in the LocalPreferences.toml you just generated. -Note that your system can have only one custom built Enzyme_jll at a time. If you build one version for one version of Enzyme or Julia +Note that your system can have only one custom built Enzyme\_jll at a time. If you build one version for one version of Enzyme or Julia and later build a new verison of Enzyme, it remove the old build. -Note that Julia versions are tightly coupled and you cannot use an Enzyme_jll built for one version of Julia for another version of Julia. +Note that Julia versions are tightly coupled and you cannot use an Enzyme\_jll built for one version of Julia for another version of Julia. -The same script can also be used to build Enzyme_jll for a branch other than main as follows. +The same script can also be used to build Enzyme\_jll for a branch other than main as follows. ```bash ~/s/Enzyme.jl (master)> julia --project=deps deps/build_local.jl --branch mybranch ``` -It can also be used to build Enzyme_jll from a local copy of Enzyme on your machine, which do not need to be committed to git. +It can also be used to build Enzyme\_jll from a local copy of Enzyme on your machine, which do not need to be committed to git. ```bash ~/s/Enzyme.jl (master)> julia --project=deps deps/build_local.jl ../path/to/Enzyme @@ -68,7 +68,7 @@ Start Julia in your development copy of Enzyme.jl ~/s/Enzyme.jl (master)> julia --project=. ``` -Then create a development copy of Enzyme_jll and activate it within. +Then create a development copy of Enzyme\_jll and activate it within. ```julia-repl julia> using Enzyme_jll From f8d6b47d15cf51eaec063bdea1dec9307cf0a146 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 12 May 2024 22:55:53 -0700 Subject: [PATCH 048/495] Update api.jl --- src/api.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/api.jl b/src/api.jl index 017dd86b91..84080d4b49 100644 --- a/src/api.jl +++ b/src/api.jl @@ -356,7 +356,7 @@ end An debugging option for developers of Enzyme. If one sets this flag prior to the first differentiation of a function, Enzyme will print (to stderr) information about each LLVM value -- specifically whether it and its shadow -is required for computing the derivative. In contrast to `printunnecessary!`, +is required for computing the derivative. In contrast to [`printunnecessary!`](@ref), this flag prints debug log for the analysis which determines for each value and shadow value, whether it can find a user which would require it to be kept around (rather than being deleted). This is prior to any cache optimizations @@ -420,10 +420,10 @@ end An debugging option for developers of Enzyme. If one sets this flag prior to the first differentiation of a function, Enzyme will print (to stderr) information about each LLVM value -- specifically whether it and its shadow -is required for computing the derivative. In contrast to `printdiffuse!`, +is required for computing the derivative. In contrast to [`printdiffuse!`](@ref), this flag prints the final results after running cache optimizations such -as minCut (see Recompute vs Cache Heuristics from https://c.wsmoses.com/papers/EnzymeGPU.pdf -and slides 31-33 from https://c.wsmoses.com/presentations/enzyme-sc.pdf) for a +as minCut (see Recompute vs Cache Heuristics from [this paper](https://c.wsmoses.com/papers/EnzymeGPU.pdf) +and slides 31-33 from [this presentation](https://c.wsmoses.com/presentations/enzyme-sc.pdf)) for a description of the caching algorithm. This may be helpful for debugging caching, phi node deletion, performance, and other errors. Off by default @@ -477,7 +477,7 @@ One can silence these issues by setting `looseTypeAnalysis!(true)` which tells Enzyme to make its best guess. This will remove the error and allow differentiation to continue, however, it may produce incorrect results. Alternatively one can consider increasing the space of the evaluated type lattice which gives Enzyme -more time to run a more thorough analysis through the use of `maxtypeoffset!(val)` +more time to run a more thorough analysis through the use of [`maxtypeoffset!`](@ref) """ function looseTypeAnalysis!(val) ptr = cglobal((:looseTypeAnalysis, libEnzyme)) @@ -494,7 +494,7 @@ through conditional branches. This may lead to illegal type errors when analyzin code with unions. Disabling strict aliasing will enable these union types to be correctly analyzed. However, it may lead to some errors that sufficient type information cannot be deduced. One can turn these insufficient type information errors into to -warnings by calling `looseTypeAnalysis!(true)` which tells Enzyme to use its best +warnings by calling [`looseTypeAnalysis!`](@ref)`(true)` which tells Enzyme to use its best guess in such scenarios. """ function strictAliasing!(val) @@ -573,7 +573,7 @@ end typeWarning!(val::Bool) Whether to print a warning when Type Analysis learns informatoin about a value's type -which cannot be represented in the current size of the lattice. See `maxtypeoffset` for +which cannot be represented in the current size of the lattice. See [`maxtypeoffset!`](@ref) for more information. Off by default. """ From c65a2f2c54c35f1550a5be6971e4ea83267d31b6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 12 May 2024 23:36:04 -0700 Subject: [PATCH 049/495] Fix static arrays on forward mode gradient call (#1438) --- Project.toml | 19 +++++++++++-------- ext/EnzymeStaticArraysExt.jl | 24 ++++++++++++++++++++++++ test/runtests.jl | 8 ++++++++ 3 files changed, 43 insertions(+), 8 deletions(-) create mode 100644 ext/EnzymeStaticArraysExt.jl diff --git a/Project.toml b/Project.toml index 0188836aac..f2519af7c9 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,16 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[extensions] +EnzymeChainRulesCoreExt = "ChainRulesCore" +EnzymeSpecialFunctionsExt = "SpecialFunctions" +EnzymeStaticArraysExt = "StaticArrays" + [compat] CEnum = "0.4, 0.5" ChainRulesCore = "1" @@ -28,14 +38,7 @@ Preferences = "1.4" SpecialFunctions = "1, 2" julia = "1.6" -[extensions] -EnzymeChainRulesCoreExt = "ChainRulesCore" -EnzymeSpecialFunctionsExt = "SpecialFunctions" - [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" - -[weakdeps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl new file mode 100644 index 0000000000..672d1c03bc --- /dev/null +++ b/ext/EnzymeStaticArraysExt.jl @@ -0,0 +1,24 @@ +module EnzymeStaticArraysExt + +using StaticArrays +using Enzyme + +@inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}) where {S, T, N, L} + ntuple(Val(L)) do i + Base.@_inline_meta + StaticArrays.SArray{S, T, N, L}(Enzyme.onehot(NTuple{L, T})[i]) + end +end + +@inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}, start, endl) where {S, T, N, L} + ntuple(Val(endl-start+1)) do i + Base.@_inline_meta + StaticArrays.SArray{S, T, N, L}( + ntuple(Val(N)) do idx + Base.@_inline_meta + return (i + start - 1 == idx) ? 1.0 : 0.0 + end) + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index a2fc9c6eda..40e4280dc6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2496,6 +2496,14 @@ end dx = Enzyme.gradient(Reverse, prod, x) @test dx isa SArray @test dx ≈ [0 30 0] + +@static if VERSION ≥ v"1.9-" + x = @SArray [5.0 0.0 6.0] + dx = Enzyme.gradient(Forward, prod, x) + @test dx[1] ≈ 0 + @test dx[2] ≈ 30 + @test dx[3] ≈ 0 +end end From f4acb2b5e64d97de5fdacdd5e2d8ddd46a89444d Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 13 May 2024 00:29:42 -0700 Subject: [PATCH 050/495] Improve recursion performance (#1439) --- src/compiler/optimize.jl | 23 +++++++++++++++++++---- test/runtests.jl | 14 ++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 1d88b56219..7583d957c0 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -844,6 +844,7 @@ end function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String}) calls = LLVM.CallInst[] + hasUser = false for u in LLVM.uses(fn) un = LLVM.user(u) @@ -862,13 +863,11 @@ function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String}) # Something with a user is not permitted for u2 in LLVM.uses(un) - return false + hasUser = true + break end push!(calls, un) end - if length(calls) == 0 - return false - end done = Set{LLVM.Function}() todo = LLVM.Function[fn] @@ -909,6 +908,22 @@ function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String}) end end end + + changed = false + attrs = collect(function_attributes(fn)) + if !any(kind(attr) == kind(EnumAttribute("readonly")) for attr in attrs) && !any(kind(attr) == kind(EnumAttribute("readnone")) for attr in attrs) + if any(kind(attr) == kind(EnumAttribute("writeonly")) for attr in attrs) + delete!(function_attributes(fn), EnumAttribute("writeonly")) + push!(function_attributes(fn), EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("readonly")) + end + changed = true + end + + if length(calls) == 0 || hasUser + return changed + end for c in calls parentf = LLVM.parent(LLVM.parent(c)) diff --git a/test/runtests.jl b/test/runtests.jl index 40e4280dc6..c512786a98 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -196,6 +196,8 @@ end end end +sumsq2(x) = sum(abs2, x) +sumsin(x) = sum(sin, x) @testset "Recursion optimization" begin # Test that we can successfully optimize out the augmented primal from the recursive divide and conquer fn = sprint() do io @@ -203,6 +205,18 @@ end end @test occursin("diffe",fn) @test !occursin("aug",fn) + + fn = sprint() do io + Enzyme.Compiler.enzyme_code_llvm(io, sumsq2, Active, Tuple{Duplicated{Vector{Float64}}}) + end + @test occursin("diffe",fn) + @test !occursin("aug",fn) + + fn = sprint() do io + Enzyme.Compiler.enzyme_code_llvm(io, sumsin, Active, Tuple{Duplicated{Vector{Float64}}}) + end + @test occursin("diffe",fn) + @test !occursin("aug",fn) end # @testset "Split Tape" begin From 8b2798131a658b45cb9287a051f59afc3874e912 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 13 May 2024 11:14:21 -0700 Subject: [PATCH 051/495] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index f2519af7c9..d85a82e3d3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.5" +version = "0.12.6" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -30,7 +30,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7" -Enzyme_jll = "0.0.108" +Enzyme_jll = "0.0.109" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" From fbcd7d3b28f2798cd1d2a617c2024628c7d599cb Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 13 May 2024 14:55:52 -0700 Subject: [PATCH 052/495] Nicer method errors (#1444) --- src/compiler.jl | 24 ++++++++++++++++++------ test/runtests.jl | 10 ++++++++++ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index fd6d6b04d6..2f17e8a4fb 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5261,13 +5261,13 @@ struct CompileResult{AT, PT} end @inline (thunk::CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal} = -enzyme_call(Val(false), thunk.adjoint, CombinedAdjointThunk, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) +enzyme_call(Val(false), thunk.adjoint, CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) @inline (thunk::ForwardModeThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal} = -enzyme_call(Val(false), thunk.adjoint, ForwardModeThunk, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) +enzyme_call(Val(false), thunk.adjoint, ForwardModeThunk{PT, FA, RT, TT, Width, ReturnPrimal}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) @inline (thunk::AdjointThunk{PT, FA, RT, TT, Width, TapeT})(fn::FA, args...) where {PT, FA, Width, RT, TT, TapeT} = -enzyme_call(Val(false), thunk.adjoint, AdjointThunk, Val(Width), #=ReturnPrimal=#Val(false), TT, RT, fn, TapeT, args...) +enzyme_call(Val(false), thunk.adjoint, AdjointThunk{PT, FA, RT, TT, Width, TapeT}, Val(Width), #=ReturnPrimal=#Val(false), TT, RT, fn, TapeT, args...) @inline raw_enzyme_call(thunk::AdjointThunk{PT, FA, RT, TT, Width, TapeT}, fn::FA, args...) where {PT, FA, Width, RT, TT, TapeT} = enzyme_call(Val(true), thunk.adjoint, AdjointThunk, Val(Width), #=ReturnPrimal=#Val(false), TT, RT, fn, TapeT, args...) @@ -5398,11 +5398,23 @@ end if !RawCall if rettype <: Active - @assert length(argtypes) + is_adjoint + needs_tape == length(argexprs) + if length(argtypes) + is_adjoint + needs_tape != length(argexprs) + return quote + throw(MethodError($CC($fptr), $args)) + end + end elseif rettype <: Const - @assert length(argtypes) + needs_tape == length(argexprs) + if length(argtypes) + needs_tape != length(argexprs) + return quote + throw(MethodError($CC($fptr), $args)) + end + end else - @assert length(argtypes) + needs_tape == length(argexprs) + if length(argtypes) + needs_tape != length(argexprs) + return quote + throw(MethodError($CC($fptr), $args)) + end + end end end diff --git a/test/runtests.jl b/test/runtests.jl index c512786a98..80d8a797ed 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -947,6 +947,16 @@ end @test res.y == nothing end +@testset "Methoe errors" begin + fwd = Enzyme.autodiff_thunk(Forward, Const{typeof(sum)}, Duplicated, Duplicated{Vector{Float64}}) + @test_throws MethodError fwd(ones(10)) + @test_throws MethodError fwd(Duplicated(ones(10), ones(10))) + @test_throws MethodError fwd(Const(first), Duplicated(ones(10), ones(10))) + # TODO + # @test_throws MethodError fwd(Const(sum), Const(ones(10))) + fwd(Const(sum), Duplicated(ones(10), ones(10))) +end + @testset "Generic Active Union Return" begin function generic_union_ret(A) From c6fb9368666d5d922bdf8c9d73e07e410f54d00b Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 13 May 2024 17:56:37 -0400 Subject: [PATCH 053/495] Mark newarray as noalias --- src/compiler.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 2f17e8a4fb..67d64783f0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2815,7 +2815,8 @@ function annotate!(mod, mode) "jl_alloc_array_1d", "jl_alloc_array_2d", "jl_alloc_array_3d", "ijl_alloc_array_1d", "ijl_alloc_array_2d", "ijl_alloc_array_3d", "jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash", - "jl_f_tuple", "ijl_f_tuple", "jl_new_structv", "ijl_new_structv") + "jl_f_tuple", "ijl_f_tuple", "jl_new_structv", "ijl_new_structv", + "ijl_new_array", "jl_new_array") if haskey(fns, boxfn) fn = fns[boxfn] push!(return_attributes(fn), LLVM.EnumAttribute("noalias", 0)) From 27859773e566377542c11c445aa936ed69dbe14d Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Mon, 13 May 2024 21:46:32 -0700 Subject: [PATCH 054/495] Fix eqtableget bug --- src/rules/llvmrules.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 1ce53d09f2..fa2efceed8 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -576,6 +576,15 @@ function eqtableget_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if is_constant_value(gutils, orig) return true end + + mode = get_mode(gutils) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) + if needsShadowP[] == 0 + return false + end width = get_width(gutils) From 67ca3a1550556a5fb42afd6e3299295e672d5b1e Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 14 May 2024 18:51:01 -0400 Subject: [PATCH 055/495] More informative nullptr error --- src/absint.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/absint.jl b/src/absint.jl index 6216c1a769..36c1689832 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -114,7 +114,11 @@ function absint(arg::LLVM.Value, partial::Bool=false) ptr = unsafe_load(reinterpret(Ptr{Ptr{Cvoid}}, convert(UInt, ce))) if ptr == C_NULL # XXX: Is this correct? - @error "Found null pointer" arg + bt = GPUCompiler.backtrace(arg) + btstr = sprint() do io + Base.show_backtrace(io, bt) + end + @error "Found null pointer at\n $btstr" arg return (false, nothing) end typ = Base.unsafe_pointer_to_objref(ptr) From cc8ceb6e9a5be2332dfa9688480c7e8266b95a0f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 14 May 2024 23:50:18 -0400 Subject: [PATCH 056/495] Improve calling conv error prints --- src/compiler/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index fc884f4561..4b38256e61 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -285,8 +285,8 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev: end println(io, "ctype = ", ctype) println(io, "tape = ", tape) - println(io, "val = ", val) - println(io, "prev = ", prev) + println(io, "val = ", string(val)) + println(io, "prev = ", string(prev)) println(io, "lidxs = ", lidxs) println(io, "ridxs = ", ridxs) println(io, "tape_type(tape) = ", tape_type(tape)) From 446f04ed968e94e1fa8d3cf9ee22dcd4b47f460e Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 15 May 2024 09:38:57 -0700 Subject: [PATCH 057/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d85a82e3d3..c05a708ba2 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7" -Enzyme_jll = "0.0.109" +Enzyme_jll = "0.0.110" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" From 94b4ea43707ea14f4d1b9a199df287a409e308c8 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 15 May 2024 09:48:18 -0700 Subject: [PATCH 058/495] recursion perf v3 (#1440) * recursion perf v3 * add more yelling --- src/compiler.jl | 1 + src/compiler/optimize.jl | 8 ++++++++ test/runtests.jl | 41 ++++++++++++++++++++++++++++++++++------ 3 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 67d64783f0..cedd6a6ab0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1348,6 +1348,7 @@ function emit_error(B::LLVM.IRBuilder, orig, string) # 2. Call error function and insert unreachable ct = call!(B, funcT, func, LLVM.Value[globalstring_ptr!(B, string)]) LLVM.API.LLVMAddCallSiteAttribute(ct, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), EnumAttribute("noreturn")) + LLVM.API.LLVMAddCallSiteAttribute(ct, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("enzyme_error")) return ct # FIXME(@wsmoses): Allow for emission of new BB in this code path # unreachable!(B) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 7583d957c0..0c26d3fb46 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -833,6 +833,10 @@ function mayWriteToMemory(inst::LLVM.Instruction)::Bool if kind(attr) == kind(EnumAttribute("readonly")) return false end + # Note out of spec, and only legal in context of removing unused calls + if kind(attr) == kind(StringAttribute("enzyme_error")) + return false + end end Libc.free(Attrs) return true @@ -1509,6 +1513,8 @@ function checkNoAssumeFalse(mod, shouldshow=false) end end +cse!(pm) = LLVM.API.LLVMAddEarlyCSEPass(pm) + function removeDeadArgs!(mod::LLVM.Module) # We need to run globalopt first. This is because remove dead args will otherwise # take internal functions and replace their args with undef. Then on LLVM up to @@ -1593,6 +1599,7 @@ function removeDeadArgs!(mod::LLVM.Module) instruction_combining!(pm) alloc_opt!(pm) scalar_repl_aggregates_ssa!(pm) # SSA variant? + cse!(pm) run!(pm, mod) end propagate_returned!(mod) @@ -1615,6 +1622,7 @@ function removeDeadArgs!(mod::LLVM.Module) API.EnzymeAddAttributorLegacyPass(pm) end end + cse!(pm) run!(pm, mod) end post_attr!(mod) diff --git a/test/runtests.jl b/test/runtests.jl index 80d8a797ed..b0107a786a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -198,25 +198,54 @@ end sumsq2(x) = sum(abs2, x) sumsin(x) = sum(sin, x) +sqrtsumsq2(x) = (sum(abs2, x)*sum(abs2,x)) @testset "Recursion optimization" begin # Test that we can successfully optimize out the augmented primal from the recursive divide and conquer fn = sprint() do io - Enzyme.Compiler.enzyme_code_llvm(io, sum, Active, Tuple{Duplicated{Vector{Float64}}}) + Enzyme.Compiler.enzyme_code_llvm(io, sum, Active, Tuple{Duplicated{Vector{Float64}}}; dump_module=true) end @test occursin("diffe",fn) - @test !occursin("aug",fn) + # TODO we need to fix julia to remove unused bounds checks + # @test !occursin("aug",fn) fn = sprint() do io - Enzyme.Compiler.enzyme_code_llvm(io, sumsq2, Active, Tuple{Duplicated{Vector{Float64}}}) + Enzyme.Compiler.enzyme_code_llvm(io, sumsq2, Active, Tuple{Duplicated{Vector{Float64}}}; dump_module=true) end @test occursin("diffe",fn) - @test !occursin("aug",fn) + # TODO we need to fix julia to remove unused bounds checks + # @test !occursin("aug",fn) fn = sprint() do io - Enzyme.Compiler.enzyme_code_llvm(io, sumsin, Active, Tuple{Duplicated{Vector{Float64}}}) + Enzyme.Compiler.enzyme_code_llvm(io, sumsin, Active, Tuple{Duplicated{Vector{Float64}}}; dump_module=true) end @test occursin("diffe",fn) - @test !occursin("aug",fn) + # TODO we need to fix julia to remove unused bounds checks + # @test !occursin("aug",fn) + + Enzyme.API.printall!(true) + fn = sprint() do io + Enzyme.Compiler.enzyme_code_llvm(io, sqrtsumsq2, Active, Tuple{Duplicated{Vector{Float64}}}; dump_module=true) + end + Enzyme.API.printall!(false) + @test occursin("diffe",fn) + if count("call fastcc void @diffejulia__mapreduce", fn) != 1 + println(sprint() do io + Enzyme.Compiler.enzyme_code_llvm(io, sqrtsumsq2, Active, Tuple{Duplicated{Vector{Float64}}}; dump_module=true, run_enzyme=false, optimize=false) + end) + println(sprint() do io + Enzyme.Compiler.enzyme_code_llvm(io, sqrtsumsq2, Active, Tuple{Duplicated{Vector{Float64}}}; dump_module=true, run_enzyme=false) + end) + println(fn) + end + # TODO per system being run on the indexing in the mapreduce is broken + # @test count("call fastcc void @diffejulia__mapreduce", fn) == 1 + # TODO we need to have enzyme circumvent the double pointer issue by also considering a broader + # no memory overwritten state [in addition to the arg-based variant] + @test_broken !occursin("aug",fn) + + x = ones(100) + dx = zeros(100) + Enzyme.autodiff(Reverse, sqrtsumsq2, Duplicated(x,dx)) end # @testset "Split Tape" begin From 9a63880dc91223304316c820f3f792a381cfcaf9 Mon Sep 17 00:00:00 2001 From: "Lance (Weiqing) Xu" <47257262+lanceXwq@users.noreply.github.com> Date: Thu, 16 May 2024 10:13:44 -0700 Subject: [PATCH 059/495] Fix a broken link and some typos in dev_docs.md (#1450) --- docs/src/dev_docs.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/src/dev_docs.md b/docs/src/dev_docs.md index 85b0ee6150..9467094e2a 100644 --- a/docs/src/dev_docs.md +++ b/docs/src/dev_docs.md @@ -2,7 +2,7 @@ ## Development of Enzyme and Enzyme.jl together (recommended) -Normally Enzyme.jl downloads and install [Enzyme](github.com/EnzymeAD/Enzyme) for the user automatically since Enzyme needs to be built against +Normally Enzyme.jl downloads and installs [Enzyme](https://github.com/EnzymeAD/enzyme) for the user automatically since Enzyme needs to be built against Julia bundeled LLVM. In case that you are making updates to Enzyme and want to test them against Enzyme.jl the instructions below should help you get started. @@ -45,7 +45,7 @@ julia> Enzyme_jll.libEnzyme_path This should correspond to the path in the LocalPreferences.toml you just generated. Note that your system can have only one custom built Enzyme\_jll at a time. If you build one version for one version of Enzyme or Julia -and later build a new verison of Enzyme, it remove the old build. +and later build a new version of Enzyme, it removes the old build. Note that Julia versions are tightly coupled and you cannot use an Enzyme\_jll built for one version of Julia for another version of Julia. @@ -55,7 +55,7 @@ The same script can also be used to build Enzyme\_jll for a branch other than ma ~/s/Enzyme.jl (master)> julia --project=deps deps/build_local.jl --branch mybranch ``` -It can also be used to build Enzyme\_jll from a local copy of Enzyme on your machine, which do not need to be committed to git. +It can also be used to build Enzyme\_jll from a local copy of Enzyme on your machine, which does not need to be committed to git. ```bash ~/s/Enzyme.jl (master)> julia --project=deps deps/build_local.jl ../path/to/Enzyme @@ -105,7 +105,7 @@ julia> Base.libllvm_version_string "9.0.1jl" ``` -If the LLVM version ends in a `jl` you a likely using the private LLVM. +If the LLVM version ends in a `jl` you are likely using the private LLVM. In your source checkout of Enzyme: From d5823087c26b2a276946c1327acd701ac58a04d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Thu, 16 May 2024 19:14:25 +0200 Subject: [PATCH 060/495] Fix type escaping in `@import_frule`, `@import_rrule` (#1446) * Fix type escaping in `@import_frule`, `@import_rrule` * Test fix * Implement `fdiff`,`rdiff` methods for `MockType` tests * Comment PR in tests * Some fixes * Fix return type --- ext/EnzymeChainRulesCoreExt.jl | 4 ++-- test/ext/chainrulescore.jl | 29 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/ext/EnzymeChainRulesCoreExt.jl b/ext/EnzymeChainRulesCoreExt.jl index 4549aa84b8..9da9eb97fd 100644 --- a/ext/EnzymeChainRulesCoreExt.jl +++ b/ext/EnzymeChainRulesCoreExt.jl @@ -45,7 +45,7 @@ function Enzyme._import_frule(fn, tys...) val = Symbol("arg_$i") TA = Symbol("AN_$i") e = :($val::$TA) - push!(anns, :($TA <: Annotation{<:$ty})) + push!(anns, :($TA <: Annotation{<:$(esc(ty))})) push!(vals, val) push!(exprs, e) push!(primals, :($val.val)) @@ -156,7 +156,7 @@ function Enzyme._import_rrule(fn, tys...) val = Symbol("arg_$i") TA = Symbol("AN_$i") e = :($val::$TA) - push!(anns, :($TA <: Annotation{<:$ty})) + push!(anns, :($TA <: Annotation{<:$(esc(ty))})) push!(vals, val) push!(exprs, e) primal = Symbol("primcopy_$i") diff --git a/test/ext/chainrulescore.jl b/test/ext/chainrulescore.jl index 38fe07e87f..b73117faf2 100644 --- a/test/ext/chainrulescore.jl +++ b/test/ext/chainrulescore.jl @@ -5,7 +5,27 @@ using ChainRulesCore using LinearAlgebra using EnzymeTestUtils +module MockModule + struct MockType + x::Float32 + end + + mock_function(x::MockType) = 2 * x.x +end + +function ChainRulesCore.frule((_, ẋ), ::typeof(MockModule.mock_function), x) + y = MockModule.mock_function(x) + ẏ = 3 * ẋ.x + return y, ẏ +end + +function ChainRulesCore.rrule(::typeof(MockModule.mock_function), x) + y = MockModule.mock_function(x) + return y, ȳ -> 2 * ȳ +end + fdiff(f, x::Number) = autodiff(Forward, f, Duplicated, Duplicated(x, one(x)))[2] +fdiff(f, x::MockModule.MockType) = autodiff(Forward, f, Duplicated, Duplicated(x, MockModule.MockType(one(x.x))))[2] @testset "import_frule" begin f1(x) = 2*x @@ -28,6 +48,10 @@ fdiff(f, x::Number) = autodiff(Forward, f, Duplicated, Duplicated(x, one(x)))[2] @test fdiff(x -> f3(x, 1.0), 2.) === 5.0 @test fdiff(y -> f3(1.0, y), 2.) === 2.0 + # external module (checks correct type escaping, PR #1446) + Enzyme.@import_frule typeof(MockModule.mock_function) MockModule.MockType + @test fdiff(MockModule.mock_function, MockModule.MockType(1f0)) === 3f0 + @testset "batch duplicated" begin x = [1.0, 2.0, 0.0] Enzyme.@import_frule typeof(Base.sort) Any @@ -65,6 +89,7 @@ fdiff(f, x::Number) = autodiff(Forward, f, Duplicated, Duplicated(x, one(x)))[2] end rdiff(f, x::Number) = autodiff(Reverse, f, Active, Active(x))[1][1] +rdiff(f, x::MockModule.MockType) = autodiff(Reverse, f, Active, Active(x))[1][1] @testset "import_rrule" begin f1(x) = 2*x @@ -87,6 +112,10 @@ rdiff(f, x::Number) = autodiff(Reverse, f, Active, Active(x))[1][1] @test rdiff(x -> f3(x, 1.0), 2.) === 5.0 @test rdiff(y -> f3(1.0, y), 2.) === 2.0 + # external module (checks correct type escaping, PR #1446) + Enzyme.@import_rrule typeof(MockModule.mock_function) MockModule.MockType + @test rdiff(MockModule.mock_function, MockModule.MockType(1f0)) === MockModule.MockType(2f0) + @testset "batch duplicated" begin x = [1.0, 2.0, 0.0] Enzyme.@import_rrule typeof(Base.sort) Any From d84ac4866512ed50dd3cef54c73efcb12018680f Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 16 May 2024 13:58:03 -0700 Subject: [PATCH 061/495] Fix newstruct --- src/rules/typeunstablerules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 91b1ee837e..546bf67d7c 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -103,7 +103,7 @@ function common_newstructv_rev(offset, B, orig, gutils, tape) ty = new_from_original(gutils, origops[offset]) for v in origops[offset+1:end-1] - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_newstruct), emit_jltypeof!(B, new_from_original(gutils, v)), ty]) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_newstruct), emit_jltypeof!(B, lookup_value(gutils, new_from_original(gutils, v), B)), ty]) end return nothing From 986fc7125314d702c9c630d0568017eaece5023a Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 16 May 2024 14:09:35 -0700 Subject: [PATCH 062/495] Fix newstruct pt2 --- src/rules/typeunstablerules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 546bf67d7c..6b0c3e5d35 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -101,7 +101,7 @@ function common_newstructv_rev(offset, B, orig, gutils, tape) abs = [abs_typeof(v, true) for v in origops[offset+1:end-1]] - ty = new_from_original(gutils, origops[offset]) + ty = lookup_value(gutils, new_from_original(gutils, origops[offset]), B) for v in origops[offset+1:end-1] emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_newstruct), emit_jltypeof!(B, lookup_value(gutils, new_from_original(gutils, v), B)), ty]) end From a99a3e0db3cb09c52fcc2dec2c8af35f45dd30e1 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 17 May 2024 00:29:37 -0700 Subject: [PATCH 063/495] Handle mixed activity of literal 0 constant (#1449) --- src/compiler.jl | 138 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 103 insertions(+), 35 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index cedd6a6ab0..8340faba94 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1755,92 +1755,160 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err elseif errtype == API.ET_MixedActivityError data2 = LLVM.Value(data2) badval = nothing + gutils = GradientUtils(API.EnzymeGradientUtilsRef(data)) # Ignore mismatched activity if phi/store of ghost - todo = LLVM.Value[data2] - seen = Set{LLVM.Value}() + seen = Dict{LLVM.Value, LLVM.Value}() illegal = false - while length(todo) != 0 - cur = pop!(todo) - if cur in seen - continue - end - push!(seen, cur) - if isa(cur, LLVM.PHIInst) - for v in LLVM.incoming(cur) - push!(todo, cur) - end - continue + created = LLVM.Instruction[] + function make_replacement(cur::LLVM.Value, prevbb)::LLVM.Value + ncur = new_from_original(gutils, cur) + if cur in keys(seen) + return seen[cur] end - + legal, TT = abs_typeof(cur, true) if legal world = enzyme_extract_world(LLVM.parent(position(IRBuilder(B)))) if guaranteed_const_nongen(TT, world) - continue + return ncur end + legal2, obj = absint(cur) + + if legal2 && active_reg_inner(TT, (), world) == ActiveState && isa(cur, LLVM.ConstantExpr) + res = emit_allocobj!(prevbb, Base.RefValue{TT}) + push!(created, res) + return res + end + badval = if legal2 string(obj)*" of type"*" "*string(TT) else "Unknown object of type"*" "*string(TT) end illegal = true - break + return ncur end + if isa(cur, LLVM.PointerNull) - continue + return ncur end if isa(cur, LLVM.UndefValue) - continue + return ncur end @static if LLVM.version() >= v"12" if isa(cur, LLVM.PoisonValue) - continue + return ncur end end if isa(cur, LLVM.ConstantAggregateZero) - continue + return ncur end if isa(cur, LLVM.ConstantAggregate) - continue + return ncur end if isa(cur, LLVM.ConstantDataSequential) + cvals = LLVM.Value[] + changed = false for v in collect(cur) - push!(todo, v) + tmp = make_replacement(v, prevbb) + if illegal + return cur + end + if v != tmp + changed = true + end + push!(todo, tmp) end - continue + + cur2 = if changed + illegal = true + # TODO replace with correct insertions/splats + ncur + else + ncur + end + return cur2 end if isa(cur, LLVM.ConstantInt) if width(value_type(cur)) <= 8 - continue + return ncur end # if storing a constant int as a non-pointer, presume it is not a GC'd var and is safe # for activity state to mix if isa(val, LLVM.StoreInst) operands(val)[1] == cur && !isa(value_type(operands(val)[1]), LLVM.PointerType) - continue + return ncur end end + + if isa(cur, LLVM.PHIInst) + B = IRBuilder() + position!(B, ncur) + phi2 = phi!(prevbb, value_type(cur), "tempphi"*LLVM.name(cur)) + seen[cur] = phi2 + changed = false + recsize = length(created)+1 + for (v, bb) in LLVM.incoming(cur) + B2 = IRBuilder() + position!(B2, last(instructions(bb))) + tmp = make_replacement(v, B2) + if illegal + changed = true + break + end + if tmp != v && v != cur + changed = true + break + end + push!(LLVM.incoming(phi2), (tmp, bb)) + end + if !changed || illegal + LLVM.API.LLVMInstructionEraseFromParent(phi2) + seen[cur] = ncur + plen = length(created) + for i in recsize:plen + u = created[i] + replace_uses!(u, LLVM.UndefValue(value_type(u))) + end + for i in recsize:plen + u = created[i] + LLVM.API.LLVMInstructionEraseFromParent(u) + end + for i in recsize:plen + pop!(created) + end + return ncur + end + push!(created, phi2) + return phi2 + end + illegal = true - break + return ncur end - if !illegal - return C_NULL + newb = new_from_original(gutils, val) + while isa(newb, LLVM.PHIInst) + newb = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(newb)) end + b = IRBuilder(B) + replacement = make_replacement(data2, b) + if !illegal + return replacement.ref + end + for u in created + replace_uses!(u, LLVM.UndefValue(value_type(u))) + end + for u in created + LLVM.API.LLVMInstructionEraseFromParent(u) + end if LLVM.API.LLVMIsAReturnInst(val) != C_NULL mi, rt = enzyme_custom_extract_mi(LLVM.parent(LLVM.parent(val))::LLVM.Function, #=error=#false) if mi !== nothing && isghostty(rt) return C_NULL end end - - gutils = GradientUtils(API.EnzymeGradientUtilsRef(data)) - newb = new_from_original(gutils, val) - while isa(newb, LLVM.PHIInst) - newb = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(newb)) - end - b = IRBuilder(B) msg2 = sprint() do io print(io, msg) println(io) From 90370474604a21f84498faa56a20d03754b824c2 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 17 May 2024 04:20:47 -0400 Subject: [PATCH 064/495] Embarassing bugfix for mixedactivity --- src/compiler.jl | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 8340faba94..91e119bdb0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1760,6 +1760,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err seen = Dict{LLVM.Value, LLVM.Value}() illegal = false created = LLVM.Instruction[] + world = enzyme_extract_world(LLVM.parent(position(IRBuilder(B)))) function make_replacement(cur::LLVM.Value, prevbb)::LLVM.Value ncur = new_from_original(gutils, cur) if cur in keys(seen) @@ -1768,7 +1769,6 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err legal, TT = abs_typeof(cur, true) if legal - world = enzyme_extract_world(LLVM.parent(position(IRBuilder(B)))) if guaranteed_const_nongen(TT, world) return ncur end @@ -1842,9 +1842,9 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end if isa(cur, LLVM.PHIInst) - B = IRBuilder() - position!(B, ncur) - phi2 = phi!(prevbb, value_type(cur), "tempphi"*LLVM.name(cur)) + Bphi = IRBuilder() + position!(Bphi, ncur) + phi2 = phi!(Bphi, value_type(cur), "tempphi"*LLVM.name(cur)) seen[cur] = phi2 changed = false recsize = length(created)+1 @@ -1856,11 +1856,10 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err changed = true break end - if tmp != v && v != cur + if tmp != new_from_original(gutils, v) && v != cur changed = true - break end - push!(LLVM.incoming(phi2), (tmp, bb)) + push!(LLVM.incoming(phi2), (tmp, new_from_original(gutils, bb))) end if !changed || illegal LLVM.API.LLVMInstructionEraseFromParent(phi2) From 14409ce4c054d883ddbbf03e783d4057988a25ae Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 17 May 2024 01:41:03 -0700 Subject: [PATCH 065/495] Update Enzyme.jl to make import chainrules docs appear, hopefully --- src/Enzyme.jl | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 9087055022..c75508cd77 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1227,12 +1227,63 @@ end function _import_frule end # defined in EnzymeChainRulesCoreExt extension +""" + import_frule(::fn, tys...) + +Automatically import a `ChainRulesCore.frule`` as a custom forward mode `EnzymeRule`. When called in batch mode, this +will end up calling the primal multiple times, which may result in incorrect behavior if the function mutates, +and slow code, always. Importing the rule from `ChainRules` is also likely to be slower than writing your own rule, +and may also be slower than not having a rule at all. + +Use with caution. + +```julia +Enzyme.@import_frule(typeof(Base.sort), Any); + +x=[1.0, 2.0, 0.0]; dx=[0.1, 0.2, 0.3]; ddx = [0.01, 0.02, 0.03]; + +Enzyme.autodiff(Forward, sort, Duplicated, BatchDuplicated(x, (dx,ddx))) +Enzyme.autodiff(Forward, sort, DuplicatedNoNeed, BatchDuplicated(x, (dx,ddx))) +Enzyme.autodiff(Forward, sort, DuplicatedNoNeed, BatchDuplicated(x, (dx,))) +Enzyme.autodiff(Forward, sort, Duplicated, BatchDuplicated(x, (dx,))) + +# output + +(var"1" = [0.0, 1.0, 2.0], var"2" = (var"1" = [0.3, 0.1, 0.2], var"2" = [0.03, 0.01, 0.02])) +(var"1" = (var"1" = [0.3, 0.1, 0.2], var"2" = [0.03, 0.01, 0.02]),) +(var"1" = [0.3, 0.1, 0.2],) +(var"1" = [0.0, 1.0, 2.0], var"2" = [0.3, 0.1, 0.2]) + +``` +""" macro import_frule(args...) return _import_frule(args...) end function _import_rrule end # defined in EnzymeChainRulesCoreExt extension +""" + import_rrule(::fn, tys...) + +Automatically import a ChainRules.rrule as a custom reverse mode EnzymeRule. When called in batch mode, this +will end up calling the primal multiple times which results in slower code. This macro assumes that the underlying +function to be imported is read-only, and returns a Duplicated or Const object. This macro also assumes that the +inputs permit a .+= operation and that the output has a valid Enzyme.make_zero function defined. It also assumes +that overwritten(x) accurately describes if there is any non-preserved data from forward to reverse, not just +the outermost data structure being overwritten as provided by the specification. + +Finally, this macro falls back to almost always caching all of the inputs, even if it may not be needed for the +derivative computation. + +As a result, this auto importer is also likely to be slower than writing your own rule, and may also be slower +than not having a rule at all. + +Use with caution. + +```julia +Enzyme.@import_rrule(typeof(Base.sort), Any); +``` +""" macro import_rrule(args...) return _import_rrule(args...) end From 1bdd127edd548e9307124f76f0f63423fcfb2563 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 17 May 2024 10:47:20 -0700 Subject: [PATCH 066/495] Add jl_simpliy pass (#1445) * Add jl_simpliy pass * Update Project.toml --- Project.toml | 2 +- src/compiler/optimize.jl | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c05a708ba2..1e2d587658 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7" -Enzyme_jll = "0.0.110" +Enzyme_jll = "0.0.111" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 0c26d3fb46..842ebe0fe3 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -735,6 +735,10 @@ function pre_attr!(mod::LLVM.Module) return nothing end +function jl_inst_simplify!(PM) + ccall((:LLVMAddJLInstSimplifyPass, API.libEnzyme), Cvoid, (LLVM.API.LLVMPassManagerRef,), PM) +end + function post_attr!(mod::LLVM.Module) end @@ -1597,6 +1601,7 @@ function removeDeadArgs!(mod::LLVM.Module) propagate_returned!(mod) ModulePassManager() do pm instruction_combining!(pm) + jl_inst_simplify!(pm) alloc_opt!(pm) scalar_repl_aggregates_ssa!(pm) # SSA variant? cse!(pm) @@ -1615,6 +1620,7 @@ function removeDeadArgs!(mod::LLVM.Module) propagate_returned!(mod) ModulePassManager() do pm instruction_combining!(pm) + jl_inst_simplify!(pm) alloc_opt!(pm) scalar_repl_aggregates_ssa!(pm) # SSA variant? if RunAttributor[] @@ -1676,12 +1682,15 @@ end LLVM.API.LLVMAddGlobalOptimizerPass(pm) # Extra gvn!(pm) # Extra instruction_combining!(pm) + jl_inst_simplify!(pm) cfgsimplification!(pm) scalar_repl_aggregates_ssa!(pm) # SSA variant? instruction_combining!(pm) + jl_inst_simplify!(pm) jump_threading!(pm) correlated_value_propagation!(pm) instruction_combining!(pm) + jl_inst_simplify!(pm) reassociate!(pm) early_cse!(pm) alloc_opt!(pm) @@ -1695,6 +1704,7 @@ end loop_unswitch!(pm) end instruction_combining!(pm) + jl_inst_simplify!(pm) ind_var_simplify!(pm) loop_deletion!(pm) loop_unroll!(pm) @@ -1705,9 +1715,11 @@ end # This InstCombine needs to be after GVN # Otherwise it will generate load chains in GPU code... instruction_combining!(pm) + jl_inst_simplify!(pm) mem_cpy_opt!(pm) sccp!(pm) instruction_combining!(pm) + jl_inst_simplify!(pm) jump_threading!(pm) dead_store_elimination!(pm) alloc_opt!(pm) @@ -1722,6 +1734,7 @@ end aggressive_dce!(pm) instruction_combining!(pm) + jl_inst_simplify!(pm) # Loop Vectorize -- not for Enzyme # InstCombine @@ -1732,6 +1745,7 @@ end # FIXME: Currently crashes printing cfgsimplification!(pm) instruction_combining!(pm) # Extra for Enzyme + jl_inst_simplify!(pm) LLVM.API.LLVMAddGlobalOptimizerPass(pm) # Exxtra gvn!(pm) # Exxtra run!(pm, mod) @@ -1773,9 +1787,11 @@ function addOptimizationPasses!(pm) # consider AggressiveInstCombinePass at optlevel > 2 instruction_combining!(pm) + jl_inst_simplify!(pm) cfgsimplification!(pm) scalar_repl_aggregates!(pm) instruction_combining!(pm) # TODO: createInstSimplifyLegacy + jl_inst_simplify!(pm) jump_threading!(pm) correlated_value_propagation!(pm) @@ -1796,6 +1812,7 @@ function addOptimizationPasses!(pm) julia_licm!(pm) # Subsequent passes not stripping metadata from terminator instruction_combining!(pm) # TODO: createInstSimplifyLegacy + jl_inst_simplify!(pm) ind_var_simplify!(pm) loop_deletion!(pm) loop_unroll!(pm) # TODO: in Julia createSimpleLoopUnroll @@ -1806,6 +1823,7 @@ function addOptimizationPasses!(pm) # over the structure of an aggregate) scalar_repl_aggregates!(pm) instruction_combining!(pm) # TODO: createInstSimplifyLegacy + jl_inst_simplify!(pm) gvn!(pm) mem_cpy_opt!(pm) @@ -1816,6 +1834,7 @@ function addOptimizationPasses!(pm) # This needs to be InstCombine instead of InstSimplify to allow # loops over Union-typed arrays to vectorize. instruction_combining!(pm) + jl_inst_simplify!(pm) jump_threading!(pm) dead_store_elimination!(pm) @@ -1829,6 +1848,7 @@ function addOptimizationPasses!(pm) cfgsimplification!(pm) loop_deletion!(pm) instruction_combining!(pm) + jl_inst_simplify!(pm) loop_vectorize!(pm) # TODO: createLoopLoadEliminationPass cfgsimplification!(pm) @@ -1873,6 +1893,7 @@ function addJuliaLegalizationPasses!(pm, lower_intrinsics=true) dce!(pm) lower_ptls!(pm, #=dump_native=# false) instruction_combining!(pm) + jl_inst_simplify!(pm) # Clean up write barrier and ptls lowering cfgsimplification!(pm) else From a68bf8361cf4290baf3e8b40caed02f8151cf95b Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 17 May 2024 18:02:49 -0500 Subject: [PATCH 067/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1e2d587658..938be7413c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.6" +version = "0.12.7" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 28f855b8b6c2fee6c08019b8f073845cb6bbe449 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 20 May 2024 10:03:43 -0500 Subject: [PATCH 068/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 938be7413c..4d8dd580b0 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7" -Enzyme_jll = "0.0.111" +Enzyme_jll = "0.0.112" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" From d5b9d50c734e497c5cc76934ab9a1bd075ef56bd Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 20 May 2024 17:13:42 -0500 Subject: [PATCH 069/495] Type unstable setfield support (#1455) * Type unstable setfield support * with test fix * nulls are for squares * fix --- src/compiler.jl | 7 +++ src/rules/typeunstablerules.jl | 79 +++++++++++++++++++++++++++++++++- test/runtests.jl | 34 +++++++++++++++ 3 files changed, 118 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 91e119bdb0..fc3103aa08 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1194,6 +1194,13 @@ function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, N) end end + +@inline function EnzymeCore.make_zero(x::FT)::FT where {FT <: AbstractFloat} + return Base.zero(x) +end +@inline function EnzymeCore.make_zero(x::Complex{FT})::Complex{FT} where {FT <: AbstractFloat} + return Base.zero(x) +end @inline function EnzymeCore.make_zero(x::Array{FT, N})::Array{FT, N} where {FT <: AbstractFloat, N} return Base.zero(x) end diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 6b0c3e5d35..f70e15c82c 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -714,17 +714,92 @@ function common_setfield_fwd(offset, B, orig, gutils, normalR, shadowR) return false end + +function rt_jl_setfield_aug(dptr::T, idx, ::Val{isconst}, val, dval) where {T, isconst} + RT = Core.Typeof(val) + if active_reg(RT) + setfield!(dptr, idx, make_zero(val)) + else + setfield!(dptr, idx, isconst ? val : dval) + end +end + +function rt_jl_setfield_rev(dptr::T, idx, ::Val{isconst}, val, dval) where {T, isconst} + RT = Core.Typeof(val) + if active_reg(RT) && !isconst + dval[] += getfield(dptr, idx) + setfield!(dptr, idx, make_zero(val)) + end +end + function common_setfield_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - emit_error(B, orig, "Enzyme: unhandled augmented forward for jl_f_setfield") + normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end + + origops = collect(operands(orig))[offset:end] + if !is_constant_value(gutils, origops[2]) + width = get_width(gutils) + + shadowstruct = invert_pointer(gutils, origops[2], B) + + shadowval = if !is_constant_value(gutils, origops[2]) + invert_pointer(gutils, origops[4], B) + else + nothing + end + + for idx in 1:width + vals = LLVM.Value[ + (width == 1) ? shadowstruct : extract_value!(B, shadowstruct, idx-1), + new_from_original(gutils, origops[3]), + unsafe_to_llvm(Val(is_constant_value(gutils, origops[4]))), + new_from_original(gutils, origops[4]), + is_constant_value(gutils, origops[4]) ? unsafe_to_llvm(nothing) : ((width == 1) ? shadowval : extract_value!(B, shadowval, idx-1)), + ] + + pushfirst!(vals, unsafe_to_llvm(rt_jl_setfield_aug)) + + cal = emit_apply_generic!(B, vals) + + debug_from_orig!(gutils, cal, orig) + end + end + return false end function common_setfield_rev(offset, B, orig, gutils, tape) - emit_error(B, orig, "Enzyme: unhandled reverse for jl_f_setfield") + origops = collect(operands(orig))[offset:end] + if !is_constant_value(gutils, origops[2]) + width = get_width(gutils) + + shadowstruct = invert_pointer(gutils, origops[2], B) + + shadowval = if !is_constant_value(gutils, origops[2]) + invert_pointer(gutils, origops[4], B) + else + nothing + end + + for idx in 1:width + vals = LLVM.Value[ + lookup_value(gutils, (width == 1) ? shadowstruct : extract_value!(B, shadowstruct, idx-1), B), + lookup_value(gutils, new_from_original(gutils, origops[3]), B), + unsafe_to_llvm(Val(is_constant_value(gutils, origops[4]))), + lookup_value(gutils, new_from_original(gutils, origops[4]), B), + is_constant_value(gutils, origops[4]) ? unsafe_to_llvm(nothing) : lookup_value(gutils, ((width == 1) ? shadowval : extract_value!(B, shadowval, idx-1)), B), + ] + + pushfirst!(vals, unsafe_to_llvm(rt_jl_setfield_rev)) + + cal = emit_apply_generic!(B, vals) + + debug_from_orig!(gutils, cal, orig) + end + end return nothing end diff --git a/test/runtests.jl b/test/runtests.jl index b0107a786a..47d9debbff 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2175,6 +2175,40 @@ end @test mt2.y ≈ 6.0 @test dmt2.x ≈ 1.2 @test dmt2.y ≈ 2.4 + + function sf_for2(v, fld, fld2, x) + setfield!(v, fld, 0.0) + for i in 1:100 + setfield!(v, fld2, getfield(v, fld)::Float64 + x * i) + end + return getfield(v, fld)::Float64 + end + + mt2 = MyType2(0.0, 0.0) + dmt2 = MyType2(0.0, 0.0) + + adres = Enzyme.autodiff(Reverse, sf_for2, Duplicated(mt2, dmt2), Const(:x), Const(:x), Active(3.1)) + @test adres[1][4] ≈ 5050.0 + + mutable struct MyType3 + x::Base.RefValue{Float64} + y::Base.RefValue{Float64} + end + + function sf_for3(v, fld, fld2, x) + setfield!(v, fld, Ref(0.0)) + for i in 1:100 + setfield!(v, fld2, Base.Ref((getfield(v, fld)::Base.RefValue{Float64})[] + x * i)) + end + return (getfield(v, fld)::Base.RefValue{Float64})[] + end + + mt3 = MyType3(Ref(0.0), Ref(0.0)) + dmt3 = MyType3(Ref(0.0), Ref(0.0)) + + adres = Enzyme.autodiff(Reverse, sf_for3, Duplicated(mt3, dmt3), Const(:x), Const(:x), Active(3.1)) + @test adres[1][4] ≈ 5050.0 + end From c4725f1fcac1bf7fe097f280e707b625d88d8660 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 21 May 2024 09:00:16 -0500 Subject: [PATCH 070/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4d8dd580b0..ed7c47b719 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7" -Enzyme_jll = "0.0.112" +Enzyme_jll = "0.0.113" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" From a654fd5b627f37f6660a4a50ab4e73c9d375ae29 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 21 May 2024 09:23:24 -0500 Subject: [PATCH 071/495] Batched mixed activity rewrite (#1457) * Batched mixed activity rewrite * fix --- src/api.jl | 2 + src/compiler.jl | 116 +++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 96 insertions(+), 22 deletions(-) diff --git a/src/api.jl b/src/api.jl index 84080d4b49..c025238abc 100644 --- a/src/api.jl +++ b/src/api.jl @@ -230,6 +230,8 @@ const CustomReversePass = Ptr{Cvoid} EnzymeRegisterCallHandler(name, fwdhandle, revhandle) = ccall((:EnzymeRegisterCallHandler, libEnzyme), Cvoid, (Cstring, CustomAugmentedForwardPass, CustomReversePass), name, fwdhandle, revhandle) EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall((:EnzymeRegisterFwdCallHandler, libEnzyme), Cvoid, (Cstring, CustomForwardPass), name, fwdhandle) +EnzymeInsertValue(B::LLVM.IRBuilder, v::LLVM.Value, v2::LLVM.Value, insts::Vector{Cuint}, name="") = LLVM.Value(ccall((:EnzymeInsertValue, libEnzyme), LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVMValueRef, LLVMValueRef, Ptr{Cuint}, Int64, Cstring), B, v, v2, insts, length(insts), name)) + EnzymeSetCalledFunction(ci::LLVM.CallInst, fn::LLVM.Function, toremove) = ccall((:EnzymeSetCalledFunction, libEnzyme), Cvoid, (LLVMValueRef, LLVMValueRef, Ptr{Int64}, Int64), ci, fn, toremove, length(toremove)) EnzymeCloneFunctionWithoutReturnOrArgs(fn::LLVM.Function, keepret, args) = ccall((:EnzymeCloneFunctionWithoutReturnOrArgs, libEnzyme), LLVMValueRef, (LLVMValueRef,UInt8,Ptr{Int64}, Int64), fn, keepret, args, length(args)) EnzymeGetShadowType(width, T) = ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64,LLVMTypeRef), width, T) diff --git a/src/compiler.jl b/src/compiler.jl index fc3103aa08..769451f29c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1768,6 +1768,24 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err illegal = false created = LLVM.Instruction[] world = enzyme_extract_world(LLVM.parent(position(IRBuilder(B)))) + width = get_width(gutils) + function make_batched(cur, B) + if width == 1 + return cur + else + shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur)))) + for idx in 1:width + shadowres = insert_value!(B, shadowres, cur, idx-1) + if isa(shadowres, LLVM.Instruction) + push!(created, shadowres) + end + end + return shadowres + end + end + + illegalVal = nothing + function make_replacement(cur::LLVM.Value, prevbb)::LLVM.Value ncur = new_from_original(gutils, cur) if cur in keys(seen) @@ -1777,15 +1795,26 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err legal, TT = abs_typeof(cur, true) if legal if guaranteed_const_nongen(TT, world) - return ncur + return make_batched(ncur, prevbb) end legal2, obj = absint(cur) - if legal2 && active_reg_inner(TT, (), world) == ActiveState && isa(cur, LLVM.ConstantExpr) - res = emit_allocobj!(prevbb, Base.RefValue{TT}) - push!(created, res) - return res + # Only do so for the immediate operand/etc to a phi, since otherwise we will make multiple + if legal2 && active_reg_inner(TT, (), world) == ActiveState && isa(cur, LLVM.ConstantExpr) && cur == data2 + if width == 1 + res = emit_allocobj!(prevbb, Base.RefValue{TT}) + push!(created, res) + return res + else + shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur)))) + for idx in 1:width + res = emit_allocobj!(prevbb, Base.RefValue{TT}) + shadowres = insert_value!(prevbb, shadowres, res, idx-1) + push!(created, shadowres) + end + return shadowres + end end badval = if legal2 @@ -1793,26 +1822,32 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err else "Unknown object of type"*" "*string(TT) end + illegalVal = cur illegal = true - return ncur + return make_batched(ncur, prevbb) end if isa(cur, LLVM.PointerNull) - return ncur + return make_batched(ncur, prevbb) end if isa(cur, LLVM.UndefValue) - return ncur + return make_batched(ncur, prevbb) end @static if LLVM.version() >= v"12" if isa(cur, LLVM.PoisonValue) - return ncur + return make_batched(ncur, prevbb) end end if isa(cur, LLVM.ConstantAggregateZero) - return ncur + return make_batched(ncur, prevbb) end if isa(cur, LLVM.ConstantAggregate) - return ncur + return make_batched(ncur, prevbb) + end + if isa(cur, LLVM.ConstantInt) + if convert(UInt64, cur) == 0 + return make_batched(ncur, prevbb) + end end if isa(cur, LLVM.ConstantDataSequential) cvals = LLVM.Value[] @@ -1820,7 +1855,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err for v in collect(cur) tmp = make_replacement(v, prevbb) if illegal - return cur + return ncur end if v != tmp changed = true @@ -1829,40 +1864,77 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end cur2 = if changed + illegalVal = cur illegal = true # TODO replace with correct insertions/splats ncur else - ncur + make_batched(ncur, prevbb) end return cur2 end if isa(cur, LLVM.ConstantInt) - if width(value_type(cur)) <= 8 - return ncur + if LLVM.width(value_type(cur)) <= 8 + return make_batched(ncur, prevbb) end # if storing a constant int as a non-pointer, presume it is not a GC'd var and is safe # for activity state to mix if isa(val, LLVM.StoreInst) operands(val)[1] == cur && !isa(value_type(operands(val)[1]), LLVM.PointerType) + return make_batched(ncur, prevbb) + end + end + + if isa(cur, LLVM.InsertValueInst) + lhs = make_replacement(operands(cur)[1], prevbb) + if illegal + return ncur + end + rhs = make_replacement(operands(cur)[2], prevbb) + if illegal return ncur end + if lhs == operands(cur)[1] && rhs == operands(cur)[2] + return make_batched(ncur, prevbb) + end + inds = LLVM.API.LLVMGetIndices(cur.ref) + ninds = LLVM.API.LLVMGetNumIndices(cur.ref) + jinds = Cuint[unsafe_load(inds, i) for i in 1:ninds] + if width == 1 + nv = API.EnzymeInsertValue(prevbb, lhs, rhs, jinds) + push!(created, nv) + seen[cur] = nv + return nv + else + shadowres = lhs + for idx in 1:width + jindsv = copy(jinds) + pushfirst!(jindsv, idx-1) + shadowres = API.EnzymeInsertValue(prevbb, shadowres, extract_value!(prevbb, rhs, idx-1), jindsv) + if isa(shadowres, LLVM.Instruction) + push!(created, shadowres) + end + end + return shadowres + end end if isa(cur, LLVM.PHIInst) Bphi = IRBuilder() position!(Bphi, ncur) - phi2 = phi!(Bphi, value_type(cur), "tempphi"*LLVM.name(cur)) + shadowty = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))) + phi2 = phi!(Bphi, shadowty, "tempphi"*LLVM.name(cur)) seen[cur] = phi2 changed = false recsize = length(created)+1 for (v, bb) in LLVM.incoming(cur) B2 = IRBuilder() - position!(B2, last(instructions(bb))) + position!(B2, new_from_original(gutils, last(instructions(bb)))) tmp = make_replacement(v, B2) if illegal changed = true break end + @assert value_type(tmp) == shadowty if tmp != new_from_original(gutils, v) && v != cur changed = true end @@ -1883,20 +1955,17 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err for i in recsize:plen pop!(created) end - return ncur + return illegal ? ncur : make_batched(ncur, prevbb) end push!(created, phi2) return phi2 end illegal = true + illegalVal = cur return ncur end - newb = new_from_original(gutils, val) - while isa(newb, LLVM.PHIInst) - newb = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(newb)) - end b = IRBuilder(B) replacement = make_replacement(data2, b) @@ -1931,6 +2000,9 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err println(io, Base.unsafe_string(st)) API.EnzymeStringFree(st) end + if illegalVal !== nothing + println(io, " llvalue="*string(illegalVal)) + end println(io, "You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now") if bt !== nothing Base.show_backtrace(io, bt) From 62b4d5952f88b314650893ed73c026f6f677165d Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 21 May 2024 09:25:49 -0500 Subject: [PATCH 072/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ed7c47b719..eb689480c5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.7" +version = "0.12.8" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 4402bb602fb648f9257f8509dcbca651e97b533a Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 22 May 2024 12:56:24 -0400 Subject: [PATCH 073/495] Handle function type in to_tape --- src/compiler.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 769451f29c..bf8cd23114 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2107,7 +2107,12 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} return Any, true else e = LLVM.API.LLVMGetElementType(Type) - return Core.LLVMPtr{to_tape_type(e)[1], Int(addrspace)}, false + tkind2 = LLVM.API.LLVMGetTypeKind(e) + if tkind2 == LLVM.API.LLVMFunctionTypeKind + return Core.LLVMPtr{Cvoid, Int(addrspace)}, false + else + return Core.LLVMPtr{to_tape_type(e)[1], Int(addrspace)}, false + end end end if tkind == LLVM.API.LLVMArrayTypeKind @@ -2170,7 +2175,7 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} if tkind == LLVM.API.LLVMFP128TypeKind return Float128, false end - error("Can't construct tape type for $Type") + error("Can't construct tape type for $Type $(string(Type)) $tkind") end function tape_type(LLVMType::LLVM.LLVMType) From 9aaf6d41b48f6b99c46909d076904f7f31740c00 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Wed, 22 May 2024 13:13:17 -0500 Subject: [PATCH 074/495] Add max type depth --- src/api.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/api.jl b/src/api.jl index c025238abc..e7d37f8cb5 100644 --- a/src/api.jl +++ b/src/api.jl @@ -467,6 +467,25 @@ function maxtypeoffset!(val) ccall((:EnzymeSetCLInteger, libEnzyme), Cvoid, (Ptr{Cvoid}, Int64), ptr, val) end +""" + maxtypedepth!(val::Bool) + +Enzyme runs a type analysis to deduce the corresponding types of all values being +differentiated. This is necessary to compute correct derivatives of various values. +To ensure this analysis temrinates, it operates on a finite lattice of possible +states. This function sets the maximum depth into a type that Enzyme will consider. +A smaller value will cause type analysis to run faster, but may result in some +necessary types not being found and result in unknown type errors. A larger value +may result in unknown type errors being resolved by searching a larger space, but +may run longer. The default setting is 6. +""" +function maxtypedepth!(val) + ptr = cglobal((:EnzymeMaxTypeDepth, libEnzyme)) + ccall((:EnzymeSetCLInteger, libEnzyme), Cvoid, (Ptr{Cvoid}, Int64), ptr, val) +end + + + """ looseTypeAnalysis!(val::Bool) From d230f661c5ea775da8e723591ee7e148d0e2d336 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 22 May 2024 19:30:03 -0500 Subject: [PATCH 075/495] Fix unused shadow return type (#1462) --- src/rules/jitrules.jl | 82 ++++++++++++++++++++++++++++++++++--------- test/runtests.jl | 73 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+), 16 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 2f818df4a0..b16fe9c481 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -455,7 +455,11 @@ function common_generic_fwd(offset, B, orig, gutils, normalR, shadowR) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return true end @@ -504,7 +508,11 @@ function common_generic_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return true end @@ -552,12 +560,17 @@ function generic_augfwd(B, orig, gutils, normalR, shadowR, tapeR) end function common_generic_rev(offset, B, orig, gutils, tape)::Cvoid - if !is_constant_value(gutils, orig) || !is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - @assert tape !== C_NULL - width = get_width(gutils) - generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset, B, true; tape) + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + return nothing end + + @assert tape !== C_NULL + width = get_width(gutils) + generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset, B, true; tape) return nothing end @@ -572,7 +585,11 @@ function generic_rev(B, orig, gutils, tape)::Cvoid end function common_apply_latest_fwd(offset, B, orig, gutils, normalR, shadowR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return true end mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -613,7 +630,11 @@ function common_apply_latest_fwd(offset, B, orig, gutils, normalR, shadowR) end function common_apply_latest_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return true end @@ -656,6 +677,13 @@ function common_apply_latest_augfwd(offset, B, orig, gutils, normalR, shadowR, t end function common_apply_latest_rev(offset, B, orig, gutils, tape)::Cvoid + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + return nothing + end if !is_constant_value(gutils, orig) || !is_constant_inst(gutils, orig) width = get_width(gutils) generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset+1, B, true; tape) @@ -690,10 +718,14 @@ function apply_latest_rev(B, orig, gutils, tape) end function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return true end - + v, isiter = absint(operands(orig)[offset+1]) v2, istup = absint(operands(orig)[offset+2]) @@ -778,7 +810,11 @@ function error_if_active_iter(arg) end function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return true end @@ -859,7 +895,11 @@ function apply_iterate_rev(B, orig, gutils, tape) end function common_invoke_fwd(offset, B, orig, gutils, normalR, shadowR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return true end @@ -899,7 +939,11 @@ function common_invoke_fwd(offset, B, orig, gutils, normalR, shadowR) end function common_invoke_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return true end normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing @@ -946,10 +990,16 @@ function common_invoke_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) end function common_invoke_rev(offset, B, orig, gutils, tape) - if !is_constant_value(gutils, orig) || !is_constant_inst(gutils, orig) - width = get_width(gutils) - generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset+1, B, true; tape) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + return nothing end + + width = get_width(gutils) + generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset+1, B, true; tape) return nothing end diff --git a/test/runtests.jl b/test/runtests.jl index 47d9debbff..f0f2aa9fa0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3051,6 +3051,79 @@ end end end +const CUmemoryPool2 = Ptr{Float64} + +struct CUmemPoolProps2 + reserved::NTuple{31,Char} +end + +mutable struct CuMemoryPool2 + handle::CUmemoryPool2 +end + +function ccall_macro_lower(func, rettype, types, args, nreq) + # instead of re-using ccall or Expr(:foreigncall) to perform argument conversion, + # we need to do so ourselves in order to insert a jl_gc_safe_enter|leave + # just around the inner ccall + + cconvert_exprs = [] + cconvert_args = [] + for (typ, arg) in zip(types, args) + var = gensym("$(func)_cconvert") + push!(cconvert_args, var) + push!(cconvert_exprs, quote + $var = Base.cconvert($(esc(typ)), $(esc(arg))) + end) + end + + unsafe_convert_exprs = [] + unsafe_convert_args = [] + for (typ, arg) in zip(types, cconvert_args) + var = gensym("$(func)_unsafe_convert") + push!(unsafe_convert_args, var) + push!(unsafe_convert_exprs, quote + $var = Base.unsafe_convert($(esc(typ)), $arg) + end) + end + + quote + $(cconvert_exprs...) + + $(unsafe_convert_exprs...) + + ret = ccall($(esc(func)), $(esc(rettype)), $(Expr(:tuple, map(esc, types)...)), + $(unsafe_convert_args...)) + end +end + +macro gcsafe_ccall(expr) + ccall_macro_lower(Base.ccall_macro_parse(expr)...) +end + +function cuMemPoolCreate2(pool, poolProps) + # CUDA.initialize_context() + #CUDA. + gc_state = @ccall(jl_gc_safe_enter()::Int8) + @gcsafe_ccall cuMemPoolCreate(pool::Ptr{CUmemoryPool2}, + poolProps::Ptr{CUmemPoolProps2})::Cvoid + @ccall(jl_gc_safe_leave(gc_state::Int8)::Cvoid) +end + +function cual() + props = Ref(CUmemPoolProps2( + ntuple(i->Char(0), 31) + )) + handle_ref = Ref{CUmemoryPool2}() + cuMemPoolCreate2(handle_ref, props) + + CuMemoryPool2(handle_ref[]) +end + +@testset "Unused shadow phi rev" begin + fwd, rev = Enzyme.autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(cual)}, Duplicated) +end + + const SEED = 42 const N_SAMPLES = 500 const N_COMPONENTS = 4 From 8da09fb6fb8a45b89608cf2d65819dc0389045cc Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 23 May 2024 09:23:32 -0500 Subject: [PATCH 076/495] Faster type analysis (no JIT) (#1463) * Faster type analysis (no JIT) * fix --- src/absint.jl | 10 +++- src/compiler.jl | 104 +++++++++++++++++------------------------ src/rules/typerules.jl | 78 ------------------------------- 3 files changed, 53 insertions(+), 139 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 36c1689832..89ae5577b4 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -144,6 +144,7 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ ("jl_box_uint64", UInt64), ("ijl_box_uint64", UInt64), ("jl_box_int32", Int32), ("ijl_box_int32", Int32), ("jl_box_uint32", UInt32), ("ijl_box_uint32", UInt32), + ("jl_box_float32", Float32), ("ijl_box_float32", Float32), ) if nm == fname return (true, ty) @@ -221,7 +222,11 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ end if nm == "jl_array_copy" || nm == "ijl_array_copy" - return abs_typeof(operands(arg)[1], partial) + legal, RT = abs_typeof(operands(arg)[1], partial) + if legal + @assert RT <: Array + end + return (legal, RT) end _, RT = enzyme_custom_extract_mi(arg, false) @@ -284,6 +289,9 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ fieldoffset(typ, i+1) end - offset if fsize == llsz(value_type(larg)) + if Base.isconcretetype(subT) && is_concrete_tuple(subT) && length(subT.parameters) == 1 + subT = subT.parameters[1] + end return (true, subT) end end diff --git a/src/compiler.jl b/src/compiler.jl index bf8cd23114..511ec256d7 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3197,70 +3197,15 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr retType = convert(API.CDIFFE_TYPE, rt) rules = Dict{String, API.CustomRuleType}( - "jl_apply_generic" => @cfunction(ptr_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_apply_generic" => @cfunction(ptr_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "julia.gc_alloc_obj" => @cfunction(alloc_obj_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_box_float32" => @cfunction(f32_box_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_box_float32" => @cfunction(f32_box_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_box_int64" => @cfunction(i64_box_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_box_int64" => @cfunction(i64_box_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_box_uint64" => @cfunction(i64_box_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_box_uint64" => @cfunction(i64_box_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), "jl_array_copy" => @cfunction(inout_rule, UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), "ijl_array_copy" => @cfunction(inout_rule, UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_alloc_array_1d" => @cfunction(alloc_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_alloc_array_1d" => @cfunction(alloc_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_alloc_array_2d" => @cfunction(alloc_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_alloc_array_2d" => @cfunction(alloc_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_alloc_array_3d" => @cfunction(alloc_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_alloc_array_3d" => @cfunction(alloc_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), "julia.pointer_from_objref" => @cfunction(inout_rule, UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_wait" => @cfunction(noop_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_enq_work" => @cfunction(noop_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - - "enz_noop" => @cfunction(noop_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), "jl_inactive_inout" => @cfunction(inout_rule, UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), @@ -4979,15 +4924,15 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; # fn, but it doesn't presently so for now we will ensure this by hand if func == typeof(Base.Checked.throw_overflowerr_binaryop) llvmfn = functions(mod)[k.specfunc] - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("readonly")]) + handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("readonly"), StringAttribute("enzyme_ta_norecur")]) continue end if EnzymeRules.is_inactive_from_sig(mi.specTypes; world, method_table, caller) - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation")]) + handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation"), StringAttribute("enzyme_ta_norecur")]) continue end if EnzymeRules.is_inactive_noinl_from_sig(mi.specTypes; world, method_table, caller) - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation")], false, false) + handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation"), StringAttribute("enzyme_ta_norecur")], false, false) for bb in blocks(llvmfn) for inst in instructions(bb) if isa(inst, LLVM.CallInst) @@ -5013,12 +4958,12 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; continue end if func == typeof(Base.enq_work) && length(sparam_vals) == 1 && first(sparam_vals) <: Task - handleCustom(llvmfn, "jl_enq_work") + handleCustom(llvmfn, "jl_enq_work", [StringAttribute("enzyme_ta_norecur")]) continue end if func == typeof(Base.wait) || func == typeof(Base._wait) if length(sparam_vals) == 1 && first(sparam_vals) <: Task - handleCustom(llvmfn, "jl_wait") + handleCustom(llvmfn, "jl_wait", [StringAttribute("enzyme_ta_norecur")]) end continue end @@ -5191,11 +5136,50 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; GPUCompiler.optimize_module!(parent_job, mod) end + seen = TypeTreeTable() + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + dl = string(LLVM.datalayout(mod)) + ctx = LLVM.context(mod) for f in functions(mod), bb in blocks(f), inst in instructions(bb) if !isa(inst, LLVM.CallInst) continue end + fn = LLVM.called_operand(inst) + + if !API.HasFromStack(inst) && (!isa(fn, LLVM.Function) || isempty(blocks(fn))) + legal, source_typ = abs_typeof(inst) + codegen_typ = value_type(inst) + if legal + typ = if codegen_typ isa LLVM.PointerType + llvm_source_typ = convert(LLVMType, source_typ; allow_boxed=true) + # pointers are used for multiple kinds of arguments + # - literal pointer values + if source_typ <: Ptr || source_typ <: Core.LLVMPtr + source_typ + elseif llvm_source_typ isa LLVM.PointerType + #if llvm_source_typ != codegen_typ + # throw(AssertionError("llvmtype ($llvm_source_typ) is not codegen_typ ($codegen_typ), source_typ = $source_typ within $(string(inst))")) + #end + # push!(args, (cc=MUT_REF, typ=source_typ, name=source_name, idx=codegen_i)) + Ptr{source_typ} + # - references to aggregates + else + @assert llvm_source_typ != codegen_typ + # push!(args, (cc=BITS_REF, typ=source_typ, name=source_name, idx=codegen_i)) + Ptr{source_typ} + end + else + codegen_typ + end + + LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_type", string(typetree(typ, ctx, dl, seen)))) + elseif codegen_typ == T_prjlvalue + LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_type", "{[-1]:Pointer}")) + end + end + if !isa(fn, LLVM.Function) continue end diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index 4730db8654..569ef87323 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -1,28 +1,4 @@ -function noop_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 - return UInt8(false) -end - -function alloc_obj_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 - inst = LLVM.Instruction(val) - if API.HasFromStack(inst) - return UInt8(false) - end - legal, typ = abs_typeof(inst) - if !legal - return UInt8(false) - throw(AssertionError("Cannot deduce type of alloc obj, $(string(inst)) of $(string(LLVM.parent(LLVM.parent(inst))))")) - end - - ctx = LLVM.context(LLVM.Value(val)) - dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - - rest = typetree(typ, ctx, dl) # copy unecessary since only user of `rest` - only!(rest, -1) - API.EnzymeMergeTypeTree(ret, rest) - return UInt8(false) -end - function int_return_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 TT = TypeTree(API.DT_Integer, LLVM.context(LLVM.Value(val))) only!(TT, -1) @@ -30,41 +6,6 @@ function int_return_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.C return UInt8(false) end -function i64_box_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 - val = LLVM.Instruction(val) - TT = TypeTree(API.DT_Pointer, LLVM.context(val)) - if (direction & API.DOWN) != 0 - sub = TypeTree(unsafe_load(args)) - ctx = LLVM.context(val) - dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(val))))) - maxSize = div(width(value_type(operands(val)[1]))+7, 8) - shift!(sub, dl, 0, maxSize, 0) - API.EnzymeMergeTypeTree(TT, sub) - end - only!(TT, -1) - API.EnzymeMergeTypeTree(ret, TT) - return UInt8(false) -end - - -function f32_box_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 - TT = TypeTree(API.DT_Float, LLVM.context(LLVM.Value(val))) - only!(TT, -1) - API.EnzymeMergeTypeTree(unsafe_load(args), TT) - - API.EnzymeMergeTypeTree(TT, TypeTree(API.DT_Pointer,LLVM.context(LLVM.Value(val)))) - only!(TT, -1) - API.EnzymeMergeTypeTree(ret, TT) - return UInt8(false) -end - -function ptr_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 - TT = TypeTree(API.DT_Pointer, LLVM.context(LLVM.Value(val))) - only!(TT, -1) - API.EnzymeSetTypeTree(ret, TT) - return UInt8(false) -end - function inout_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 if numArgs != 1 return UInt8(false) @@ -97,22 +38,3 @@ function inout_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeT end return UInt8(false) end - -function alloc_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 - inst = LLVM.Instruction(val) - - legal, typ = abs_typeof(inst) - @assert legal - - ctx = LLVM.context(LLVM.Value(val)) - dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - - rest = typetree(typ, ctx, dl) # copy unecessary since only user of `rest` - only!(rest, -1) - API.EnzymeMergeTypeTree(ret, rest) - - for i = 1:numArgs - API.EnzymeMergeTypeTree(unsafe_load(args, i), TypeTree(API.DT_Integer, -1, ctx)) - end - return UInt8(false) -end From 1e8ff37b19c7f1ef1c44dcffc3190bdec89017b3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 23 May 2024 09:24:21 -0500 Subject: [PATCH 077/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index eb689480c5..621d5bbd5f 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7" -Enzyme_jll = "0.0.113" +Enzyme_jll = "0.0.114" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" From 95a11a3dcd1b1d9700ad5cc207b609e7476a7870 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 24 May 2024 10:45:49 -0500 Subject: [PATCH 078/495] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 621d5bbd5f..003f360e05 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.8" +version = "0.12.9" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -30,7 +30,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7" -Enzyme_jll = "0.0.114" +Enzyme_jll = "0.0.115" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" From cc2b4583f95ba2463af905bc91c2578337c9246e Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 24 May 2024 12:26:19 -0500 Subject: [PATCH 079/495] Remove closure from active_reg_inner (#1466) --- src/compiler.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 511ec256d7..52a510fd33 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -371,6 +371,13 @@ end end) end +@inline function active_reg_recur(::Type{ST}, seen::Seen, world, ::Val{justActive}, ::Val{UnionSret}) where {ST, Seen, justActive, UnionSret} + if ST isa Union + return forcefold(Val(active_reg_recur(ST.a, seen, world, Val(justActive), Val(UnionSret))), Val(active_reg_recur(ST.b, seen, world, Val(justActive), Val(UnionSret)))) + end + return active_reg_inner(ST, seen, world, Val(justActive), Val(UnionSret)) +end + @inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret} if T === Any @@ -436,13 +443,7 @@ end # if sret union, the data is stored in a stack memory location and is therefore # not unique'd preventing the boxing of the union in the default case if UnionSret && is_sret_union(T) - @inline function recur(::Type{ST}) where ST - if ST isa Union - return forcefold(Val(recur(ST.a)), Val(recur(ST.b))) - end - return active_reg_inner(ST, seen, world, Val(justActive), Val(UnionSret)) - end - return recur(T) + return active_reg_recur(T, seen, world, Val(justActive), Val(UnionSret)) else if justActive return AnyState From 0dd3c37b9cfe72a4071a49f5168d6a10e5c03007 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 25 May 2024 17:08:25 -0400 Subject: [PATCH 080/495] Nicer CUDA errors (#1470) --- src/compiler.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 52a510fd33..f80490d5f0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1353,8 +1353,13 @@ function emit_error(B::LLVM.IRBuilder, orig, string) string*=sprint(io->Base.show_backtrace(io, bt)) end + ct = if occursin("ptx", LLVM.triple(mod)) + GPUCompiler.emit_exception!(B, string, orig) + else + call!(B, funcT, func, LLVM.Value[globalstring_ptr!(B, string)]) + end + # 2. Call error function and insert unreachable - ct = call!(B, funcT, func, LLVM.Value[globalstring_ptr!(B, string)]) LLVM.API.LLVMAddCallSiteAttribute(ct, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), EnumAttribute("noreturn")) LLVM.API.LLVMAddCallSiteAttribute(ct, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("enzyme_error")) return ct From ba54eb26253d92b80928e14a97fa00c2ab1b39f1 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 26 May 2024 08:06:38 +0200 Subject: [PATCH 081/495] Apply iterate in fwd mode (#1472) --- src/rules/jitrules.jl | 305 ++++++++++++++++++++++++++++++++++++++++-- test/runtests.jl | 225 +++++++++++++++++++++++++++++++ 2 files changed, 516 insertions(+), 14 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index b16fe9c481..76b72466c1 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -306,11 +306,270 @@ end return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) end +@inline concat() = () +@inline concat(a) = a +@inline concat(a, b) = (a..., b...) +@inline concat(a, b, c...) = concat(concat(a, b), c...) + +@inline iterate_unwrap_inner_fwd(x::Const) = (map(Const, x.val)...,) +@inline iterate_unwrap_inner_fwd(x::Duplicated) = (map(Duplicated, x.val, x.dval)...,) +@inline batch_dup_tuple(x, vals...) = BatchDuplicated(x, (vals...,)) +@inline iterate_unwrap_inner_fwd(x::BatchDuplicated) = (map(batch_dup_tuple, x.val, x.dval...)...,) + +@inline function iterate_unwrap_fwd(args...) + ntuple(Val(length(args))) do i + Base.@_inline_meta + iterate_unwrap_inner_fwd(args[i]) + end +end + +# This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] +function fwddiff_with_return(::Val{width}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {width, Nargs} + tt′ = Enzyme.vaTypeof(args...) + ReturnPrimal = Val(true) + RT = A + ModifiedBetween = Val(Enzyme.falses_from_args(Nargs+1)) + + tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} + world = codegen_world_age(Core.Typeof(f.val), tt) + + thunk(Val(world), FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), + ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI)(f, args...) +end + +function body_runtime_iterate_fwd(N, Width, wrapped, primtypes) + nnothing = ntuple(i->nothing, Val(Width+1)) + nres = ntuple(i->:(res[1]), Val(Width+1)) + ModifiedBetween = ntuple(i->false, Val(N+1)) + ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) + Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) + return quote + args0 = ($(wrapped...),) + args = concat(iterate_unwrap_fwd(args0...)...) + + dupClosure = ActivityTup[1] + FT = Core.Typeof(f) + if dupClosure && guaranteed_const(FT) + dupClosure = false + end + + tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ForwardMode) + + annotation = @static if $Width != 1 + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + BatchDuplicated{rt, $Width} + else + Const{rt} + end + else + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + Duplicated{rt} + else + Const{rt} + end + end + + res = fwddiff_with_return(Val($Width), dupClosure ? Duplicated(f, df) : Const(f), annotation, args...) + return if annotation <: Const + ReturnType(($(nres...),)) + else + if $Width == 1 + ReturnType((res[1], res[2])) + else + ReturnType((res[1], res[2]...)) + end + end + end +end + +function func_runtime_iterate_fwd(N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _ = setup_macro_wraps(true, N, Width) + body = body_runtime_iterate_fwd(N, Width, wrapped, primtypes) + + quote + function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, ReturnType, F, DF, $(typeargs...)} + $body + end + end +end + +@generated function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} + N = div(length(allargs)+2, Width+1)-1 + _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(true, N, Width, :allargs) + return body_runtime_iterate_fwd(N, Width, wrapped, primtypes) +end + +function body_runtime_iterate_augfwd(N, Width, wrapped, primttypes) + nnothing = ntuple(i->nothing, Val(Width+1)) + nres = ntuple(i->:(origRet), Val(Width+1)) + nzeros = ntuple(i->:(Ref(zero(resT))), Val(Width)) + nres3 = ntuple(i->:(res[3]), Val(Width)) + ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) + Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) + + return quote + args = ($(wrapped...),) + throw(AssertionError("Runtime iterate augmented forward pass unhandled, f=$f df=$df args=$args")) + + # TODO: Annotation of return value + # tt0 = Tuple{$(primtypes...)} + tt′ = Tuple{$(Types...)} + rt = Core.Compiler.return_type(f, Tuple{$(ElTypes...)}) + annotation = guess_activity(rt, API.DEM_ReverseModePrimal) + + dupClosure = ActivityTup[1] + FT = Core.Typeof(f) + if dupClosure && guaranteed_const(FT) + dupClosure = false + end + + world = codegen_world_age(FT, Tuple{$(ElTypes...)}) + + forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, + annotation, tt′, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + + internal_tape, origRet, initShadow = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) + resT = typeof(origRet) + if annotation <: Const + shadow_return = nothing + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType(($(nres...), tape)) + elseif annotation <: Active + if $Width == 1 + shadow_return = Ref(make_zero(origRet)) + else + shadow_return = ($(nzeros...),) + end + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + if $Width == 1 + return ReturnType((origRet, shadow_return, tape)) + else + return ReturnType((origRet, shadow_return..., tape)) + end + end + + @assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed + + shadow_return = nothing + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + if $Width == 1 + return ReturnType((origRet, initShadow, tape)) + else + return ReturnType((origRet, initShadow..., tape)) + end + end +end + +function func_runtime_iterate_augfwd(N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _ = setup_macro_wraps(false, N, Width) + body = body_runtime_iterate_augfwd(N, Width, wrapped, primtypes) + + quote + function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} + $body + end + end +end + +@generated function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, ReturnType, F, DF} + N = div(length(allargs)+2, Width+1)-1 + _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(false, N, Width, :allargs) + return body_runtime_iterate_augfwd(N, Width, wrapped, primtypes) +end + +function body_runtime_iterate_rev(N, Width, wrapped, primttypes, shadowargs) + outs = [] + for i in 1:N + for w in 1:Width + expr = if Width == 1 + :(tup[$i]) + else + :(tup[$i][$w]) + end + shad = shadowargs[i][w] + out = :(if tup[$i] === nothing + elseif $shad isa Base.RefValue + $shad[] = recursive_add($shad[], $expr) + else + error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)) + end + ) + push!(outs, out) + end + end + shadow_ret = nothing + if Width == 1 + shadowret = :(tape.shadow_return[]) + else + shadowret = [] + for w in 1:Width + push!(shadowret, :(tape.shadow_return[$w][])) + end + shadowret = :(($(shadowret...),)) + end + + ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) + Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) + + quote + args = ($(wrapped...),) + throw(AssertionError("Runtime iterate reverse pass unhandled, f=$f df=$df args=$args")) + + # TODO: Annotation of return value + # tt0 = Tuple{$(primtypes...)} + tt = Tuple{$(ElTypes...)} + tt′ = Tuple{$(Types...)} + rt = Core.Compiler.return_type(f, tt) + annotation = guess_activity(rt, API.DEM_ReverseModePrimal) + + dupClosure = ActivityTup[1] + FT = Core.Typeof(f) + if dupClosure && guaranteed_const(FT) + dupClosure = false + end + world = codegen_world_age(FT, tt) + + forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + if tape.shadow_return !== nothing + args = (args..., $shadowret) + end + + tup = adjoint(dupClosure ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] + + $(outs...) + return nothing + end +end + +function func_runtime_iterate_rev(N, Width) + _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width) + body = body_runtime_iterate_rev(N, Width, wrapped, primtypes, batchshadowargs) + + quote + function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)} + $body + end + end +end + +@generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} + N = div(length(allargs)+2, Width+1)-1 + _, _, primtypes, _, _, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width, :allargs) + return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) +end + # Create specializations for (N, Width) in Iterators.product(0:30, 1:10) eval(func_runtime_generic_fwd(N, Width)) eval(func_runtime_generic_augfwd(N, Width)) eval(func_runtime_generic_rev(N, Width)) + eval(func_runtime_iterate_fwd(N, Width)) + eval(func_runtime_iterate_augfwd(N, Width)) + eval(func_runtime_iterate_rev(N, Width)) end function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false) @@ -776,24 +1035,42 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) unsafe_store!(shadowR, shadowres.ref) return false end - emit_error(B, orig, "Enzyme: Not yet implemented, forward for jl_f__apply_iterate") - if unsafe_load(shadowR) != C_NULL - cal = new_from_original(gutils, orig) - width = get_width(gutils) - if width == 1 - shadow = cal - else - ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) - shadow = LLVM.UndefValue(ST) - for i in 1:width - shadow = insert_value!(B, shadow, cal, i-1) - if i == 1 - API.moveBefore(cal, shadow, B) + + if v && isiter == Base.iterate + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + + sret = generic_setup(orig, runtime_iterate_fwd, AnyArray(1+Int(width)), gutils, #=start=#offset+2, B, false) + AT = LLVM.ArrayType(T_prjlvalue, 1+Int(width)) + if unsafe_load(shadowR) != C_NULL + if width == 1 + gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + shadow = LLVM.load!(B, T_prjlvalue, gep) + else + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + shadow = LLVM.UndefValue(ST) + for i in 1:width + gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + ld = LLVM.load!(B, T_prjlvalue, gep) + shadow = insert_value!(B, shadow, ld, i-1) end end + unsafe_store!(shadowR, shadow.ref) end - unsafe_store!(shadowR, shadow.ref) + + if unsafe_load(normalR) != C_NULL + normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + unsafe_store!(normalR, normal.ref) + else + # Delete the primal code + ni = new_from_original(gutils, orig) + erase_with_placeholder(gutils, ni, orig) + end + return false end + + emit_error(B, orig, "Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate "*string((v, v2, isiter, istup, length(operands(orig)), offset+4))) + return false end diff --git a/test/runtests.jl b/test/runtests.jl index f0f2aa9fa0..4988151fae 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1647,6 +1647,231 @@ end end + +concat() = () +concat(a) = a +concat(a, b) = (a..., b...) +concat(a, b, c...) = concat(concat(a, b), c...) + +metaconcat(x) = concat(x...) + +metaconcat2(x, y) = concat(x..., y...) + +midconcat(x, y) = (x, concat(y...)...) + +metaconcat3(x, y, z) = concat(x..., y..., z...) + +@testset "Forward Apply iterate" begin + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(13.7, 15.2), (100.02, 304.1)] + + dres, = Enzyme.autodiff(Forward, metaconcat, Duplicated(x, dx)) + @test length(dres) == 4 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(x, dx)) + @test length(res) == 4 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + @test length(dres) == 4 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + + a = [("a", "b"), ("c", "d")] + da = [("e", "f"), ("g", "h")] + + dres, = Enzyme.autodiff(Forward, metaconcat, Duplicated(a, da)) + @test length(dres) == 4 + @test dres[1] == "a" + @test dres[2] == "b" + @test dres[3] == "c" + @test dres[4] == "d" + + res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(a, da)) + @test length(res) == 4 + @test res[1] == "a" + @test res[2] == "b" + @test res[3] == "c" + @test res[4] == "d" + @test length(dres) == 4 + @test dres[1] == "a" + @test dres[2] == "b" + @test dres[3] == "c" + @test dres[4] == "d" + + + Enzyme.autodiff(Forward, metaconcat, Const(a)) + + dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Duplicated(a, da)) + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" + + res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Duplicated(a, da)) + @test length(res) == 5 + @test res[1] ≈ 1.0 + @test res[2] == "a" + @test res[3] == "b" + @test res[4] == "c" + @test res[5] == "d" + + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" + + + dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Const(a)) + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" + + res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Const(a)) + @test length(res) == 5 + @test res[1] ≈ 1.0 + @test res[2] == "a" + @test res[3] == "b" + @test res[4] == "c" + @test res[5] == "d" + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" + + y = [(-92.0, -93.0), (-97.9, -911.2)] + dy = [(-913.7, -915.2), (-9100.02, -9304.1)] + + dres, = Enzyme.autodiff(Forward, metaconcat2, Duplicated(x, dx), Duplicated(y, dy)) + @test length(dres) == 8 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + @test dres[5] ≈ -913.7 + @test dres[6] ≈ -915.2 + @test dres[7] ≈ -9100.02 + @test dres[8] ≈ -9304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat2, Duplicated, Duplicated(x, dx), Duplicated(y, dy)) + @test length(res) == 8 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + @test res[5] ≈ -92.0 + @test res[6] ≈ -93.0 + @test res[7] ≈ -97.9 + @test res[8] ≈ -911.2 + @test length(dres) == 8 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + @test dres[5] ≈ -913.7 + @test dres[6] ≈ -915.2 + @test dres[7] ≈ -9100.02 + @test dres[8] ≈ -9304.1 + + + dres, = Enzyme.autodiff(Forward, metaconcat3, Duplicated(x, dx), Const(a), Duplicated(y, dy)) + @test length(dres) == 12 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + @test dres[5] == "a" + @test dres[6] == "b" + @test dres[7] == "c" + @test dres[8] == "d" + + @test dres[9] ≈ -913.7 + @test dres[10] ≈ -915.2 + @test dres[11] ≈ -9100.02 + @test dres[12] ≈ -9304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat3, Duplicated, Duplicated(x, dx), Const(a), Duplicated(y, dy)) + @test length(res) == 12 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + + @test res[5] == "a" + @test res[6] == "b" + @test res[7] == "c" + @test res[8] == "d" + + @test res[9] ≈ -92.0 + @test res[10] ≈ -93.0 + @test res[11] ≈ -97.9 + @test res[12] ≈ -911.2 + + @test length(dres) == 12 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + @test dres[5] == "a" + @test dres[6] == "b" + @test dres[7] == "c" + @test dres[8] == "d" + + @test dres[9] ≈ -913.7 + @test dres[10] ≈ -915.2 + @test dres[11] ≈ -9100.02 + @test dres[12] ≈ -9304.1 + + + dres, = Enzyme.autodiff(Forward, metaconcat, BatchDuplicated(x, (dx, dy))) + @test length(dres[1]) == 4 + @test dres[1][1] ≈ 13.7 + @test dres[1][2] ≈ 15.2 + @test dres[1][3] ≈ 100.02 + @test dres[1][4] ≈ 304.1 + @test length(dres[2]) == 4 + @test dres[2][1] ≈ -913.7 + @test dres[2][2] ≈ -915.2 + @test dres[2][3] ≈ -9100.02 + @test dres[2][4] ≈ -9304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, BatchDuplicated(x, (dx, dy))) + @test length(res) == 4 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + @test length(dres[1]) == 4 + @test dres[1][1] ≈ 13.7 + @test dres[1][2] ≈ 15.2 + @test dres[1][3] ≈ 100.02 + @test dres[1][4] ≈ 304.1 + @test length(dres[2]) == 4 + @test dres[2][1] ≈ -913.7 + @test dres[2][2] ≈ -915.2 + @test dres[2][3] ≈ -9100.02 + @test dres[2][4] ≈ -9304.1 +end + @testset "Dynamic Val Construction" begin dyn_f(::Val{D}) where D = prod(D) From 8c88c076232023194b58a0d1770c0b9bcdd3bdb7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 27 May 2024 00:38:54 +0200 Subject: [PATCH 082/495] Consider noalias info from julia custom rules (#1467) * Consider noalias info from julia custom rules * Update Project.toml * add noalias call attr * mightalias * fix ret attr * remove null ptr error message * Better methoderror for fwd * fixup * fixup * fixup * bump ecore version --- Project.toml | 2 +- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/rules.jl | 15 +++++++++ src/absint.jl | 11 +++---- src/compiler.jl | 28 +++++++++++++++- src/rules/customrules.jl | 65 +++++++++++++++++++++++-------------- test/kwrrules.jl | 2 +- test/kwrules.jl | 2 +- test/rules.jl | 4 +-- 9 files changed, 93 insertions(+), 38 deletions(-) diff --git a/Project.toml b/Project.toml index 003f360e05..ca45bbb85e 100644 --- a/Project.toml +++ b/Project.toml @@ -29,7 +29,7 @@ EnzymeStaticArraysExt = "StaticArrays" [compat] CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.7" +EnzymeCore = "0.7.3" Enzyme_jll = "0.0.115" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 5249f78945..670e1f3014 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.7.2" +version = "0.7.3" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 727ee1b178..398c790087 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -219,6 +219,21 @@ function is_inactive_noinl_from_sig(@nospecialize(TT); return isapplicable(inactive_noinl, TT; world, method_table, caller) end +""" + noalias(func::typeof(f), args...) + +Mark a particular function as always being a fresh allocation which does not alias any other +accessible memory. +""" +function noalias end + +function noalias_from_sig(@nospecialize(TT); + world::UInt=Base.get_world_counter(), + method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, + caller::Union{Nothing,Core.MethodInstance}=nothing) + return isapplicable(noalias, TT; world, method_table, caller) +end + """ inactive_type(::Type{Ty}) diff --git a/src/absint.jl b/src/absint.jl index 89ae5577b4..ae9c35a09b 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -113,12 +113,11 @@ function absint(arg::LLVM.Value, partial::Bool=false) end ptr = unsafe_load(reinterpret(Ptr{Ptr{Cvoid}}, convert(UInt, ce))) if ptr == C_NULL - # XXX: Is this correct? - bt = GPUCompiler.backtrace(arg) - btstr = sprint() do io - Base.show_backtrace(io, bt) - end - @error "Found null pointer at\n $btstr" arg + # bt = GPUCompiler.backtrace(arg) + # btstr = sprint() do io + # Base.show_backtrace(io, bt) + # end + # @error "Found null pointer at\n $btstr" arg return (false, nothing) end typ = Base.unsafe_pointer_to_objref(ptr) diff --git a/src/compiler.jl b/src/compiler.jl index f80490d5f0..c2fae5d90d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4863,7 +4863,23 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if llvmfn == primalf actualRetType = k.ci.rettype end + + if EnzymeRules.noalias_from_sig(mi.specTypes; world, method_table, caller) + push!(return_attributes(llvmfn), EnumAttribute("noalias")) + for u in LLVM.uses(llvmfn) + c = LLVM.user(u) + if !isa(c, LLVM.CallInst) + continue + end + cf = LLVM.called_operand(c) + if cf == llvmfn + LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0)) + end + end + end + func = mi.specTypes.parameters[1] + meth = mi.def name = meth.name jlmod = meth.module @@ -4891,7 +4907,6 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; continue end - func = mi.specTypes.parameters[1] sparam_vals = mi.specTypes.parameters[2:end] # mi.sparam_vals if func == typeof(Base.eps) || func == typeof(Base.nextfloat) || func == typeof(Base.prevfloat) @@ -4912,6 +4927,17 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; ]) continue end + if func == typeof(Base.mightalias) + handleCustom(llvmfn, "jl_mightalias", + [EnumAttribute("readonly", 0), + StringAttribute("enzyme_shouldrecompute"), + StringAttribute("enzyme_inactive"), + StringAttribute("enzyme_no_escaping_allocation"), + EnumAttribute("nofree"), + StringAttribute("enzyme_ta_norecur"), + ], true, false) + continue + end if func == typeof(Base.Threads.threadid) || func == typeof(Base.Threads.nthreads) name = (func == typeof(Base.Threads.threadid)) ? "jl_threadid" : "jl_nthreads" handleCustom(llvmfn, name, diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 7ec09e2c1d..c65da1585a 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -243,6 +243,10 @@ function enzyme_custom_setup_ret(gutils, orig, mi, RealRt) return RT, needsPrimal, needsShadowP[] != 0, origNeedsPrimal end +function custom_rule_method_error(world, fn, args...) + throw(MethodError(fn, (args...,), world)) +end + function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true @@ -305,20 +309,24 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) @safe_debug "Applying custom forward rule (kwcall)" TT llvmf = nested_codegen!(mode, mod, kwfunc, TT, world) fwd_RT = Core.Compiler.return_type(kwfunc, TT, world) + else + TT = Tuple{typeof(world), typeof(kwfunc), TT.parameters...} + llvmf = nested_codegen!(mode, mod, custom_rule_method_error, TT, world) + pushfirst!(args, LLVM.ConstantInt(world)) + fwd_RT = Union{} end else if EnzymeRules.isapplicable(EnzymeRules.forward, TT; world) @safe_debug "Applying custom forward rule" TT llvmf = nested_codegen!(mode, mod, EnzymeRules.forward, TT, world) fwd_RT = Core.Compiler.return_type(EnzymeRules.forward, TT, world) + else + TT = Tuple{typeof(world), typeof(EnzymeRules.forward), TT.parameters...} + llvmf = nested_codegen!(mode, mod, custom_rule_method_error, TT, world) + pushfirst!(args, LLVM.ConstantInt(world)) + fwd_RT = Union{} end end - - if llvmf === nothing - @safe_debug "No custom forward rule is applicable for" TT - emit_error(B, orig, "Enzyme: No custom rule was applicable for " * string(TT)) - return false - end push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) @@ -340,7 +348,6 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) sret = nothing end - if length(args) != length(parameters(llvmf)) GPUCompiler.@safe_error "Calling convention mismatch", args, llvmf, orig, isKWCall, kwtup, TT, sret, returnRoots return false @@ -524,6 +531,11 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, ami = GPUCompiler.methodinstance(Core.Typeof(kwfunc), augprimal_TT, world) @safe_debug "Applying custom augmented_primal rule (kwcall)" TT=augprimal_TT catch e + augprimal_TT = Tuple{typeof(world), typeof(kwfunc), augprimal_TT.parameters...} + ami = GPUCompiler.methodinstance(typeof(custom_rule_method_error), augprimal_TT, world) + if forward + pushfirst!(args, LLVM.ConstantInt(world)) + end end else @assert kwtup === nothing @@ -535,20 +547,19 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, ami = GPUCompiler.methodinstance(Core.Typeof(EnzymeRules.augmented_primal), augprimal_TT, world) @safe_debug "Applying custom augmented_primal rule" TT=augprimal_TT catch e + augprimal_TT = Tuple{typeof(world), typeof(EnzymeRules.augmented_primal), augprimal_TT.parameters...} + ami = GPUCompiler.methodinstance(typeof(custom_rule_method_error), augprimal_TT, world) + if forward + pushfirst!(args, LLVM.ConstantInt(world)) + end end end - - if ami !== nothing - target = DefaultCompilerTarget() - params = PrimalCompilerParams(mode) - job = CompilerJob(ami, CompilerConfig(target, params; kernel=false), world) - interp = GPUCompiler.get_interpreter(job) - aug_RT = something(Core.Compiler.typeinf_type(interp, ami.def, ami.specTypes, ami.sparam_vals), Any) - else - @safe_debug "No custom augmented_primal rule is applicable for" augprimal_TT - emit_error(B, orig, "Enzyme: No custom augmented_primal rule was applicable for " * string(augprimal_TT)) - return C_NULL - end + + target = DefaultCompilerTarget() + params = PrimalCompilerParams(mode) + aug_RT = something(Core.Compiler.typeinf_type(GPUCompiler.get_interpreter(CompilerJob(ami, CompilerConfig(target, params; kernel=false), world)), ami.def, ami.specTypes, ami.sparam_vals), Any) + + @assert ami !== nothing if kwtup !== nothing && kwtup <: Duplicated @safe_debug "Non-constant keyword argument found for " augprimal_TT @@ -596,20 +607,24 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, @safe_debug "Applying custom reverse rule (kwcall)" TT=rev_TT llvmf = nested_codegen!(mode, mod, rkwfunc, rev_TT, world) rev_RT = Core.Compiler.return_type(rkwfunc, rev_TT, world) + else + rev_TT = Tuple{typeof(world), typeof(rkwfunc), rev_TT.parameters...} + llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) + pushfirst!(args, LLVM.ConstantInt(world)) + rev_RT = Union{} end else if EnzymeRules.isapplicable(EnzymeRules.reverse, rev_TT; world) @safe_debug "Applying custom reverse rule" TT=rev_TT llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, rev_TT, world) rev_RT = Core.Compiler.return_type(EnzymeRules.reverse, rev_TT, world) + else + rev_TT = Tuple{typeof(world), typeof(EnzymeRules.reverse), rev_TT.parameters...} + llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) + pushfirst!(args, LLVM.ConstantInt(world)) + rev_RT = Union{} end end - - if llvmf == nothing - @safe_debug "No custom reverse rule is applicable for" rev_TT - emit_error(B, orig, "Enzyme: No custom reverse rule was applicable for " * string(rev_TT)) - return C_NULL - end end push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) diff --git a/test/kwrrules.jl b/test/kwrrules.jl index 72708d993b..a62ba94608 100644 --- a/test/kwrrules.jl +++ b/test/kwrrules.jl @@ -61,7 +61,7 @@ end # Test that this errors due to missing kwargs in rule definition g2(x, y) = f_kw2(x; val=y) -@test_throws Enzyme.Compiler.EnzymeRuntimeException autodiff(Reverse, g2, Active(2.0), Const(42.0))[1][1] +@test_throws MethodError autodiff(Reverse, g2, Active(2.0), Const(42.0))[1][1] function f_kw3(x; val=nothing) diff --git a/test/kwrules.jl b/test/kwrules.jl index 13f916c65d..91d3dc859d 100644 --- a/test/kwrules.jl +++ b/test/kwrules.jl @@ -31,7 +31,7 @@ end # Test that this errors due to missing kwargs in rule definition g2(x, y) = f_kw2(x; val=y) -@test_throws Enzyme.Compiler.EnzymeRuntimeException autodiff(Forward, g2, Duplicated(2.0, 1.0), Const(42.0))[1] ≈ 14.0 +@test_throws MethodError autodiff(Forward, g2, Duplicated(2.0, 1.0), Const(42.0))[1] ≈ 14.0 function f_kw3(x; val=nothing) x^2 diff --git a/test/rules.jl b/test/rules.jl index 4c2db62bf1..b6644d8c55 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -87,11 +87,11 @@ function forward(func::Const{typeof(g)}, ::Type{<:Const}, x::Const) end @testset "Registry" begin - @test_throws Enzyme.Compiler.EnzymeRuntimeException Enzyme.autodiff(Forward, g, Duplicated(1.0, 1.0)) + @test_throws MethodError Enzyme.autodiff(Forward, g, Duplicated(1.0, 1.0)) rh(cond, x) = cond ? g(x) : x @test Enzyme.autodiff(Forward, rh, Const(false), Duplicated(1.0, 1.0)) == (1.0,) - @test_throws Enzyme.Compiler.EnzymeRuntimeException Enzyme.autodiff(Forward, rh, Const(true), Duplicated(1.0, 1.0)) + @test_throws MethodError Enzyme.autodiff(Forward, rh, Const(true), Duplicated(1.0, 1.0)) end function alloc_sq(x) From cd569d6d55168bfa21bbe8c742b909664439f1e1 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 27 May 2024 01:17:20 +0200 Subject: [PATCH 083/495] rewrite calls (#1453) * rewrite calls * Fix * remove nulls * fix --- src/compiler.jl | 5 ++- src/compiler/validation.jl | 73 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index c2fae5d90d..5a23126498 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2247,7 +2247,10 @@ end function get_julia_inner_types(B, p, startvals...; added=[]) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - vals = LLVM.Value[p] + vals = LLVM.Value[] + if p != nothing + push!(vals, p) + end todo = LLVM.Value[startvals...] while length(todo) != 0 cur = popfirst!(todo) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 03e7a4458e..f8fa3a4cd2 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -161,6 +161,78 @@ function check_ir(job, mod::LLVM.Module) end end +# Rewrite calls with "jl_roots" to only have the jl_value_t attached and not { { {} addrspace(10)*, [1 x [2 x i64]], i64, i64 }, [2 x i64] } %unbox110183_replacementA +function rewrite_ccalls!(mod::LLVM.Module) + for f in collect(functions(mod)) + replaceAndErase = Tuple{Instruction, Instruction}[] + for bb in blocks(f), inst in instructions(bb) + if isa(inst, LLVM.CallInst) + changed = false + newbundles = OperandBundleDef[] + B = IRBuilder() + position!(B, inst) + for bunduse in operand_bundles(inst) + bunduse = LLVM.OperandBundleDef(bunduse) + if LLVM.tag_name(bunduse) != "jl_roots" + push!(newbundles, bunduse) + continue + end + uservals = LLVM.Value[] + subchanged = false + for lval in LLVM.inputs(bunduse) + llty = value_type(lval) + if !isa(llty, LLVM.PointerType) || LLVM.addrspace(llty) != 10 + push!(uservals, lval) + continue + end + vals = get_julia_inner_types(B, nothing, lval) + for v in vals + if isa(v, LLVM.PointerNull) + subchanged = true + continue + end + push!(uservals, v) + end + if length(vals) == 1 && vals[1] == lval + continue + end + subchanged = true + end + if !subchanged + push!(newbundles, bunduse) + continue + end + changed = true + push!(newbundles, OperandBundleDef(LLVM.tag_name(bunduse), uservals)) + end + changed = false + if changed + prevname = LLVM.name(inst) + LLVM.name!(inst, "") + newinst = call!(B, called_type(inst), called_operand(inst), collect(arguments(inst)), newbundles, prevname) + for idx = [LLVM.API.LLVMAttributeFunctionIndex, LLVM.API.LLVMAttributeReturnIndex, [LLVM.API.LLVMAttributeIndex(i) for i in 1:(length(arguments(inst)))]...] + idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx); + Attrs = Base.unsafe_convert(Ptr{LLVM.API.LLVMAttributeRef}, Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef)*count)) + LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) + for j in 1:count + LLVM.API.LLVMAddCallSiteAttribute(newinst, idx, unsafe_load(Attrs, j)) + end + Libc.free(Attrs) + end + API.EnzymeCopyMetadata(newinst, inst) + callconv!(newinst, callconv(inst)) + push!(replaceAndErase, (inst, newinst)) + end + end + end + for (inst, newinst) in replaceAndErase + replace_uses!(inst, newinst) + LLVM.API.LLVMInstructionEraseFromParent(inst) + end + end +end + function check_ir!(job, errors, mod::LLVM.Module) imported = Set(String[]) if haskey(functions(mod), "malloc") @@ -174,6 +246,7 @@ function check_ir!(job, errors, mod::LLVM.Module) replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstPointerCast(mfn, value_type(f)))) unsafe_delete!(mod, f) end + rewrite_ccalls!(mod) for f in collect(functions(mod)) check_ir!(job, errors, imported, f) end From 74d88b4e77c9334dad1bd70b34490e1cf23d9dfa Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 27 May 2024 13:39:44 +0200 Subject: [PATCH 084/495] Make nofree errors nicer (#1474) --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 5a23126498..66c083b71b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1483,7 +1483,7 @@ end function Base.showerror(io::IO, ece::NoDerivativeException) print(io, "Enzyme compilation failed.\n") - if ece.ir !== nothing + if ece.ir !== nothing && !occursin("No create nofree of empty function", ece.msg) print(io, "Current scope: \n") print(io, ece.ir) end From 78a88d6085557572050f94640471c5b0f29e1afd Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 27 May 2024 13:39:56 +0200 Subject: [PATCH 085/495] Mark growat as nofree (#1473) --- src/compiler.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compiler.jl b/src/compiler.jl index 66c083b71b..8e897c2526 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -103,6 +103,7 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( end const nofreefns = Set{String}(( + "ijl_array_grow_at", "jl_array_grow_at", "ijl_try_substrtod", "jl_try_substrtod", "jl_f__apply_iterate", "ijl_field_index", "jl_field_index", From e362c361fccfe128b9d3b3ebc4c2dca563dc163d Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 27 May 2024 13:40:10 +0200 Subject: [PATCH 086/495] Fix 1.6 (#1475) --- test/runtests.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 4988151fae..182d8291df 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1710,6 +1710,7 @@ metaconcat3(x, y, z) = concat(x..., y..., z...) Enzyme.autodiff(Forward, metaconcat, Const(a)) +@static if VERSION ≥ v"1.7-" dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Duplicated(a, da)) @test length(dres) == 5 @test dres[1] ≈ 7.0 @@ -1755,6 +1756,7 @@ metaconcat3(x, y, z) = concat(x..., y..., z...) @test dres[3] == "b" @test dres[4] == "c" @test dres[5] == "d" +end y = [(-92.0, -93.0), (-97.9, -911.2)] dy = [(-913.7, -915.2), (-9100.02, -9304.1)] From 5609c7edf976717f2678dc5d300d1a8bccf8bb64 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 28 May 2024 11:48:53 +0200 Subject: [PATCH 087/495] Nice union{} error (#1479) * Nice union{} error * fixup --- src/Enzyme.jl | 4 +-- src/compiler.jl | 79 +++++++++++++++++++++++++++++++++--------------- test/runtests.jl | 9 ++++++ 3 files changed, 66 insertions(+), 26 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index c75508cd77..911d1801ad 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -230,7 +230,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) end if A <: Active - if !allocatedinline(rt) || rt isa Union + if (!allocatedinline(rt) || rt isa Union) && rt != Union{} forward, adjoint = Enzyme.Compiler.thunk(Val(world), FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI) res = forward(f, args...) tape = res[1] @@ -244,7 +244,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) throw(ErrorException("Duplicated Returns not yet handled")) end - if A <: Active && rt <: Complex + if (A <: Active && rt <: Complex) && rt != Union{} if Holomorphic seen = IdDict() seen2 = IdDict() diff --git a/src/compiler.jl b/src/compiler.jl index 8e897c2526..8cbac14f94 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -577,6 +577,10 @@ struct AdjointThunk{PT, FA, RT, TT, Width, TapeType} <: AbstractThunk{FA, RT, TT adjoint::PT end +struct PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World} <: AbstractThunk{FA, RT, TT, Width} + adjoint::PT +end + @inline return_type(::AbstractThunk{FA, RT}) where {FA, RT} = RT @inline return_type(::Type{AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeType}}) where {PT, FA, RT, TT, Width, ReturnPrimal, TapeType} = RT @@ -5277,7 +5281,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; cf = LLVM.called_operand(tmp) if isa(cf, LLVM.Function) nm = LLVM.name(cf) - if nm == "gpu_signal_exception" || nm == "gpu_report_exception" + if nm == "gpu_signal_exception" || nm == "gpu_report_exception" || nm == "ijl_throw" || nm == "jl_throw" shouldemit = false break end @@ -5433,6 +5437,9 @@ struct CompileResult{AT, PT} TapeType::Type end +@inline (thunk::PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World})(fn::FA, args...) where {PT, FA, RT, TT, Width, ReturnPrimal, World} = +enzyme_call(Val(false), thunk.adjoint, PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) + @inline (thunk::CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal} = enzyme_call(Val(false), thunk.adjoint, CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) @@ -5536,7 +5543,9 @@ end end @inline function default_adjoint(T) - if T <: AbstractFloat + if T == Union{} + return nothing + elseif T <: AbstractFloat return one(T) elseif T <: Complex error("Attempted to use automatic pullback (differential return value) deduction on a either a type unstable function returning an active complex number, or autodiff_deferred returning an active complex number. For the first case, please type stabilize your code, e.g. by specifying autodiff(Reverse, f->f(x)::Complex, ...). For the second case, please use regular non-deferred autodiff") @@ -5559,7 +5568,7 @@ end JuliaContext() do ctx F = eltype(FA) - is_forward = CC <: AugmentedForwardThunk || CC <: ForwardModeThunk + is_forward = CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk is_adjoint = CC <: AdjointThunk || CC <: CombinedAdjointThunk is_split = CC <: AdjointThunk || CC <: AugmentedForwardThunk needs_tape = CC <: AdjointThunk @@ -5569,23 +5578,33 @@ end argtypes = DataType[argtt.parameters...] argexprs = Union{Expr, Symbol}[:(args[$i]) for i in 1:N] - if !RawCall + if false && CC <: PrimalErrorThunk + primargs = [quote + convert($(eltype(T)), $(argexprs[i]).val) + end for (i, T) in enumerate(argtypes)] + return quote + fn.val($(primargs...)) + error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up") + end + end + + if !RawCall && !(CC <: PrimalErrorThunk) if rettype <: Active if length(argtypes) + is_adjoint + needs_tape != length(argexprs) return quote - throw(MethodError($CC($fptr), $args)) + throw(MethodError($CC(fptr), $args)) end end elseif rettype <: Const if length(argtypes) + needs_tape != length(argexprs) return quote - throw(MethodError($CC($fptr), $args)) + throw(MethodError($CC(fptr), $args)) end end else if length(argtypes) + needs_tape != length(argexprs) return quote - throw(MethodError($CC($fptr), $args)) + throw(MethodError($CC(fptr), $args)) end end end @@ -5593,8 +5612,10 @@ end types = DataType[] - if eltype(rettype) === Union{} - error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up") + if eltype(rettype) === Union{} && false + return quote + error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up") + end end if !(rettype <: Const) && (isghostty(eltype(rettype)) || Core.Compiler.isconstType(eltype(rettype)) || eltype(rettype) === DataType) rrt = eltype(rettype) @@ -5665,7 +5686,9 @@ end end continue end - + if CC <: PrimalErrorThunk + continue + end if T <: Active if is_adjoint if width == 1 @@ -5752,8 +5775,10 @@ end end push!(sret_types, NT) end - - @assert i == length(argexprs)+1 + + if !(CC <: PrimalErrorThunk) + @assert i == length(argexprs)+1 + end # Tape if CC <: AugmentedForwardThunk @@ -5785,7 +5810,7 @@ end T_void = convert(LLVMType, Nothing) - combinedReturn = Tuple{sret_types...} + combinedReturn = (CC <: PrimalErrorThunk && eltype(rettype) == Union{}) ? Union{} : Tuple{sret_types...} if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types) combinedReturn = AnonymousStruct(combinedReturn) end @@ -6003,29 +6028,30 @@ end params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) - sig = Tuple{eltype(FA), map(eltype, TT.parameters)...} - interp = GPUCompiler.get_interpreter(tmp_job) # TODO check compile return here, early # rrt = Core.Compiler.return_type(f, primal.tt) # nothing rrt = something(Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), Any) + rrt = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype + + run_enzyme = true if rrt == Union{} - estr = "Function to differentiate `$mi` is guaranteed to return an error and doesn't make sense to autodiff. Giving up" - return quote - error($estr) - end + run_enzyme = false + A = Const end - if !(A <: Const) && guaranteed_const_nongen(rrt, World) + if run_enzyme && !(A <: Const) && guaranteed_const_nongen(rrt, World) estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant" return quote error($estr) end end - rt2 = if A isa UnionAll + rt2 = if !run_enzyme + Const{rrt} + elseif A isa UnionAll A{rrt} else @assert A isa DataType @@ -6034,7 +6060,7 @@ end A end - params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) + params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, run_enzyme, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) # We need to use primal as the key, to lookup the right method @@ -6045,7 +6071,13 @@ end compile_result = cached_compilation(job) - if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient + if !run_enzyme + ErrT = PrimalErrorThunk{typeof(compile_result.adjoint), FA, rt2, TT, width, ReturnPrimal, World} + return quote + Base.@_inline_meta + $ErrT($(compile_result.adjoint)) + end + elseif Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient TapeType = compile_result.TapeType AugT = AugmentedForwardThunk{typeof(compile_result.primal), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal, TapeType} AdjT = AdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, TapeType} @@ -6086,7 +6118,6 @@ import GPUCompiler: deferred_codegen_jobs params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI) tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) - sig = Tuple{eltype(FA), map(eltype, TT.parameters)...} interp = GPUCompiler.get_interpreter(tmp_job) rrt = something(Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), Any) diff --git a/test/runtests.jl b/test/runtests.jl index 182d8291df..f7132e1d75 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2602,6 +2602,15 @@ end @test 2.0 ≈ Enzyme.autodiff(Reverse, unionret, Active, Active(2.0), Duplicated(out, dout), Const(true))[1][1] end + +function assured_err(x) + throw(AssertionError("foo")) +end + +@testset "UnionAll" begin + @test_throws AssertionError Enzyme.autodiff(Reverse, assured_err, Active, Active(2.0)) +end + struct MyFlux end From edc9b9d75137eab7799b6000186b95514382d621 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 27 May 2024 09:26:06 -0400 Subject: [PATCH 088/495] Add broadcast noalias tests --- Project.toml | 2 +- test/runtests.jl | 62 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ca45bbb85e..7bcf3a8438 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.3" -Enzyme_jll = "0.0.115" +Enzyme_jll = "0.0.117" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" diff --git a/test/runtests.jl b/test/runtests.jl index f7132e1d75..d37adfc1d6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2289,6 +2289,68 @@ end end end +function bc0_test_function(ps) + z = view(ps, 26:30) + C = Matrix{Float64}(undef, 5, 1) + C .= z + return C[1] +end + +@noinline function bc1_bcs2(x, y) + x != y && error(2) + return x +end + +@noinline function bc1_affine_normalize(x::AbstractArray) + # _axes = broadcast_shape(axes(x), axes(x)) #Broadcast.combine_axes(x, x) + _axes = bc1_bcs2(axes(x), axes(x)) + i = Broadcast.Broadcasted(Base.Broadcast.DefaultArrayStyle{2}(), +, (x,), _axes) + + dest = similar(Array{Float32}, _axes) + bc = convert(Broadcast.Broadcasted{Nothing}, i) + + # mycopyto!(dest, bc) + copyto!(dest, bc) + return x +end + +function bc1_loss_function(x) + return bc1_affine_normalize(x)[1] +end + +function bc2_affine_normalize(::typeof(identity), x::AbstractArray, xmean, xvar, + scale::AbstractArray, bias::AbstractArray, epsilon::Real) + _scale = @. scale / sqrt(xvar + epsilon) + _bias = @. bias - xmean * _scale + return @. x * _scale + _bias +end + +function bc2_loss_function(x, scale, bias) + x_ = reshape(x, 6, 6, 3, 2, 2) + scale_ = reshape(scale, 1, 1, 3, 2, 1) + bias_ = reshape(bias, 1, 1, 3, 2, 1) + + xmean = mean(x_, dims=(1, 2, 5)) + xvar = var(x_, corrected=false, mean=xmean, dims=(1, 2, 5)) + + return sum(abs2, bc2_affine_normalize(identity, x_, xmean, xvar, scale_, bias_, 1e-5)) +end + +@testset "Broadcast noalias" begin + + x = ones(30) + autodiff(Reverse, bc0_test_function, Active, Const(x)) + + x = rand(Float32, 2, 3) + Enzyme.autodiff(Reverse, bc1_loss_function, Duplicated(x, zero(x))) + + x = rand(Float32, 6, 6, 6, 2) + sc = rand(Float32, 6) + bi = rand(Float32, 6) + Enzyme.autodiff(Reverse, bc2_loss_function, Active, Duplicated(x, Enzyme.make_zero(x)), + Duplicated(sc, Enzyme.make_zero(sc)), Duplicated(bi, Enzyme.make_zero(bi))) +end + @testset "GetField" begin mutable struct MyType x::Float64 From 63a6759c8cc0dc9f08e7f70b120893f90268b52b Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 28 May 2024 07:04:03 -0400 Subject: [PATCH 089/495] fix test --- test/runtests.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index d37adfc1d6..3db032fefd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2302,14 +2302,9 @@ end end @noinline function bc1_affine_normalize(x::AbstractArray) - # _axes = broadcast_shape(axes(x), axes(x)) #Broadcast.combine_axes(x, x) _axes = bc1_bcs2(axes(x), axes(x)) - i = Broadcast.Broadcasted(Base.Broadcast.DefaultArrayStyle{2}(), +, (x,), _axes) - dest = similar(Array{Float32}, _axes) - bc = convert(Broadcast.Broadcasted{Nothing}, i) - - # mycopyto!(dest, bc) + bc = convert(Broadcast.Broadcasted{Nothing}, Broadcast.instantiate(Base.broadcasted(+, x, x))) copyto!(dest, bc) return x end From cf1851b101eba251fb52cbe50a3cd843ab76c401 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 28 May 2024 08:03:27 -0400 Subject: [PATCH 090/495] Only run broadcast noalias on 1.8+ --- test/runtests.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 3db032fefd..94809a7f3e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2331,6 +2331,7 @@ function bc2_loss_function(x, scale, bias) return sum(abs2, bc2_affine_normalize(identity, x_, xmean, xvar, scale_, bias_, 1e-5)) end +@static if VERSION ≥ v"1.8-" @testset "Broadcast noalias" begin x = ones(30) @@ -2345,6 +2346,7 @@ end Enzyme.autodiff(Reverse, bc2_loss_function, Active, Duplicated(x, Enzyme.make_zero(x)), Duplicated(sc, Enzyme.make_zero(sc)), Duplicated(bi, Enzyme.make_zero(bi))) end +end @testset "GetField" begin mutable struct MyType From 1e45f264dbd2dacd79686c891f2d8c42ead33fce Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 28 May 2024 14:57:23 +0200 Subject: [PATCH 091/495] Allow custom rule for constant arg/ret in rev mode (#1371) * Allow custom rule for constant arg/ret in rev mode * cse * Add differential use handler * fixup * fix * fix * fixup * fixup * fixup --- src/api.jl | 4 +- src/gradientutils.jl | 4 +- src/rules/customrules.jl | 211 +++++++++++++++++++++++++-------------- src/rules/llvmrules.jl | 11 ++ 4 files changed, 154 insertions(+), 76 deletions(-) diff --git a/src/api.jl b/src/api.jl index e7d37f8cb5..c0632d2600 100644 --- a/src/api.jl +++ b/src/api.jl @@ -232,6 +232,8 @@ EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall((:EnzymeRegisterFwdCallHan EnzymeInsertValue(B::LLVM.IRBuilder, v::LLVM.Value, v2::LLVM.Value, insts::Vector{Cuint}, name="") = LLVM.Value(ccall((:EnzymeInsertValue, libEnzyme), LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVMValueRef, LLVMValueRef, Ptr{Cuint}, Int64, Cstring), B, v, v2, insts, length(insts), name)) +const CustomDiffUse = Ptr{Cvoid} +EnzymeRegisterDiffUseCallHandler(name, handle) = ccall((:EnzymeRegisterDiffUseCallHandler, libEnzyme), Cvoid, (Cstring, CustomDiffUse), name, handle) EnzymeSetCalledFunction(ci::LLVM.CallInst, fn::LLVM.Function, toremove) = ccall((:EnzymeSetCalledFunction, libEnzyme), Cvoid, (LLVMValueRef, LLVMValueRef, Ptr{Int64}, Int64), ci, fn, toremove, length(toremove)) EnzymeCloneFunctionWithoutReturnOrArgs(fn::LLVM.Function, keepret, args) = ccall((:EnzymeCloneFunctionWithoutReturnOrArgs, libEnzyme), LLVMValueRef, (LLVMValueRef,UInt8,Ptr{Int64}, Int64), fn, keepret, args, length(args)) EnzymeGetShadowType(width, T) = ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64,LLVMTypeRef), width, T) @@ -260,7 +262,7 @@ EnzymeGradientUtilsTypeAnalyzer(gutils) = ccall((:EnzymeGradientUtilsTypeAnalyze EnzymeGradientUtilsAllocAndGetTypeTree(gutils, val) = ccall((:EnzymeGradientUtilsAllocAndGetTypeTree, libEnzyme), CTypeTreeRef, (EnzymeGradientUtilsRef,LLVMValueRef), gutils, val) -EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall((:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, UInt64), gutils, orig, uncacheable, size) +EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall((:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), UInt8, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, UInt64), gutils, orig, uncacheable, size) EnzymeGradientUtilsGetDiffeType(gutils, op, isforeign) = ccall((:EnzymeGradientUtilsGetDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, UInt8), gutils, op, isforeign) diff --git a/src/gradientutils.jl b/src/gradientutils.jl index cc64726f8e..f7f80fd396 100644 --- a/src/gradientutils.jl +++ b/src/gradientutils.jl @@ -24,7 +24,9 @@ function get_shadow_type(gutils::GradientUtils, T::LLVM.LLVMType) end function get_uncacheable(gutils::GradientUtils, orig::LLVM.CallInst) uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) - API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) + if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) != 1 + uncacheable .= 1 + end return uncacheable end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index c65da1585a..e8c573a176 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -1,5 +1,5 @@ -function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) +function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, mi, @nospecialize(RT), reverse::Bool, isKWCall::Bool) ops = collect(operands(orig)) called = ops[end] ops = ops[1:end-1] @@ -46,6 +46,7 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) if !(isKWCall && arg.arg_i == 1) push!(overwritten, false) end + if B !== nothing if Core.Compiler.isconstType(arg.typ) && !Core.Compiler.isconstType(Const{arg.typ}) llty = convert(LLVMType, Const{arg.typ}) al0 = al = emit_allocobj!(B, Const{arg.typ}) @@ -63,6 +64,7 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) else @assert isghostty(Const{arg.typ}) || Core.Compiler.isconstType(Const{arg.typ}) end + end continue end @assert !(isghostty(arg.typ) || Core.Compiler.isconstType(arg.typ)) @@ -74,7 +76,7 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) end val = new_from_original(gutils, op) - if reverse + if reverse && B !== nothing val = lookup_value(gutils, val, B) end @@ -100,21 +102,23 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) Ty = Const{arg.typ} llty = convert(LLVMType, Ty) arty = convert(LLVMType, arg.typ; allow_boxed=true) - al0 = al = emit_allocobj!(B, Ty) - al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) - al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) + if B !== nothing + al0 = al = emit_allocobj!(B, Ty) + al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) + al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) - ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) - if value_type(val) != eltype(value_type(ptr)) - val = load!(B, arty, val) - end - store!(B, val, ptr) + ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) + if value_type(val) != eltype(value_type(ptr)) + val = load!(B, arty, val) + end + store!(B, val, ptr) - if any_jltypes(llty) - emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) - end + if any_jltypes(llty) + emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) + end - push!(args, al) + push!(args, al) + end push!(activity, Ty) @@ -122,29 +126,33 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) Ty = Active{arg.typ} llty = convert(LLVMType, Ty) arty = convert(LLVMType, arg.typ; allow_boxed=true) - al0 = al = emit_allocobj!(B, Ty) - al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) - al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) + if B !== nothing + al0 = al = emit_allocobj!(B, Ty) + al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) + al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) - ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) - if value_type(val) != eltype(value_type(ptr)) - @assert !overwritten[end] - val = load!(B, arty, val) - end - store!(B, val, ptr) + ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) + if value_type(val) != eltype(value_type(ptr)) + @assert !overwritten[end] + val = load!(B, arty, val) + end + store!(B, val, ptr) - if any_jltypes(llty) - emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) - end + if any_jltypes(llty) + emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) + end - push!(args, al) + push!(args, al) + end push!(activity, Ty) push!(actives, op) else - ival = invert_pointer(gutils, op, B) - if reverse - ival = lookup_value(gutils, ival, B) + if B !== nothing + ival = invert_pointer(gutils, op, B) + if reverse + ival = lookup_value(gutils, ival, B) + end end if width == 1 if activep == API.DFT_DUP_ARG @@ -165,31 +173,33 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) llty = convert(LLVMType, Ty) arty = convert(LLVMType, arg.typ; allow_boxed=true) sarty = LLVM.LLVMType(API.EnzymeGetShadowType(width, arty)) - al0 = al = emit_allocobj!(B, Ty) - al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) - al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) + if B !== nothing + al0 = al = emit_allocobj!(B, Ty) + al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) + al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) - ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) - if value_type(val) != eltype(value_type(ptr)) - val = load!(B, arty, val) - ptr_val = ival - ival = UndefValue(sarty) - for idx in 1:width - ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1) - ld = load!(B, arty, ev) - ival = (width == 1 ) ? ld : insert_value!(B, ival, ld, idx-1) + ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) + if value_type(val) != eltype(value_type(ptr)) + val = load!(B, arty, val) + ptr_val = ival + ival = UndefValue(sarty) + for idx in 1:width + ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1) + ld = load!(B, arty, ev) + ival = (width == 1 ) ? ld : insert_value!(B, ival, ld, idx-1) + end end - end - store!(B, val, ptr) + store!(B, val, ptr) - iptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 1)]) - store!(B, ival, iptr) + iptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 1)]) + store!(B, ival, iptr) - if any_jltypes(llty) - emit_writebarrier!(B, get_julia_inner_types(B, al0, val, ival)) - end + if any_jltypes(llty) + emit_writebarrier!(B, get_julia_inner_types(B, al0, val, ival)) + end - push!(args, al) + push!(args, al) + end push!(activity, Ty) end @@ -197,7 +207,7 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) return args, activity, (overwritten...,), actives, kwtup end -function enzyme_custom_setup_ret(gutils, orig, mi, RealRt) +function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt)) width = get_width(gutils) mode = get_mode(gutils) @@ -206,7 +216,23 @@ function enzyme_custom_setup_ret(gutils, orig, mi, RealRt) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) + # Conditionally use the get return. This is done because EnzymeGradientUtilsGetReturnDiffeType + # calls differential use analysis to determine needsprimal/shadow. However, since now this function + # is used as part of differential use analysis, we need to avoid an ininite recursion. Thus use + # the version without differential use if actual unreachable results are not available anyways. + uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) + activep = if mode == API.DEM_ForwardMode || API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 1 + API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) + else + actv = API.EnzymeGradientUtilsGetDiffeType(gutils, orig, false) + if !isghostty(RealRt) + needsPrimalP[] = 1 + if actv == API.DFT_DUP_ARG || actv == API.DFT_DUP_NONEED + needsShadowP[] = 1 + end + end + actv + end needsPrimal = needsPrimalP[] != 0 origNeedsPrimal = needsPrimal _, sret, _ = get_return_info(RealRt) @@ -349,7 +375,7 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) end if length(args) != length(parameters(llvmf)) - GPUCompiler.@safe_error "Calling convention mismatch", args, llvmf, orig, isKWCall, kwtup, TT, sret, returnRoots + GPUCompiler.@safe_error "Calling convention mismatch", args, llvmf, string(value_type(llvmf)), orig, isKWCall, kwtup, TT, sret, returnRoots return false end @@ -476,19 +502,9 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) return false end -function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, normalR, shadowR, tape)::LLVM.API.LLVMValueRef - - ctx = LLVM.context(orig) - +@inline function aug_fwd_mi(orig::LLVM.CallInst, gutils::GradientUtils, forward=false, B=nothing) width = get_width(gutils) - shadowType = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) - if shadowR != C_NULL - unsafe_store!(shadowR,UndefValue(shadowType).ref) - end - - # TODO: don't inject the code multiple times for multiple calls - # 1) extract out the MI from attributes mi, RealRt = enzyme_custom_extract_mi(orig) isKWCall = isKWCallSignature(mi.specTypes) @@ -503,11 +519,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, needsShadow end - alloctx = LLVM.IRBuilder() - position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) - - curent_bb = position(B) - fn = LLVM.parent(curent_bb) + fn = LLVM.parent(LLVM.parent(orig)) world = enzyme_extract_world(fn) C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten} @@ -554,13 +566,55 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, end end end + return ami, augprimal_TT, (args, activity, overwritten, actives, kwtup, RT, needsPrimal, needsShadow, origNeedsPrimal) +end + +@inline function has_aug_fwd_rule(orig, gutils) + return aug_fwd_mi(orig, gutils)[1] !== nothing +end + +function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, normalR, shadowR, tape)::LLVM.API.LLVMValueRef + + ctx = LLVM.context(orig) + + width = get_width(gutils) + + shadowType = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + if shadowR != C_NULL + unsafe_store!(shadowR,UndefValue(shadowType).ref) + end + + # TODO: don't inject the code multiple times for multiple calls + + # 1) extract out the MI from attributes + mi, RealRt = enzyme_custom_extract_mi(orig) + isKWCall = isKWCallSignature(mi.specTypes) + + # 2) Create activity, and annotate function spec + ami, augprimal_TT, setup = aug_fwd_mi(orig, gutils, forward, B) + args, activity, overwritten, actives, kwtup, RT, needsPrimal, needsShadow, origNeedsPrimal = setup + + needsShadowJL = if RT <: Active + false + else + needsShadow + end + + C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten} + + alloctx = LLVM.IRBuilder() + position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + world = enzyme_extract_world(fn) + + mode = get_mode(gutils) + + @assert ami !== nothing target = DefaultCompilerTarget() params = PrimalCompilerParams(mode) aug_RT = something(Core.Compiler.typeinf_type(GPUCompiler.get_interpreter(CompilerJob(ami, CompilerConfig(target, params; kernel=false), world)), ami.def, ami.specTypes, ami.sparam_vals), Any) - - @assert ami !== nothing - if kwtup !== nothing && kwtup <: Duplicated @safe_debug "Non-constant keyword argument found for " augprimal_TT emit_error(B, orig, "Enzyme: Non-constant keyword argument found for " * string(augprimal_TT)) @@ -904,7 +958,7 @@ end function enzyme_custom_augfwd(B, orig, gutils, normalR, shadowR, tapeR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) return true end tape = enzyme_custom_common_rev(#=forward=#true, B, orig, gutils, normalR, shadowR, #=tape=#nothing) @@ -916,9 +970,18 @@ end function enzyme_custom_rev(B, orig, gutils, tape) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) return end enzyme_custom_common_rev(#=forward=#false, B, orig, gutils, #=normalR=#C_NULL, #=shadowR=#C_NULL, #=tape=#tape) return nothing end + +function enzyme_custom_diffuse(orig, gutils, val, isshadow, mode) + # use default + if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) + return (false, true) + end + # don't use default and always require the arg + return (true, false) +end diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index fa2efceed8..45d25b1b83 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -1172,7 +1172,18 @@ macro fwdfunc(f) )) end + +macro diffusefunc(f) + :(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin + res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, mode)::Tuple{Bool, Bool} + unsafe_store!(useDefault, UInt8(res[2])) + UInt8(res[1]) + end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8}) + )) +end + @noinline function register_llvm_rules() + API.EnzymeRegisterDiffUseCallHandler("enzyme_custom", @diffusefunc(enzyme_custom_diffuse)) register_handler!( ("julia.call",), @augfunc(jlcall_augfwd), From b5addb62c0c9d0d74825be5fc71deebfd26d793c Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 29 May 2024 13:30:53 +0200 Subject: [PATCH 092/495] Add finalizer under jlcall (#1483) --- src/rules/llvmrules.jl | 18 ++++++++++++++---- src/rules/typeunstablerules.jl | 30 +++++++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 45d25b1b83..37208f4afc 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -31,6 +31,9 @@ function jlcall_fwd(B, orig, gutils, normalR, shadowR) if in(name, ("ijl_f__svec_ref", "jl_f__svec_ref")) return common_f_svec_ref_fwd(2, B, orig, gutils, normalR, shadowR) end + if in(name, ("ijl_f_finalizer", "jl_f_finalizer")) + return common_finalizer_fwd(2, B, orig, gutils, normalR, shadowR) + end if any(map(k->kind(k)==kind(StringAttribute("enzyme_inactive")), collect(function_attributes(F)))) return true end @@ -69,6 +72,9 @@ function jlcall_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if in(name, ("ijl_f__svec_rev", "jl_f__svec_ref")) return common_f_svec_ref_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end + if in(name, ("ijl_f_finalizer", "jl_f_finalizer")) + return common_finalizer_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) + end if any(map(k->kind(k)==kind(StringAttribute("enzyme_inactive")), collect(function_attributes(F)))) return true end @@ -115,6 +121,10 @@ function jlcall_rev(B, orig, gutils, tape) common_f_svec_ref_rev(2, B, orig, gutils, tape) return nothing end + if in(name, ("ijl_f_finalizer", "jl_f_finalizer")) + common_finalizer_rev(2, B, orig, gutils, tape) + return nothing + end if any(map(k->kind(k)==kind(StringAttribute("enzyme_inactive")), collect(function_attributes(F)))) return nothing end @@ -1103,7 +1113,7 @@ function finalizer_fwd(B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - err = emit_error(B, orig, "Enzyme: unhandled augmented forward for jl_gc_add_finalizer_th or jl_gc_add_ptr_finalizer") + err = emit_error(B, orig, "Enzyme: unhandled forward for jl_gc_add_finalizer_th or jl_gc_add_ptr_finalizer") newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing @@ -1117,9 +1127,9 @@ function finalizer_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - # err = emit_error(B, orig, "Enzyme: unhandled augmented forward for jl_gc_add_finalizer_th") - # newo = new_from_original(gutils, orig) - # API.moveBefore(newo, err, B) + err = emit_error(B, orig, "Enzyme: unhandled augmented forward for jl_gc_add_finalizer_th") + newo = new_from_original(gutils, orig) + API.moveBefore(newo, err, B) normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index f70e15c82c..c6639eb8aa 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -822,7 +822,7 @@ function common_f_svec_ref_fwd(offset, B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - emit_error(B, orig, "Enzyme: unhandled augmented forward for jl_f__svec_ref") + emit_error(B, orig, "Enzyme: unhandled forward for jl_f__svec_ref") normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) @@ -900,6 +900,34 @@ function common_f_svec_ref_rev(offset, B, orig, gutils, tape) return nothing end +function common_finalizer_fwd(offset, B, orig, gutils, normalR, shadowR) + if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + return true + end + emit_error(B, orig, "Enzyme: unhandled forward for jl_f_finalizer") + normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + if shadowR != C_NULL && normal !== nothing + unsafe_store!(shadowR, normal.ref) + end + return false +end + +function common_finalizer_augfwd(offset, B, orig, gutils, normalR, shadowR) + if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + return true + end + emit_error(B, orig, "Enzyme: unhandled augmented forward for jl_f_finalizer") + normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + if shadowR != C_NULL && normal !== nothing + unsafe_store!(shadowR, normal.ref) + end + return false +end + +function common_finalizer_rev(offset, B, orig, gutils, tape) + return nothing +end + function f_svec_ref_fwd(B, orig, gutils, normalR, shadowR) common_f_svec_ref_fwd(1, B, orig, gutils, normalR, shadowR) return nothing From 7526a5c77964c52f03c8a94e0971fd0c9468aaa0 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sat, 1 Jun 2024 14:46:12 -0400 Subject: [PATCH 093/495] Setup benchmarks (#1489) --- .github/workflows/benchmark_pr.yml | 76 ++++++++++++++++++++++++++++++ benchmark/benchmarks.jl | 16 +++++++ 2 files changed, 92 insertions(+) create mode 100644 .github/workflows/benchmark_pr.yml create mode 100644 benchmark/benchmarks.jl diff --git a/.github/workflows/benchmark_pr.yml b/.github/workflows/benchmark_pr.yml new file mode 100644 index 0000000000..1af037fd6a --- /dev/null +++ b/.github/workflows/benchmark_pr.yml @@ -0,0 +1,76 @@ +name: Benchmark a pull request + +on: + pull_request: + +permissions: + pull-requests: write + +jobs: + generate_plots: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@v1 + with: + version: "1" + - uses: julia-actions/cache@v1 + - name: Extract Package Name from Project.toml + id: extract-package-name + run: | + PACKAGE_NAME=$(grep "^name" Project.toml | sed 's/^name = "\(.*\)"$/\1/') + echo "::set-output name=package_name::$PACKAGE_NAME" + - name: Build AirspeedVelocity + env: + JULIA_NUM_THREADS: 2 + run: | + # Lightweight build step, as sometimes the runner runs out of memory: + julia -e 'ENV["JULIA_PKG_PRECOMPILE_AUTO"]=0; import Pkg; Pkg.add("AirspeedVelocity")' + julia -e 'ENV["JULIA_PKG_PRECOMPILE_AUTO"]=0; import Pkg; Pkg.build("AirspeedVelocity")' + - name: Add ~/.julia/bin to PATH + run: | + echo "$HOME/.julia/bin" >> $GITHUB_PATH + - name: Run benchmarks + run: | + echo $PATH + ls -l ~/.julia/bin + mkdir results + benchpkg ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.repository.default_branch}},${{github.event.pull_request.head.sha}}" --url=${{ github.event.repository.clone_url }} --bench-on="${{github.event.repository.default_branch}}" --output-dir=results/ --tune + - name: Create plots from benchmarks + run: | + mkdir -p plots + benchpkgplot ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.repository.default_branch}},${{github.event.pull_request.head.sha}}" --npart=10 --format=png --input-dir=results/ --output-dir=plots/ + - name: Upload plot as artifact + uses: actions/upload-artifact@v2 + with: + name: plots + path: plots + - name: Create markdown table from benchmarks + run: | + benchpkgtable ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.repository.default_branch}},${{github.event.pull_request.head.sha}}" --input-dir=results/ --ratio > table.md + echo '### Benchmark Results' > body.md + echo '' >> body.md + echo '' >> body.md + cat table.md >> body.md + echo '' >> body.md + echo '' >> body.md + echo '### Benchmark Plots' >> body.md + echo 'A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.' >> body.md + echo 'Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).' >> body.md + + - name: Find Comment + uses: peter-evans/find-comment@v2 + id: fcbenchmark + with: + issue-number: ${{ github.event.pull_request.number }} + comment-author: 'github-actions[bot]' + body-includes: Benchmark Results + + - name: Comment on PR + uses: peter-evans/create-or-update-comment@v3 + with: + comment-id: ${{ steps.fcbenchmark.outputs.comment-id }} + issue-number: ${{ github.event.pull_request.number }} + body-path: body.md + edit-mode: replace diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl new file mode 100644 index 0000000000..5b5b7c03f8 --- /dev/null +++ b/benchmark/benchmarks.jl @@ -0,0 +1,16 @@ +# To run: +# using PkgBenchmark, Enzyme +# result = benchmarkpkg(KernelAbstractions) +# export_markdown("benchmark/perf.md", result) + +# Note: if you change this file you will need to delete an regenerate tune.json +# Your "v1.x" environment needs to have BenchmarkTools and PkgBenchmark installed. + +using BenchmarkTools +using Enzyme + +const SUITE = BenchmarkGroup() + +SUITE["basics"] = BenchmarkGroup() + +SUITE["basics"]["overhead"] = @benchmarkable Enzyme.autodiff(Forward, identity, Const(1.0)) \ No newline at end of file From 21b0762d2939f4b41c7449fa281a7f5261a7139f Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 3 Jun 2024 01:11:41 +0200 Subject: [PATCH 094/495] llvm.julia.gc_preserve_begin splatting (#1486) * llvm.julia.gc_preserve_begin splatting * fix --- src/compiler/validation.jl | 46 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index f8fa3a4cd2..68eb4a5bca 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -167,10 +167,52 @@ function rewrite_ccalls!(mod::LLVM.Module) replaceAndErase = Tuple{Instruction, Instruction}[] for bb in blocks(f), inst in instructions(bb) if isa(inst, LLVM.CallInst) + fn = called_operand(inst) changed = false - newbundles = OperandBundleDef[] B = IRBuilder() position!(B, inst) + if isa(fn, LLVM.Function) && LLVM.name(fn) == "llvm.julia.gc_preserve_begin" + uservals = LLVM.Value[] + for lval in collect(arguments(inst)) + llty = value_type(lval) + if isa(llty, LLVM.PointerType) + push!(uservals, lval) + continue + end + vals = get_julia_inner_types(B, nothing, lval) + for v in vals + if isa(v, LLVM.PointerNull) + subchanged = true + continue + end + push!(uservals, v) + end + if length(vals) == 1 && vals[1] == lval + continue + end + changed = true + end + if changed + prevname = LLVM.name(inst) + LLVM.name!(inst, "") + newinst = call!(B, called_type(inst), called_operand(inst), uservals, collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), prevname) + for idx = [LLVM.API.LLVMAttributeFunctionIndex, LLVM.API.LLVMAttributeReturnIndex, [LLVM.API.LLVMAttributeIndex(i) for i in 1:(length(arguments(inst)))]...] + idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx); + Attrs = Base.unsafe_convert(Ptr{LLVM.API.LLVMAttributeRef}, Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef)*count)) + LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) + for j in 1:count + LLVM.API.LLVMAddCallSiteAttribute(newinst, idx, unsafe_load(Attrs, j)) + end + Libc.free(Attrs) + end + API.EnzymeCopyMetadata(newinst, inst) + callconv!(newinst, callconv(inst)) + push!(replaceAndErase, (inst, newinst)) + end + continue + end + newbundles = OperandBundleDef[] for bunduse in operand_bundles(inst) bunduse = LLVM.OperandBundleDef(bunduse) if LLVM.tag_name(bunduse) != "jl_roots" @@ -181,7 +223,7 @@ function rewrite_ccalls!(mod::LLVM.Module) subchanged = false for lval in LLVM.inputs(bunduse) llty = value_type(lval) - if !isa(llty, LLVM.PointerType) || LLVM.addrspace(llty) != 10 + if isa(llty, LLVM.PointerType) push!(uservals, lval) continue end From e17da2c04e56cb31661f7f31bb907c7732ef5b42 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 5 Jun 2024 13:26:38 +0200 Subject: [PATCH 095/495] Test forward mode blas (#1490) * Test forward mdoe blas * Update Project.toml * fix * Update Project.toml --- Project.toml | 4 ++-- deps/build_local.jl | 57 +++++++++++++++++++++++++++++++++++++-------- src/compiler.jl | 2 +- test/runtests.jl | 4 ++-- 4 files changed, 52 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 7bcf3a8438..8052e3dad3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.9" +version = "0.12.10" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -30,7 +30,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.3" -Enzyme_jll = "0.0.117" +Enzyme_jll = "0.0.119" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" diff --git a/deps/build_local.jl b/deps/build_local.jl index 5c67ac0477..5f833ce1a5 100644 --- a/deps/build_local.jl +++ b/deps/build_local.jl @@ -6,7 +6,8 @@ Enzyme_jll = Base.UUID("7cc45869-7501-5eee-bdea-0790c847d4ef") using Pkg, Scratch, Preferences, Libdl -BUILD_TYPE = "RelWithDebInfo" +BUILD_TYPE = "RelWithDebInfo" +BCLoad = true # 1. Get a scratch directory scratch_dir = get_scratch!(Enzyme_jll, "build") @@ -14,12 +15,35 @@ isdir(scratch_dir) && rm(scratch_dir; recursive=true) source_dir = nothing branch = nothing -if length(ARGS) == 2 - @assert ARGS[1] == "--branch" - branch = ARGS[2] - source_dir = nothing -elseif length(ARGS) == 1 - source_dir = ARGS[1] + +args = (ARGS...,) +while length(args) > 0 + global args + global branch + global source_dir + if length(args) >= 2 && args[1] == "--branch" + branch = args[2] + args = (args[3:end]...,) + continue + end + if length(args) >= 1 && args[1] == "--debug" + BUILD_TYPE = "Debug" + args = (args[2:end]...,) + continue + end + if length(args) >= 1 && args[1] == "--nobcload" + BCLoad = false + args = (args[2:end]...,) + continue + end + if source_dir == nothing + source_dir = args[1] + args = (args[2:end]...,) + continue + end + @show args + @assert length(args) == 0 + break end if branch === nothing @@ -62,7 +86,12 @@ LLVM_VER_MAJOR = Base.libllvm_version.major # Build! @info "Building" source_dir scratch_dir LLVM_DIR run(`cmake -DLLVM_DIR=$(LLVM_DIR) -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) -DENZYME_EXTERNAL_SHARED_LIB=ON -B$(scratch_dir) -S$(source_dir)`) -run(`cmake --build $(scratch_dir) --parallel $(Sys.CPU_THREADS) -t Enzyme-$(LLVM_VER_MAJOR) EnzymeBCLoad-$(LLVM_VER_MAJOR)`) + +if BCLoad + run(`cmake --build $(scratch_dir) --parallel $(Sys.CPU_THREADS) -t Enzyme-$(LLVM_VER_MAJOR) EnzymeBCLoad-$(LLVM_VER_MAJOR)`) +else + run(`cmake --build $(scratch_dir) --parallel $(Sys.CPU_THREADS) -t Enzyme-$(LLVM_VER_MAJOR)`) +end # Discover built libraries built_libs = filter(readdir(joinpath(scratch_dir, "Enzyme"))) do file @@ -72,18 +101,26 @@ end lib_path = joinpath(scratch_dir, "Enzyme", only(built_libs)) isfile(lib_path) || error("Could not find library $lib_path in build directory") +# Tell Enzyme_jll to load our library instead of the default artifact one +set_preferences!( + joinpath(dirname(@__DIR__), "LocalPreferences.toml"), + "Enzyme_jll", + "libEnzyme_path" => lib_path, + force=true, +) + +if BCLoad built_libs = filter(readdir(joinpath(scratch_dir, "BCLoad"))) do file endswith(file, ".$(Libdl.dlext)") && startswith(file, "lib") end libBC_path = joinpath(scratch_dir, "BCLoad", only(built_libs)) isfile(libBC_path) || error("Could not find library $libBC_path in build directory") - # Tell Enzyme_jll to load our library instead of the default artifact one set_preferences!( joinpath(dirname(@__DIR__), "LocalPreferences.toml"), "Enzyme_jll", - "libEnzyme_path" => lib_path, "libEnzymeBCLoad_path" => libBC_path; force=true, ) +end diff --git a/src/compiler.jl b/src/compiler.jl index 8cbac14f94..30bf6f0d9c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4632,7 +4632,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; disableFallback = String[] # Tablegen BLAS does not support forward mode yet - if mode != API.DEM_ForwardMode + if !(mode == API.DEM_ForwardMode && Enzyme.API.runtimeActivity()) for ty in ("s", "d") for func in ("dot","gemm","gemv","axpy","copy","scal") for prefix in ("", "cblas_") diff --git a/test/runtests.jl b/test/runtests.jl index 94809a7f3e..225ddf435f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -69,7 +69,7 @@ function test_matrix_to_number(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1) dx = zero(x) autodiff(Reverse, f, Active, Duplicated(x, dx)) - @test isapprox(reshape(dx, length(dx)), dx_fd; rtol=rtol, atol=atol, kwargs...) + @test isapproxfn((Enzyme.Reverse, f), reshape(dx, length(dx)), dx_fd; rtol=rtol, atol=atol, kwargs...) dx_fwd = map(eachindex(x)) do i dx = zero(x) @@ -77,7 +77,7 @@ function test_matrix_to_number(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1) ∂x = autodiff(Forward, f, Duplicated(x, dx)) isempty(∂x) ? zero(eltype(dx)) : ∂x[1] end - @test isapprox(dx_fwd, dx_fd; rtol=rtol, atol=atol, kwargs...) + @test isapproxfn((Enzyme.Forward, f), dx_fwd, dx_fd; rtol=rtol, atol=atol, kwargs...) end Aqua.test_all(Enzyme, unbound_args=false, piracies=false, deps_compat=false) From f8427ca02357cdde67003294f9107bdc27553b40 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 6 Jun 2024 19:13:16 +0200 Subject: [PATCH 096/495] Workaround breakage in llvm.jl (#1497) Co-authored-by: Valentin Churavy --- src/api.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/api.jl b/src/api.jl index c0632d2600..3c626635b0 100644 --- a/src/api.jl +++ b/src/api.jl @@ -104,6 +104,7 @@ struct CFnTypeInfo end +@static if isdefined(LLVM, :InstructionMetadataDict) Base.haskey(md::LLVM.InstructionMetadataDict, kind::String) = ccall((:EnzymeGetStringMD, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef, Cstring), md.inst, kind) != C_NULL @@ -115,6 +116,7 @@ function Base.getindex(md::LLVM.InstructionMetadataDict, kind::String) Base.setindex!(md::LLVM.InstructionMetadataDict, node::LLVM.Metadata, kind::String) = ccall((:EnzymeSetStringMD, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef, Cstring, LLVM.API.LLVMValueRef), md.inst, kind, LLVM.Value(node)) +end @cenum(CDIFFE_TYPE, DFT_OUT_DIFF = 0, # add differential to an output struct From b2b5161213cbaa6e2c69fd690be0d2f97aa17216 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 6 Jun 2024 13:33:18 -0400 Subject: [PATCH 097/495] CompatHelper: add new compat entry for StaticArrays in [weakdeps] at version 1, (keep existing compat) (#1469) Co-authored-by: CompatHelper Julia --- Project.toml | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 8052e3dad3..0b19c3cead 100644 --- a/Project.toml +++ b/Project.toml @@ -16,16 +16,6 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -[weakdeps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[extensions] -EnzymeChainRulesCoreExt = "ChainRulesCore" -EnzymeSpecialFunctionsExt = "SpecialFunctions" -EnzymeStaticArraysExt = "StaticArrays" - [compat] CEnum = "0.4, 0.5" ChainRulesCore = "1" @@ -36,9 +26,20 @@ LLVM = "6.1, 7" ObjectFile = "0.4" Preferences = "1.4" SpecialFunctions = "1, 2" +StaticArrays = "1" julia = "1.6" +[extensions] +EnzymeChainRulesCoreExt = "ChainRulesCore" +EnzymeSpecialFunctionsExt = "SpecialFunctions" +EnzymeStaticArraysExt = "StaticArrays" + [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" From ebe4cf7dcc1500a6a17cc84cc1c3000574410066 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 6 Jun 2024 23:05:32 +0330 Subject: [PATCH 098/495] Update `CompatHelper` workflow (#1478) Co-authored-by: Valentin Churavy --- .github/workflows/CompatHelper.yml | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 9ed073b516..bda2faf4ed 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -8,19 +8,27 @@ jobs: CompatHelper: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Get Julia compatibility id: julia_compat # NOTE: this requires a julia compat lower-bound with minor version! run : | version=$(grep '^julia = ' Project.toml | grep -o '".*"' | cut -d '"' -f2) echo "::set-output name=version::$version" - - uses: julia-actions/setup-julia@latest + - uses: julia-actions/setup-julia@v2 with: version: ${{ steps.julia_compat.outputs.version }} - - name: Pkg.add("CompatHelper") - run: julia -e 'using Pkg; Pkg.add("CompatHelper")' - - name: CompatHelper.main() + arch: x64 + show-versioninfo: true + - name: Pkg.add + shell: julia --color=yes {0} + run: | + import Pkg + Pkg.add("CompatHelper") + - name: CompatHelper.main env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: julia -e 'using CompatHelper; CompatHelper.main()' + shell: julia --color=yes {0} + run: | + import CompatHelper + CompatHelper.main(; include_jll = true, subdirs = ["", "docs", "test", "deps", "lib/EnzymeCore", "lib/EnzymeCore/test", "lib/EnzymeTestUtils", "lib/EnzymeTestUtils/test"]) From 8d0be2ed98dbe90487afee53791912d25c052edc Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 6 Jun 2024 15:45:56 -0400 Subject: [PATCH 099/495] Update CompatHelper.yml --- .github/workflows/CompatHelper.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index bda2faf4ed..591cbae5c4 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -3,6 +3,7 @@ name: CompatHelper on: schedule: - cron: '0 0 * * *' + workflow_dispatch: jobs: CompatHelper: From f53c53c32a3f1e54177347728ff68b1f791b2747 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 6 Jun 2024 15:53:07 -0400 Subject: [PATCH 100/495] CompatHelper: add new compat entry for Preferences at version 1 for package deps, (keep existing compat) (#1514) Co-authored-by: CompatHelper Julia --- deps/Project.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deps/Project.toml b/deps/Project.toml index b9566c7d69..b68337ab88 100644 --- a/deps/Project.toml +++ b/deps/Project.toml @@ -3,3 +3,6 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Scratch = "6c6a2e73-6563-6170-7368-637461726353" + +[compat] +Preferences = "1" From f74c13768681355e778ee73566807c0239349937 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 6 Jun 2024 15:53:45 -0400 Subject: [PATCH 101/495] CompatHelper: add new compat entry for Scratch at version 1 for package deps, (keep existing compat) (#1515) Co-authored-by: CompatHelper Julia Co-authored-by: Valentin Churavy --- deps/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/deps/Project.toml b/deps/Project.toml index b68337ab88..94ba9b06c0 100644 --- a/deps/Project.toml +++ b/deps/Project.toml @@ -5,4 +5,5 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Scratch = "6c6a2e73-6563-6170-7368-637461726353" [compat] +Scratch = "1" Preferences = "1" From 8b00cf2b3c88f7fdadd962693bd5d455f4c9d503 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 6 Jun 2024 21:11:00 -0400 Subject: [PATCH 102/495] turn off CompatHelper for test --- .github/workflows/CompatHelper.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 591cbae5c4..aaeb8a60d4 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -32,4 +32,4 @@ jobs: shell: julia --color=yes {0} run: | import CompatHelper - CompatHelper.main(; include_jll = true, subdirs = ["", "docs", "test", "deps", "lib/EnzymeCore", "lib/EnzymeCore/test", "lib/EnzymeTestUtils", "lib/EnzymeTestUtils/test"]) + CompatHelper.main(; include_jll = true, subdirs = ["", "docs", "deps", "lib/EnzymeCore", "lib/EnzymeTestUtils"]) From 3f7b80e433a8ab59a09543b9e8259bc5465a0d8f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 6 Jun 2024 21:15:43 -0400 Subject: [PATCH 103/495] CompatHelper: bump compat for Documenter to 1 for package docs, (keep existing compat) (#1499) Co-authored-by: CompatHelper Julia --- docs/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Project.toml b/docs/Project.toml index 1fa467423b..9ebcb622b3 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,4 +4,4 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -Documenter = "0.27.8" +Documenter = "0.27.8, 1" From 23dd259dc8f7d49a25329bbc1fe4edd220f69acc Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 6 Jun 2024 21:16:24 -0400 Subject: [PATCH 104/495] CompatHelper: add new compat entry for Literate at version 2 for package docs, (keep existing compat) (#1500) Co-authored-by: CompatHelper Julia Co-authored-by: Valentin Churavy --- docs/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/Project.toml b/docs/Project.toml index 9ebcb622b3..73ef4b4f01 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,4 +4,5 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] +Literate = "2" Documenter = "0.27.8, 1" From 90fdd2e6e9695044874692414ab304812113987d Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 7 Jun 2024 14:39:18 -0400 Subject: [PATCH 105/495] Avoid undefref error in validation.jl (#1523) --- src/compiler/validation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 68eb4a5bca..80db3cfb39 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -426,7 +426,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) flib = Base.unsafe_pointer_to_objref(ld) end end - if isa(flib, GlobalRef) + if isa(flib, GlobalRef) && isdefined(flib.mod, flib.name) flib = getfield(flib.mod, flib.name) end From 5a2031121da8a20e994525c2ee91865eae61cdb7 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 7 Jun 2024 14:39:36 -0400 Subject: [PATCH 106/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 0b19c3cead..8e31225cd9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.10" +version = "0.12.11" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From ad7694eb5fefd17472267a706e6720cd95ede561 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 7 Jun 2024 14:43:53 -0400 Subject: [PATCH 107/495] Fix docs build (#1524) --- docs/Project.toml | 2 +- docs/make.jl | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 73ef4b4f01..56dd852972 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -5,4 +5,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Literate = "2" -Documenter = "0.27.8, 1" +Documenter = "1" diff --git a/docs/make.jl b/docs/make.jl index 022e0b5f8d..4f2eea837d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -59,7 +59,6 @@ makedocs(; ] ], doctest = true, - strict = true, ) deploydocs(; From df9bff9628f868dc581904795d3f2eeeeb430e42 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 8 Jun 2024 12:28:41 -0400 Subject: [PATCH 108/495] Make zero in place (#1518) * Make zero in place * add make_zero! * more fixes and tests --- Project.toml | 2 +- examples/custom_rule.jl | 8 +- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/EnzymeCore.jl | 7 ++ src/Enzyme.jl | 6 +- src/api.jl | 2 +- src/compiler.jl | 202 ++++++++++++++++++++++++++++++- test/runtests.jl | 22 ++++ 8 files changed, 239 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 8e31225cd9..87c6d55dcc 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.7.3" +EnzymeCore = "0.7.4" Enzyme_jll = "0.0.119" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" diff --git a/examples/custom_rule.jl b/examples/custom_rule.jl index 836d299c1e..c2098006c2 100644 --- a/examples/custom_rule.jl +++ b/examples/custom_rule.jl @@ -134,7 +134,7 @@ function forward(func::Const{typeof(f)}, RT::Type{<:Union{Const, DuplicatedNoNee if !(x isa Const) && !(y isa Const) y.dval .= 2 .* x.val .* x.dval elseif !(y isa Const) - y.dval .= 0 + make_zero!(y.dval) end dret = !(y isa Const) ? sum(y.dval) : zero(eltype(y.val)) if RT <: Const @@ -211,7 +211,7 @@ function reverse(config::ConfigWidth{1}, func::Const{typeof(f)}, dret::Active, t x.dval .+= 2 .* xval .* dret.val ## also accumulate any derivative in y's shadow into x's shadow. x.dval .+= 2 .* xval .* y.dval - y.dval .= 0 + make_zero!(y.dval) return (nothing, nothing) end @@ -251,8 +251,8 @@ end x = [3.0, 1.0] y = [0.0, 0.0] -dx .= 0 -dy .= 0 +make_zero!(dx) +make_zero!(dy) autodiff(Reverse, h, Duplicated(y, dy), Duplicated(x, dx)) @show dx # derivative of h w.r.t. x diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 670e1f3014..20a89b9a05 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.7.3" +version = "0.7.4" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 30577a38e8..fb788fd5a6 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -228,6 +228,13 @@ function autodiff_deferred_thunk end """ function make_zero end +""" + make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing + + Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`. +""" +function make_zero! end + """ make_zero(prev::T) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 911d1801ad..7626304944 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -14,8 +14,8 @@ export BatchDuplicatedFunc import EnzymeCore: batch_size, get_func export batch_size, get_func -import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero -export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero +import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero, make_zero! +export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero, make_zero! export jacobian, gradient, gradient! export markType, batch_size, onehot, chunkedonehot @@ -1007,7 +1007,7 @@ gradient!(Reverse, dx, f, [2.0, 3.0]) ``` """ @inline function gradient!(::ReverseMode, dx::X, f::F, x::X) where {X<:Array, F} - dx .= 0 + make_zero!(dx) autodiff(Reverse, f, Active, Duplicated(x, dx)) dx end diff --git a/src/api.jl b/src/api.jl index 3c626635b0..d68d904d5a 100644 --- a/src/api.jl +++ b/src/api.jl @@ -104,7 +104,7 @@ struct CFnTypeInfo end -@static if isdefined(LLVM, :InstructionMetadataDict) +@static if !isdefined(LLVM, :ValueMetadataDict) Base.haskey(md::LLVM.InstructionMetadataDict, kind::String) = ccall((:EnzymeGetStringMD, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef, Cstring), md.inst, kind) != C_NULL diff --git a/src/compiler.jl b/src/compiler.jl index 30bf6f0d9c..fac6907b59 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1298,7 +1298,7 @@ end xi = getfield(prev, i) T = Core.Typeof(xi) xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) - ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi) + setfield!(y, i, xi) end end return y @@ -1324,6 +1324,204 @@ end return y end +function make_zero_immutable!(prev::T, seen::S)::T where {T <: AbstractFloat, S} + zero(T) +end + +function make_zero_immutable!(prev::Complex{T}, seen::S)::Complex{T} where {T <: AbstractFloat, S} + zero(T) +end + +function make_zero_immutable!(prev::T, seen::S)::T where {T <: Tuple, S} + ntuple(Val(length(T.parameters))) do i + Base.@_inline_meta + make_zero_immutable!(prev[i], seen) + end +end + +function make_zero_immutable!(prev::NamedTuple{a, b}, seen::S)::NamedTuple{a, b} where {a,b, S} + NamedTuple{a, b}( + ntuple(Val(length(T.parameters))) do i + Base.@_inline_meta + make_zero_immutable!(prev[a[i]], seen) + end + ) +end + + +function make_zero_immutable!(prev::T, seen::S)::T where {T, S} + if guaranteed_const_nongen(T, nothing) + return prev + end + @assert !ismutable(T) + + @assert !Base.isabstracttype(RT) + @assert Base.isconcretetype(RT) + nf = fieldcount(RT) + + flds = Vector{Any}(undef, nf) + for i in 1:nf + if isdefined(prev, i) + xi = getfield(prev, i) + ST = Core.Typeof(xi) + flds[i] = if active_reg_inner(ST, (), nothing, #=justActive=#Val(true)) == ActiveState + make_zero_immutable!(xi, seen) + else + EnzymeCore.make_zero!(xi, seen) + xi + end + else + nf = i - 1 # rest of tail must be undefined values + break + end + end + ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)::T +end + +@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T}, seen::ST)::Nothing where {T <: AbstractFloat, ST} + T[] = zero(T) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Base.RefValue{Complex{T}}, seen::ST)::Nothing where {T <: AbstractFloat, ST} + T[] = zero(Complex{T}) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{T, N}, seen::ST)::Nothing where {T <: AbstractFloat, N, ST} + fill!(prev, zero(T)) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{Complex{T}, N}, seen::ST)::Nothing where {T <: AbstractFloat, N, ST} + fill!(prev, zero(Complex{T})) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T})::Nothing where {T <: AbstractFloat} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Base.RefValue{Complex{T}})::Nothing where {T <: AbstractFloat} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{T, N})::Nothing where {T <: AbstractFloat, N} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{Complex{T}, N})::Nothing where {T <: AbstractFloat, N} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{T, N}, seen::ST)::Nothing where {T, N, ST} + if guaranteed_const_nongen(T, nothing) + return + end + if in(seen, prev) + return + end + push!(seen, prev) + + for I in eachindex(prev) + if isassigned(prev, I) + pv = prev[I] + SBT = Core.Typeof(pv) + if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + @inbounds prev[I] = make_zero_immutable!(pv, seen) + nothing + else + EnzymeCore.make_zero!(pv, seen) + nothing + end + end + end + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T}, seen::ST)::Nothing where {T, ST} + if guaranteed_const_nongen(T, nothing) + return + end + if in(seen, prev) + return + end + push!(seen, prev) + + pv = prev[] + SBT = Core.Typeof(pv) + if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + prev[] = make_zero_immutable!(pv, seen) + nothing + else + EnzymeCore.make_zero!(pv, seen) + nothing + end + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST} + pv = prev.contents + T = Core.Typeof(pv) + if guaranteed_const_nongen(T, nothing) + return + end + if in(seen, prev) + return + end + push!(seen, prev) + SBT = Core.Typeof(pv) + if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) + nothing + else + EnzymeCore.make_zero!(pv, seen) + nothing + end + nothing +end + +@inline function EnzymeCore.make_zero!(prev::T, seen::S=Base.IdSet{Any}())::Nothing where {T, S} + if guaranteed_const_nongen(T, nothing) + return + end + if in(seen, prev) + return + end + @assert !Base.isabstracttype(T) + @assert Base.isconcretetype(T) + nf = fieldcount(T) + + + if nf == 0 + return + end + + push!(seen, prev) + + for i in 1:nf + if isdefined(prev, i) + xi = getfield(prev, i) + SBT = Core.Typeof(xi) + if guaranteed_const_nongen(SBT, nothing) + continue + end + if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + setfield!(prev, i, make_zero_immutable!(xi, seen)) + nothing + else + EnzymeCore.make_zero!(xi, seen) + nothing + end + end + end + return +end + struct EnzymeRuntimeException <: Base.Exception msg::Cstring end @@ -5536,7 +5734,7 @@ end @assert ismutable(x) yi = getfield(y, i) nexti = recursive_add(xi, yi, f, mutable_register) - ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), x, i-1, nexti) + setfield!(x, i, nexti) end end end diff --git a/test/runtests.jl b/test/runtests.jl index 225ddf435f..0212ec0d83 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -181,6 +181,28 @@ end # @test thunk_split.primal !== C_NULL # @test thunk_split.primal !== thunk_split.adjoint # @test thunk_a.adjoint !== thunk_split.adjoint + # + z = ([3.14, 21.5, 16.7], [0,1], [5.6, 8.9]) + Enzyme.make_zero!(z) + @test z[1] ≈ [0.0, 0.0, 0.0] + @test z[2][1] == 0 + @test z[2][2] == 1 + @test z[3] ≈ [0.0, 0.0] + + z2 = ([3.14, 21.5, 16.7], [0,1], [5.6, 8.9]) + Enzyme.make_zero!(z2) + @test z2[1] ≈ [0.0, 0.0, 0.0] + @test z2[2][1] == 0 + @test z2[2][2] == 1 + @test z2[3] ≈ [0.0, 0.0] + + z3 = [3.4, "foo"] + Enzyme.make_zero!(z3) + @test z3[1] ≈ 0.0 + @test z3[2] == "foo" + + z4 = sin + Enzyme.make_zero!(z4) end @testset "Reflection" begin From 86da3cdae09a5c6d0f877c4d9e01a4491d414501 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 8 Jun 2024 17:39:06 -0400 Subject: [PATCH 109/495] Reverse mode apply iterate (#1485) * Reverse mode apply iterate * fixed * fixup * cleanup * debugging fixes * fixup * cleanup * fix tests * fix batch getfield rev * fix tests * more test fix * fix tuple fast path * fix * Update Project.toml * fix sym index rev * fix test * fixup * Fix unionall * fix * fix sym offset * ix constantarray * Update Project.toml --- Project.toml | 4 +- src/Enzyme.jl | 7 + src/compiler.jl | 11 +- src/compiler/validation.jl | 16 +- src/rules/jitrules.jl | 662 +++++++++++++++++++++++---------- src/rules/typeunstablerules.jl | 152 +++++++- src/utils.jl | 2 +- test/applyiter.jl | 491 ++++++++++++++++++++++++ test/runtests.jl | 288 ++------------ 9 files changed, 1145 insertions(+), 488 deletions(-) create mode 100644 test/applyiter.jl diff --git a/Project.toml b/Project.toml index 87c6d55dcc..848c47e7ee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.11" +version = "0.12.12" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -20,7 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.4" -Enzyme_jll = "0.0.119" +Enzyme_jll = "0.0.121" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 7626304944..a6bc604e6a 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -74,6 +74,13 @@ end end)...} end +@inline function vaEltypes(args::Type{Ty}) where {Ty <: Tuple} + return Tuple{(ntuple(Val(length(Ty.parameters))) do i + Base.@_inline_meta + eltype(Ty.parameters[i]) + end)...} +end + @inline function same_or_one_helper(current, next) if current == -1 return next diff --git a/src/compiler.jl b/src/compiler.jl index fac6907b59..cca67bc874 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -380,7 +380,6 @@ end end @inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret} - if T === Any return DupState end @@ -422,7 +421,9 @@ end else inmi = GPUCompiler.methodinstance(typeof(EnzymeCore.EnzymeRules.inactive_type), Tuple{Type{T}}, world) args = Any[EnzymeCore.EnzymeRules.inactive_type, T]; - ccall(:jl_invoke, Any, (Any, Ptr{Any}, Cuint, Any), EnzymeCore.EnzymeRules.inactive_type, args, length(args), inmi) + GC.@preserve T begin + ccall(:jl_invoke, Any, (Any, Ptr{Any}, Cuint, Any), EnzymeCore.EnzymeRules.inactive_type, args, length(args), inmi) + end end if inactivety @@ -480,11 +481,13 @@ end @static if VERSION < v"1.7.0" nT = T else - nT = if is_concrete_tuple(T) + nT = if T <: Tuple && T != Tuple && !(T isa UnionAll) Tuple{(ntuple(length(T.parameters)) do i Base.@_inline_meta sT = T.parameters[i] - if sT isa Core.TypeofVararg + if sT isa TypeVar + Any + elseif sT isa Core.TypeofVararg Any else sT diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 80db3cfb39..caf86cbc03 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -743,7 +743,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width end end - seen = Dict{LLVM.Value,Tuple}() + seen = Set{Tuple{LLVM.Value,Tuple}}() while length(todo) != 0 cur, off = pop!(todo) @@ -751,11 +751,10 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width cur = operands(cur)[1] end - if cur in keys(seen) - @assert seen[cur] == off + if cur in seen continue end - seen[cur] = off + push!(seen, (cur, off)) if isa(cur, LLVM.PHIInst) for (v, _) in LLVM.incoming(cur) @@ -781,7 +780,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width # if inserting at the current desired offset, we have found the value we need if ind == off[1] - push!(todo, (operands(cur)[2], -1)) + push!(todo, (operands(cur)[2], off[2:end])) # otherwise it must be inserted at a different point else push!(todo, (operands(cur)[1], off)) @@ -880,10 +879,15 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width end end + if isa(cur, LLVM.ConstantArray) + push!(todo, (cur[off[1]], off[2:end])) + continue + end + msg = sprint() do io::IO println(io, "Enzyme Internal Error (rewrite_union_returns_as_ref[2])") println(io, string(enzymefn)) - println(io, "cur=", cur) + println(io, "cur=", string(cur)) println(io, "off=", off) end throw(AssertionError(msg)) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 76b72466c1..af12d2bfbc 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1,5 +1,5 @@ -function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing) +function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate=false) primargs = Union{Symbol,Expr}[] shadowargs = Union{Symbol,Expr}[] batchshadowargs = Vector{Union{Symbol,Expr}}[] @@ -59,8 +59,36 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing) @assert length(primargs) == N @assert length(primtypes) == N wrapped = Expr[] + modbetween = Expr[:(MB[1])] for i in 1:N - expr = :( + if iterate + push!(modbetween, quote + ntuple(Val(length($(primargs[i])))) do _ + Base.@_inline_meta + MB[$i] + end + end) + end + expr = if iterate + :( + if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) + @assert $(primtypes[i]) !== DataType + if !$forwardMode && active_reg($(primtypes[i])) + iterate_unwrap_augfwd_act($(primargs[i])...) + else + $((Width == 1) ? quote + iterate_unwrap_augfwd_dup(Val($forwardMode), $(primargs[i]), $(shadowargs[i])) + end : quote + iterate_unwrap_augfwd_batchdup(Val($forwardMode), Val($Width), $(primargs[i]), $(shadowargs[i])) + end + ) + end + else + map(Const, $(primargs[i])) + end + ) + else + :( if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) @assert $(primtypes[i]) !== DataType if !$forwardMode && active_reg($(primtypes[i])) @@ -73,9 +101,10 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing) end ) + end push!(wrapped, expr) end - return primargs, shadowargs, primtypes, allargs, typeargs, wrapped, batchshadowargs + return primargs, shadowargs, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween end function body_runtime_generic_fwd(N, Width, wrapped, primtypes) @@ -110,7 +139,6 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) end world = codegen_world_age(FT, tt) - forward = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val($ModifiedBetween), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) res = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) @@ -131,7 +159,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) end function func_runtime_generic_fwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _ = setup_macro_wraps(true, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(true, N, Width) body = body_runtime_generic_fwd(N, Width, wrapped, primtypes) quote @@ -143,14 +171,14 @@ end @generated function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(true, N, Width, :allargs) + _, _, primtypes, _, _, wrapped, _, _ = setup_macro_wraps(true, N, Width, :allargs) return body_runtime_generic_fwd(N, Width, wrapped, primtypes) end function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) nnothing = ntuple(i->nothing, Val(Width+1)) nres = ntuple(i->:(origRet), Val(Width+1)) - nzeros = ntuple(i->:(Ref(zero(resT))), Val(Width)) + nzeros = ntuple(i->:(Ref(make_zero(origRet))), Val(Width)) nres3 = ntuple(i->:(res[3]), Val(Width)) ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) @@ -162,7 +190,13 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) # tt0 = Tuple{$(primtypes...)} tt′ = Tuple{$(Types...)} rt = Core.Compiler.return_type(f, Tuple{$(ElTypes...)}) - annotation = guess_activity(rt, API.DEM_ReverseModePrimal) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + + annotation = if $Width != 1 && annotation0 <: Duplicated + BatchDuplicated{rt, $Width} + else + annotation0 + end dupClosure = ActivityTup[1] FT = Core.Typeof(f) @@ -209,19 +243,19 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) end function func_runtime_generic_augfwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _ = setup_macro_wraps(false, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(false, N, Width) body = body_runtime_generic_augfwd(N, Width, wrapped, primtypes) quote - function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} + function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, ReturnType, F, DF} +@generated function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(false, N, Width, :allargs) + _, _, primtypes, _, _, wrapped, _, _= setup_macro_wraps(false, N, Width, :allargs) return body_runtime_generic_augfwd(N, Width, wrapped, primtypes) end @@ -267,7 +301,13 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) tt = Tuple{$(ElTypes...)} tt′ = Tuple{$(Types...)} rt = Core.Compiler.return_type(f, tt) - annotation = guess_activity(rt, API.DEM_ReverseModePrimal) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + + annotation = if $Width != 1 && annotation0 <: Duplicated + BatchDuplicated{rt, $Width} + else + annotation0 + end dupClosure = ActivityTup[1] FT = Core.Typeof(f) @@ -278,6 +318,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + if tape.shadow_return !== nothing args = (args..., $shadowret) end @@ -290,7 +331,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) end function func_runtime_generic_rev(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _ = setup_macro_wraps(false, N, Width) body = body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) quote @@ -302,7 +343,7 @@ end @generated function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width, :allargs) + _, _, primtypes, _, _, wrapped, batchshadowargs, _ = setup_macro_wraps(false, N, Width, :allargs) return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) end @@ -323,69 +364,127 @@ end end end +@inline function iterate_unwrap_augfwd_act(args...) + ntuple(Val(length(args))) do i + Base.@_inline_meta + arg = args[i] + if guaranteed_const(Core.Typeof(arg)) + Const(arg) + else + Active(arg) + end + end +end + +@inline function iterate_unwrap_augfwd_dup(::Val{forwardMode}, args, dargs) where forwardMode + ntuple(Val(length(args))) do i + Base.@_inline_meta + arg = args[i] + ty = Core.Typeof(arg) + if guaranteed_const(ty) + Const(arg) + elseif !forwardMode && active_reg(ty) + Active(arg) + else + Duplicated(arg, dargs[i]) + end + end +end + +@inline function iterate_unwrap_augfwd_batchdup(::Val{forwardMode}, ::Val{Width}, args, dargs) where {forwardMode, Width} + ntuple(Val(length(args))) do i + Base.@_inline_meta + arg = args[i] + ty = Core.Typeof(arg) + if guaranteed_const(ty) + Const(arg) + elseif !forwardMode && active_reg(ty) + Active(arg) + else + BatchDuplicated(arg, ntuple(Val(Width)) do j + Base.@_inline_meta + dargs[j][i] + end) + end + end +end + +@inline function allFirst(::Val{Width}, res) where Width + ntuple(Val(Width)) do i + Base.@_inline_meta + res[1] + end +end + +@inline function allZero(::Val{Width}, res) where Width + ntuple(Val(Width)) do i + Base.@_inline_meta + Ref(make_zero(res)) + end +end + # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] -function fwddiff_with_return(::Val{width}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {width, Nargs} - tt′ = Enzyme.vaTypeof(args...) +function fwddiff_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {width, dupClosure0, ReturnType, FT, tt′, DF, Nargs} ReturnPrimal = Val(true) - RT = A ModifiedBetween = Val(Enzyme.falses_from_args(Nargs+1)) - - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - world = codegen_world_age(Core.Typeof(f.val), tt) - thunk(Val(world), FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), - ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI)(f, args...) -end + dupClosure = dupClosure0 && !guaranteed_const(FT) + FA = dupClosure ? Duplicated{FT} : Const{FT} -function body_runtime_iterate_fwd(N, Width, wrapped, primtypes) - nnothing = ntuple(i->nothing, Val(Width+1)) - nres = ntuple(i->:(res[1]), Val(Width+1)) - ModifiedBetween = ntuple(i->false, Val(N+1)) - ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) - Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) - return quote - args0 = ($(wrapped...),) - args = concat(iterate_unwrap_fwd(args0...)...) - - dupClosure = ActivityTup[1] - FT = Core.Typeof(f) - if dupClosure && guaranteed_const(FT) - dupClosure = false - end + tt = Enzyme.vaEltypes(tt′) - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - rt = Core.Compiler.return_type(f, tt) - annotation0 = guess_activity(rt, API.DEM_ForwardMode) + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ForwardMode) - annotation = @static if $Width != 1 - if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated - BatchDuplicated{rt, $Width} - else - Const{rt} - end + annotation = if width != 1 + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + BatchDuplicated{rt, width} else - if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated - Duplicated{rt} - else - Const{rt} - end + Const{rt} end + else + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + Duplicated{rt} + else + Const{rt} + end + end - res = fwddiff_with_return(Val($Width), dupClosure ? Duplicated(f, df) : Const(f), annotation, args...) - return if annotation <: Const - ReturnType(($(nres...),)) + world = codegen_world_age(FT, tt) + fa = if dupClosure + if width == 1 + Duplicated(f, df) else - if $Width == 1 - ReturnType((res[1], res[2])) - else - ReturnType((res[1], res[2]...)) - end + BatchDuplicated(f, df) + end + else + Const(f) + end + res = thunk(Val(world), FA, annotation, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), + ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI)(fa, args...) + return if annotation <: Const + ReturnType(allFirst(Val(width+1), res)) + else + if width == 1 + ReturnType((res[1], res[2])) + else + ReturnType((res[1], res[2]...)) end end end +function body_runtime_iterate_fwd(N, Width, wrapped, primtypes) + wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) + return quote + args = ($(wrappedexexpand...),) + tt′ = Enzyme.vaTypeof(args...) + FT = Core.Typeof(f) + fwddiff_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, FT, tt′, f, df, args...)::ReturnType + end +end + function func_runtime_iterate_fwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _ = setup_macro_wraps(true, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(true, N, Width, #=base=#nothing, #=iterate=#true) body = body_runtime_iterate_fwd(N, Width, wrapped, primtypes) quote @@ -397,75 +496,135 @@ end @generated function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(true, N, Width, :allargs) + _, _, primtypes, _, _, wrapped, _, _ = setup_macro_wraps(true, N, Width, :allargs, #=iterate=#true) return body_runtime_iterate_fwd(N, Width, wrapped, primtypes) end -function body_runtime_iterate_augfwd(N, Width, wrapped, primttypes) - nnothing = ntuple(i->nothing, Val(Width+1)) - nres = ntuple(i->:(origRet), Val(Width+1)) - nzeros = ntuple(i->:(Ref(zero(resT))), Val(Width)) - nres3 = ntuple(i->:(res[3]), Val(Width)) - ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) - Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) +function primal_tuple(args::Vararg{Annotation, Nargs}) where Nargs + ntuple(Val(Nargs)) do i + Base.@_inline_meta + args[i].val + end +end - return quote - args = ($(wrapped...),) - throw(AssertionError("Runtime iterate augmented forward pass unhandled, f=$f df=$df args=$args")) - - # TODO: Annotation of return value - # tt0 = Tuple{$(primtypes...)} - tt′ = Tuple{$(Types...)} - rt = Core.Compiler.return_type(f, Tuple{$(ElTypes...)}) - annotation = guess_activity(rt, API.DEM_ReverseModePrimal) +function shadow_tuple(::Val{1}, args::Vararg{Annotation, Nargs}) where Nargs + ntuple(Val(Nargs)) do i + Base.@_inline_meta + @assert !(args[i] isa Active) + if args[i] isa Const + args[i].val + else + args[i].dval + end + end +end - dupClosure = ActivityTup[1] - FT = Core.Typeof(f) - if dupClosure && guaranteed_const(FT) - dupClosure = false +function shadow_tuple(::Val{width}, args::Vararg{Annotation, Nargs}) where {width, Nargs} + ntuple(Val(width)) do w + ntuple(Val(Nargs)) do i + Base.@_inline_meta + @assert !(args[i] isa Active) + if args[i] isa Const + args[i].val + else + args[i].dval[w] end + end + end +end - world = codegen_world_age(FT, Tuple{$(ElTypes...)}) +# This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] +function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Val{ModifiedBetween0}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {width, dupClosure0, ReturnType, ModifiedBetween0, FT, tt′, DF, Nargs} + ReturnPrimal = Val(true) + ModifiedBetween = Val(ModifiedBetween0) - forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, - annotation, tt′, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + tt = Enzyme.vaEltypes(tt′) + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - internal_tape, origRet, initShadow = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) - resT = typeof(origRet) - if annotation <: Const - shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - return ReturnType(($(nres...), tape)) - elseif annotation <: Active - if $Width == 1 - shadow_return = Ref(make_zero(origRet)) - else - shadow_return = ($(nzeros...),) - end - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - if $Width == 1 - return ReturnType((origRet, shadow_return, tape)) + annotation = if width != 1 + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + BatchDuplicated{rt, width} + elseif annotation0 <: Active + Active{rt} + else + Const{rt} + end + else + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + Duplicated{rt} + elseif annotation0 <: Active + Active{rt} + else + Const{rt} + end + end + + internal_tape, origRet, initShadow = if f != Base.tuple + dupClosure = dupClosure0 && !guaranteed_const(FT) + FA = dupClosure ? Duplicated{FT} : Const{FT} + + fa = if dupClosure + if width == 1 + Duplicated(f, df) else - return ReturnType((origRet, shadow_return..., tape)) + BatchDuplicated(f, df) end + else + Const(f) end + world = codegen_world_age(FT, tt) + forward, adjoint = thunk(Val(world), FA, + annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + forward(fa, args...) + else + nothing, primal_tuple(args...), annotation <: Active ? nothing : shadow_tuple(Val(width), args...) + end - @assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed - + resT = typeof(origRet) + if annotation <: Const shadow_return = nothing tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - if $Width == 1 - return ReturnType((origRet, initShadow, tape)) + return ReturnType((allFirst(Val(width+1), origRet)..., tape)) + elseif annotation <: Active + if width == 1 + shadow_return = Ref(make_zero(origRet)) else - return ReturnType((origRet, initShadow..., tape)) + shadow_return = allZero(Val(width), origRet) end + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + if width == 1 + return ReturnType((origRet, shadow_return, tape)) + else + return ReturnType((origRet, shadow_return..., tape)) + end + end + + @assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed + + shadow_return = nothing + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + if width == 1 + return ReturnType((origRet, initShadow, tape)) + else + return ReturnType((origRet, initShadow..., tape)) + end +end + +function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) + wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) + return quote + args = ($(wrappedexexpand...),) + tt′ = Enzyme.vaTypeof(args...) + FT = Core.Typeof(f) + augfwd_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, Val(concat($(modbetween...))), FT, tt′, f, df, args...)::ReturnType end end function func_runtime_iterate_augfwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _ = setup_macro_wraps(false, N, Width) - body = body_runtime_iterate_augfwd(N, Width, wrapped, primtypes) + _, _, primtypes, allargs, typeargs, wrapped, _, modbetween = setup_macro_wraps(false, N, Width, #=base=#nothing, #=iterate=#true) + body = body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) quote function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} @@ -476,11 +635,139 @@ end @generated function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(false, N, Width, :allargs) - return body_runtime_iterate_augfwd(N, Width, wrapped, primtypes) + _, _, primtypes, _, _, wrapped, _ , modbetween, = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) + return body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) +end + + + +# This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] +function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween0}, ::Val{lengths}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, tape, shadowargs, args::Vararg{Annotation, Nargs})::Nothing where {width, dupClosure0, ModifiedBetween0, lengths, FT, tt′, DF, Nargs} + ReturnPrimal = Val(true) + ModifiedBetween = Val(ModifiedBetween0) + + dupClosure = dupClosure0 && !guaranteed_const(FT) + FA = dupClosure ? Duplicated{FT} : Const{FT} + + tt = Enzyme.vaEltypes(tt′) + + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + + annotation = if width != 1 + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + BatchDuplicated{rt, width} + elseif annotation0 <: Active + Active{rt} + else + Const{rt} + end + else + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + Duplicated{rt} + elseif annotation0 <: Active + Active{rt} + else + Const{rt} + end + end + + tup = if f != Base.tuple + world = codegen_world_age(FT, tt) + + fa = if dupClosure + if width == 1 + Duplicated(f, df) + else + BatchDuplicated(f, df) + end + else + Const(f) + end + forward, adjoint = thunk(Val(world), FA, + annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + + args2 = if tape.shadow_return !== nothing + if width == 1 + (args..., tape.shadow_return[]) + else + (args..., ntuple(Val(width)) do w + Base.@_inline_meta + tape.shadow_return[w][] + end) + end + else + args + end + + adjoint(fa, args2..., tape.internal_tape)[1] + else + ntuple(Val(Nargs)) do i + Base.@_inline_meta + if args[i] isa Active + if width == 1 + tape.shadow_return[][i] + else + ntuple(Val(width)) do w + Base.@_inline_meta + tape.shadow_return[w][][i] + end + end + else + nothing + end + end + end + + ntuple(Val(Nargs)) do i + Base.@_inline_meta + + ntuple(Val(width)) do w + Base.@_inline_meta + + if tup[i] == nothing + else + expr = if width == 1 + tup[i] + else + tup[i][w] + end + idx_of_vec, idx_in_vec = lengths[i] + vec = @inbounds shadowargs[idx_of_vec][w] + if vec isa Base.RefValue + vecld = vec[] + T = Core.Typeof(vecld) + vec[] = splatnew(T, ntuple(Val(fieldcount(T))) do i + Base.@_inline_meta + prev = getfield(vecld, i) + if i == idx_in_vec + recursive_add(prev, expr) + else + prev + end + end) + else + val = @inbounds vec[idx_in_vec] + if val isa Base.RefValue + val[] = recursive_add(val[], expr) + elseif ismutable(vec) + @inbounds vec[idx_in_vec] = recursive_add(val, expr) + else + error("Enzyme Mutability Error: Cannot in place to immutable value vec[$idx_in_vec] = $val, vec=$vec") + end + end + end + + nothing + end + + nothing + end + nothing end -function body_runtime_iterate_rev(N, Width, wrapped, primttypes, shadowargs) +function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shadowargs) outs = [] for i in 1:N for w in 1:Width @@ -494,7 +781,7 @@ function body_runtime_iterate_rev(N, Width, wrapped, primttypes, shadowargs) elseif $shad isa Base.RefValue $shad[] = recursive_add($shad[], $expr) else - error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)) + error("Enzyme Mutability Error: Cannot add in place to immutable value "*string($shad)) end ) push!(outs, out) @@ -514,40 +801,30 @@ function body_runtime_iterate_rev(N, Width, wrapped, primttypes, shadowargs) ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) + wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) + lengths = ntuple(i->quote + (ntuple(Val(length($(primargs[i])))) do j + Base.@_inline_meta + ($i, j) + end) + end, Val(N)) + + shadowsplat = Expr[] + for s in shadowargs + push!(shadowsplat, :(($(s...),))) + end quote - args = ($(wrapped...),) - throw(AssertionError("Runtime iterate reverse pass unhandled, f=$f df=$df args=$args")) - - # TODO: Annotation of return value - # tt0 = Tuple{$(primtypes...)} - tt = Tuple{$(ElTypes...)} - tt′ = Tuple{$(Types...)} - rt = Core.Compiler.return_type(f, tt) - annotation = guess_activity(rt, API.DEM_ReverseModePrimal) - - dupClosure = ActivityTup[1] + args = ($(wrappedexexpand...),) + tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - if dupClosure && guaranteed_const(FT) - dupClosure = false - end - world = codegen_world_age(FT, tt) - - forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - if tape.shadow_return !== nothing - args = (args..., $shadowret) - end - - tup = adjoint(dupClosure ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] - - $(outs...) + rev_with_return(Val($Width), Val(ActivityTup[1]), Val(concat($(modbetween...))), Val(concat($(lengths...))), FT, tt′, f, df, tape, ($(shadowsplat...),), args...) return nothing end end function func_runtime_iterate_rev(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width) - body = body_runtime_iterate_rev(N, Width, wrapped, primtypes, batchshadowargs) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween = setup_macro_wraps(false, N, Width, #=body=#nothing, #=iterate=#true) + body = body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs) quote function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)} @@ -558,8 +835,8 @@ end @generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width, :allargs) - return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) + return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs) end # Create specializations @@ -697,7 +974,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, end debug_from_orig!(gutils, cal, orig) - + if tape === nothing llty = convert(LLVMType, ReturnType) cal = LLVM.addrspacecast!(B, cal, LLVM.PointerType(T_jlvalue, Derived)) @@ -778,7 +1055,7 @@ function common_generic_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) width = get_width(gutils) sret = generic_setup(orig, runtime_generic_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset, B, false) AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) - + if unsafe_load(shadowR) != C_NULL if width == 1 gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) @@ -1074,18 +1351,6 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) return false end -function error_if_active_iter(arg) - # check if it could contain an active - for v in arg - seen = () - T = Core.Typeof(v) - areg = active_reg_inner(T, seen, nothing, #=justActive=#Val(true)) - if areg == ActiveState - throw(AssertionError("Found unhandled active variable in tuple splat, jl_apply_iterate $T")) - end - end -end - function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) @@ -1100,51 +1365,41 @@ function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, width = get_width(gutils) - if v && v2 && isiter == Base.iterate && istup == Base.tuple && length(operands(orig)) >= offset+4 - origops = collect(operands(orig)[1:end-1]) - shadowins = [ invert_pointer(gutils, origops[i], B) for i in (offset+3):length(origops) ] - shadowres = if width == 1 - newops = LLVM.Value[] - newvals = API.CValueType[] - for (i, v) in enumerate(origops) - if i >= offset + 3 - shadowin2 = shadowins[i-offset-3+1] - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_iter), shadowin2]) - push!(newops, shadowin2) - push!(newvals, API.VT_Shadow) - else - push!(newops, new_from_original(gutils, origops[i])) - push!(newvals, API.VT_Primal) - end - end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) - callconv!(cal, callconv(orig)) - cal - else - ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) - shadow = LLVM.UndefValue(ST) - for j in 1:width - newops = LLVM.Value[] - newvals = API.CValueType[] - for (i, v) in enumerate(origops) - if i >= offset + 3 - shadowin2 = extract_value!(B, shadowins[i-offset-3+1], j-1) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_iter), shadowin2]) - push!(newops, shadowin2) - push!(newvals, API.VT_Shadow) - else - push!(newops, new_from_original(gutils, origops[i])) - push!(newvals, API.VT_Primal) - end + if v && isiter == Base.iterate + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + + sret = generic_setup(orig, runtime_iterate_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+2, B, false) + AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) + + if unsafe_load(shadowR) != C_NULL + if width == 1 + gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + shadow = LLVM.load!(B, T_prjlvalue, gep) + else + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + shadow = LLVM.UndefValue(ST) + for i in 1:width + gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + ld = LLVM.load!(B, T_prjlvalue, gep) + shadow = insert_value!(B, shadow, ld, i-1) end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) - callconv!(cal, callconv(orig)) - shadow = insert_value!(B, shadow, cal, j-1) end - shadow + unsafe_store!(shadowR, shadow.ref) end - unsafe_store!(shadowR, shadowres.ref) + tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) + unsafe_store!(tapeR, tape.ref) + + if normalR != C_NULL + normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + unsafe_store!(normalR, normal.ref) + else + # Delete the primal code + ni = new_from_original(gutils, orig) + erase_with_placeholder(gutils, ni, orig) + end + return false return false end @@ -1155,6 +1410,17 @@ function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, end function common_apply_iterate_rev(offset, B, orig, gutils, tape) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + return nothing + end + + @assert tape !== C_NULL + width = get_width(gutils) + generic_setup(orig, runtime_iterate_rev, Nothing, gutils, #=start=#offset+2, B, true; tape) return nothing end diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index c6639eb8aa..1ee4f0d961 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -249,7 +249,7 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) return false end -function rt_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst} +function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {NT, T, T2, Nargs, symname, isconst} res = if dptr isa Base.RefValue Base.getfield(dptr[], symname) else @@ -260,40 +260,57 @@ function rt_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs if length(dptrs) == 0 return Ref{RT}(make_zero(res)) else - return ( (Ref{RT}(make_zero(res)) for _ in 1:(1+length(dptrs)))..., ) + return NT(ntuple(Val(1+length(dptrs))) do i + Base.@_inline_meta + Ref{RT}(make_zero(res)) + end) end else if length(dptrs) == 0 return res else - return (res, (getfield(dv, symname) for dv in dptrs)...) + fval = NT((res, (ntuple(Val(length(dptrs))) do i + Base.@_inline_meta + dv = dptrs[i] + getfield(dv isa Base.RefValue ? dv[] : dv, symname) + end)...)) + return fval end end end -function idx_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst} +function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {NT, T, T2, Nargs, symname, isconst} res = if dptr isa Base.RefValue Base.getfield(dptr[], symname+1) else Base.getfield(dptr, symname+1) end RT = Core.Typeof(res) - if active_reg(RT) + actreg = active_reg(RT) + if actreg if length(dptrs) == 0 - return Ref{RT}(make_zero(res)) + return Ref{RT}(make_zero(res))::Any else - return ( (Ref{RT}(make_zero(res)) for _ in 1:(1+length(dptrs)))..., ) + return NT(ntuple(Val(1+length(dptrs))) do i + Base.@_inline_meta + Ref{RT}(make_zero(res)) + end) end else if length(dptrs) == 0 - return res + return res::Any else - return (res, (getfield(dv, symname) for dv in dptrs)...) + fval = NT((res, (ntuple(Val(length(dptrs))) do i + Base.@_inline_meta + dv = dptrs[i] + getfield(dv isa Base.RefValue ? dv[] : dv, symname+1) + end)...)) + return fval end end end -function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst} +function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst} cur = if dptr isa Base.RefValue getfield(dptr[], symname) else @@ -303,17 +320,65 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, RT = Core.Typeof(cur) if active_reg(RT) && !isconst if length(dptrs) == 0 - setfield!(dptr, symname, recursive_add(cur, dret[])) + if dptr isa Base.RefValue + vload = dptr[] + dRT = Core.Typeof(vload) + dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do i + Base.@_inline_meta + prev = getfield(vload, i) + if fieldname(dRT, i) == symname + recursive_add(prev, dret[]) + else + prev + end + end) + else + setfield!(dptr, symname, recursive_add(cur, dret[])) + end else - setfield!(dptr, symname, recursive_add(cur, dret[1][])) + if dptr isa Base.RefValue + vload = dptr[] + dRT = Core.Typeof(vload) + dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j + Base.@_inline_meta + prev = getfield(vload, j) + if fieldname(dRT, j) == symname + recursive_add(prev, dret[1][]) + else + prev + end + end) + else + setfield!(dptr, symname, recursive_add(cur, dret[1][])) + end for i in 1:length(dptrs) - setfield!(dptrs[i], symname, recursive_add(cur, dret[1+i][])) + if dptrs[i] isa Base.RefValue + vload = dptrs[i][] + dRT = Core.Typeof(vload) + dptrs[i][] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j + Base.@_inline_meta + prev = getfield(vload, j) + if fieldname(dRT, j) == symname + recursive_add(prev, dret[1+i][]) + else + prev + end + end) + else + curi = if dptr isa Base.RefValue + Base.getfield(dptrs[i][], symname) + else + Base.getfield(dptrs[i], symname) + end + setfield!(dptrs[i], symname, recursive_add(curi, dret[1+i][])) + end end end end return nothing end -function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst} + +function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst} cur = if dptr isa Base.RefValue Base.getfield(dptr[], symname+1) else @@ -323,11 +388,58 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} RT = Core.Typeof(cur) if active_reg(RT) && !isconst if length(dptrs) == 0 - setfield!(dptr, symname+1, recursive_add(cur, dret[])) + if dptr isa Base.RefValue + vload = dptr[] + dRT = Core.Typeof(vload) + dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do i + Base.@_inline_meta + prev = getfield(vload, i) + if i == symname+1 + recursive_add(prev, dret[]) + else + prev + end + end) + else + setfield!(dptr, symname+1, recursive_add(cur, dret[])) + end else - setfield!(dptr, symname+1, recursive_add(cur, dret[1][])) + if dptr isa Base.RefValue + vload = dptr[] + dRT = Core.Typeof(vload) + dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j + Base.@_inline_meta + prev = getfield(vload, j) + if j == symname+1 + recursive_add(prev, dret[1][]) + else + prev + end + end) + else + setfield!(dptr, symname+1, recursive_add(cur, dret[1][])) + end for i in 1:length(dptrs) - setfield!(dptrs[i], symname+1, recursive_add(cur, dret[1+i][])) + if dptrs[i] isa Base.RefValue + vload = dptrs[i][] + dRT = Core.Typeof(vload) + dptrs[i][] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j + Base.@_inline_meta + prev = getfield(vload, j) + if j == symname+1 + recursive_add(prev, dret[1+i][]) + else + prev + end + end) + else + curi = if dptr isa Base.RefValue + Base.getfield(dptrs[i][], symname+1) + else + Base.getfield(dptrs[i], symname+1) + end + setfield!(dptrs[i], symname+1, recursive_add(curi, dret[1+i][])) + end end end end @@ -362,7 +474,8 @@ function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, ta inps = [new_from_original(gutils, ops[2])] end - vals = LLVM.Value[] + AA = Val(AnyArray(Int(width))) + vals = LLVM.Value[unsafe_to_llvm(AA)] push!(vals, inps[1]) sym = new_from_original(gutils, ops[3]) @@ -539,7 +652,8 @@ function jl_nthfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) inps = [new_from_original(gutils, ops[1])] end - vals = LLVM.Value[] + AA = Val(AnyArray(Int(width))) + vals = LLVM.Value[unsafe_to_llvm(AA)] push!(vals, inps[1]) sym = new_from_original(gutils, ops[2]) diff --git a/src/utils.jl b/src/utils.jl index a3268c6c94..916818181e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -8,7 +8,7 @@ @inline unsafe_to_pointer(val::Type{T}) where T = ccall(Base.@cfunction(x->x, Ptr{Cvoid}, (Ptr{Cvoid},)), Ptr{Cvoid}, (Any,), val) export unsafe_to_pointer -@inline is_concrete_tuple(x::T2) where T2 = (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) +@inline is_concrete_tuple(x::Type{T2}) where T2 = (T2 <: Tuple) && !(T2 === Tuple) && !(T2 isa UnionAll) export is_concrete_tuple const Tracked = 10 diff --git a/test/applyiter.jl b/test/applyiter.jl new file mode 100644 index 0000000000..2518e2d829 --- /dev/null +++ b/test/applyiter.jl @@ -0,0 +1,491 @@ +using Enzyme, Test + +concat() = () +concat(a) = a +concat(a, b) = (a..., b...) +concat(a, b, c...) = concat(concat(a, b), c...) + +metaconcat(x) = concat(x...) + +metaconcat2(x, y) = concat(x..., y...) + +midconcat(x, y) = (x, concat(y...)...) + +metaconcat3(x, y, z) = concat(x..., y..., z...) + +function metasumsq(f, args...) + res = 0.0 + x = f(args...) + for v in x + v = v::Float64 + res += v*v + end + return res +end + +function metasumsq2(f, args...) + res = 0.0 + x = f(args...) + for v in x + for v2 in v + v2 = v2::Float64 + res += v*v + end + end + return res +end + + +function metasumsq3(f, args...) + res = 0.0 + x = f(args...) + for v in x + v = v + res += v*v + end + return res +end + +function metasumsq4(f, args...) + res = 0.0 + x = f(args...) + for v in x + for v2 in v + v2 = v2 + res += v*v + end + end + return res +end + +function make_byref(out, fn, args...) + out[] = fn(args...) + nothing +end + +function tupapprox(a, b) + if a isa Tuple && b isa Tuple + if length(a) != length(b) + return false + end + for (aa, bb) in zip(a, b) + if !tupapprox(aa, bb) + return false + end + end + return true + end + if a isa Array && b isa Array + if size(a) != size(b) + return false + end + for i in length(a) + if !tupapprox(a[i], b[i]) + return false + end + end + return true + end + return a ≈ b +end + +@testset "Reverse Apply iterate" begin + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + res = Enzyme.autodiff(Reverse, metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + + dx = [(0.0, 0.0), (0.0, 0.0)] + res = Enzyme.autodiff(ReverseWithPrimal, metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @test res[2] ≈ 200.84999999999997 + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + + res = Enzyme.autodiff(Reverse, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx)) + @test dx ≈ [[4.0, 6.0], [15.8, 22.4]] + + dx = [[0.0, 0.0], [0.0, 0.0]] + + res = Enzyme.autodiff(ReverseWithPrimal, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx)) + + @test res[2] ≈ 200.84999999999997 + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + + + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + + y = [(13, 17), (25, 31)] + res = Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Const(y)) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + + + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + y = [(13, 17), (25, 31)] + dy = [(0, 0), (0, 0)] + res = Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy)) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + + + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + y = [[13, 17], [25, 31]] + res = Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Const(y)) + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + y = [[13, 17], [25, 31]] + dy = [[0, 0], [0, 0]] + res = Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy)) + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) +end + +@testset "BatchReverse Apply iterate" begin + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + dx2 = [(0.0, 0.0), (0.0, 0.0)] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) + + dx = [(0.0, 0.0), (0.0, 0.0)] + dx2 = [(0.0, 0.0), (0.0, 0.0)] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test out[] ≈ 200.84999999999997 + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + dx2 = [[0.0, 0.0], [0.0, 0.0]] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test dx ≈ [[4.0, 6.0], [15.8, 22.4]] + @test dx2 ≈ [[3*4.0, 3*6.0], [3*15.8, 3*22.4]] + + dx = [[0.0, 0.0], [0.0, 0.0]] + dx2 = [[0.0, 0.0], [0.0, 0.0]] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + + @test out[] ≈ 200.84999999999997 + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + @test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]]) + + + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + dx2 = [(0.0, 0.0), (0.0, 0.0)] + + y = [(13, 17), (25, 31)] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y)) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) + + + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + dx2 = [(0.0, 0.0), (0.0, 0.0)] + y = [(13, 17), (25, 31)] + dy = [(0, 0), (0, 0)] + dy2 = [(0, 0), (0, 0)] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3),Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) + + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + dx2 = [[0.0, 0.0], [0.0, 0.0]] + y = [[13, 17], [25, 31]] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y)) + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + @test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]]) + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + dx2 = [[0.0, 0.0], [0.0, 0.0]] + y = [[13, 17], [25, 31]] + dy = [[0, 0], [0, 0]] + dy2 = [[0, 0], [0, 0]] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + @test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]]) +end + +@testset "Forward Apply iterate" begin + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(13.7, 15.2), (100.02, 304.1)] + + dres, = Enzyme.autodiff(Forward, metaconcat, Duplicated(x, dx)) + @test length(dres) == 4 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(x, dx)) + @test length(res) == 4 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + @test length(dres) == 4 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + + a = [("a", "b"), ("c", "d")] + da = [("e", "f"), ("g", "h")] + + dres, = Enzyme.autodiff(Forward, metaconcat, Duplicated(a, da)) + @test length(dres) == 4 + @test dres[1] == "a" + @test dres[2] == "b" + @test dres[3] == "c" + @test dres[4] == "d" + + res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(a, da)) + @test length(res) == 4 + @test res[1] == "a" + @test res[2] == "b" + @test res[3] == "c" + @test res[4] == "d" + @test length(dres) == 4 + @test dres[1] == "a" + @test dres[2] == "b" + @test dres[3] == "c" + @test dres[4] == "d" + + + Enzyme.autodiff(Forward, metaconcat, Const(a)) + +@static if VERSION ≥ v"1.7-" + dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Duplicated(a, da)) + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" + + res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Duplicated(a, da)) + @test length(res) == 5 + @test res[1] ≈ 1.0 + @test res[2] == "a" + @test res[3] == "b" + @test res[4] == "c" + @test res[5] == "d" + + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" + + + dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Const(a)) + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" + + res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Const(a)) + @test length(res) == 5 + @test res[1] ≈ 1.0 + @test res[2] == "a" + @test res[3] == "b" + @test res[4] == "c" + @test res[5] == "d" + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" +end + + y = [(-92.0, -93.0), (-97.9, -911.2)] + dy = [(-913.7, -915.2), (-9100.02, -9304.1)] + + dres, = Enzyme.autodiff(Forward, metaconcat2, Duplicated(x, dx), Duplicated(y, dy)) + @test length(dres) == 8 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + @test dres[5] ≈ -913.7 + @test dres[6] ≈ -915.2 + @test dres[7] ≈ -9100.02 + @test dres[8] ≈ -9304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat2, Duplicated, Duplicated(x, dx), Duplicated(y, dy)) + @test length(res) == 8 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + @test res[5] ≈ -92.0 + @test res[6] ≈ -93.0 + @test res[7] ≈ -97.9 + @test res[8] ≈ -911.2 + @test length(dres) == 8 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + @test dres[5] ≈ -913.7 + @test dres[6] ≈ -915.2 + @test dres[7] ≈ -9100.02 + @test dres[8] ≈ -9304.1 + + + dres, = Enzyme.autodiff(Forward, metaconcat3, Duplicated(x, dx), Const(a), Duplicated(y, dy)) + @test length(dres) == 12 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + @test dres[5] == "a" + @test dres[6] == "b" + @test dres[7] == "c" + @test dres[8] == "d" + + @test dres[9] ≈ -913.7 + @test dres[10] ≈ -915.2 + @test dres[11] ≈ -9100.02 + @test dres[12] ≈ -9304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat3, Duplicated, Duplicated(x, dx), Const(a), Duplicated(y, dy)) + @test length(res) == 12 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + + @test res[5] == "a" + @test res[6] == "b" + @test res[7] == "c" + @test res[8] == "d" + + @test res[9] ≈ -92.0 + @test res[10] ≈ -93.0 + @test res[11] ≈ -97.9 + @test res[12] ≈ -911.2 + + @test length(dres) == 12 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + @test dres[5] == "a" + @test dres[6] == "b" + @test dres[7] == "c" + @test dres[8] == "d" + + @test dres[9] ≈ -913.7 + @test dres[10] ≈ -915.2 + @test dres[11] ≈ -9100.02 + @test dres[12] ≈ -9304.1 + + + dres, = Enzyme.autodiff(Forward, metaconcat, BatchDuplicated(x, (dx, dy))) + @test length(dres[1]) == 4 + @test dres[1][1] ≈ 13.7 + @test dres[1][2] ≈ 15.2 + @test dres[1][3] ≈ 100.02 + @test dres[1][4] ≈ 304.1 + @test length(dres[2]) == 4 + @test dres[2][1] ≈ -913.7 + @test dres[2][2] ≈ -915.2 + @test dres[2][3] ≈ -9100.02 + @test dres[2][4] ≈ -9304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, BatchDuplicated(x, (dx, dy))) + @test length(res) == 4 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + @test length(dres[1]) == 4 + @test dres[1][1] ≈ 13.7 + @test dres[1][2] ≈ 15.2 + @test dres[1][3] ≈ 100.02 + @test dres[1][4] ≈ 304.1 + @test length(dres[2]) == 4 + @test dres[2][1] ≈ -913.7 + @test dres[2][2] ≈ -915.2 + @test dres[2][3] ≈ -9100.02 + @test dres[2][4] ≈ -9304.1 +end + +@testset "legacy reverse apply iterate" begin + function mktup(v) + tup = tuple(v...) + return tup[1][1] * tup[3][1] + end + + data = [[3.0], nothing, [2.0]] + ddata = [[0.0], nothing, [0.0]] + + Enzyme.autodiff(Reverse, mktup, Duplicated(data, ddata)) + @test ddata[1][1] ≈ 2.0 + @test ddata[3][1] ≈ 3.0 + + function mktup2(v) + tup = tuple(v...) + return (tup[1][1] * tup[3])::Float64 + end + + data = [[3.0], nothing, 2.0] + ddata = [[0.0], nothing, 0.0] + + @test_throws AssertionError Enzyme.autodiff(Reverse, mktup2, Duplicated(data, ddata)) + + function mktup3(v) + tup = tuple(v..., v...) + return tup[1][1] * tup[1][1] + end + + data = [[3.0]] + ddata = [[0.0]] + + Enzyme.autodiff(Reverse, mktup3, Duplicated(data, ddata)) + @test ddata[1][1] ≈ 6.0 +end diff --git a/test/runtests.jl b/test/runtests.jl index 0212ec0d83..e931666f90 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -137,6 +137,7 @@ end @assert Enzyme.Compiler.active_reg_inner(Symbol, (), nothing) == Enzyme.Compiler.AnyState @assert Enzyme.Compiler.active_reg_inner(String, (), nothing) == Enzyme.Compiler.AnyState @assert Enzyme.Compiler.active_reg_inner(Tuple{Any,Int64}, (), nothing) == Enzyme.Compiler.DupState + @assert Enzyme.Compiler.active_reg_inner(Tuple{S,Int64} where S, (), Base.get_world_counter()) == Enzyme.Compiler.DupState @assert Enzyme.Compiler.active_reg_inner(Union{Float64,Nothing}, (), nothing) == Enzyme.Compiler.DupState @assert Enzyme.Compiler.active_reg_inner(Union{Float64,Nothing}, (), nothing, #=justActive=#Val(false), #=unionSret=#Val(true)) == Enzyme.Compiler.ActiveState world = codegen_world_age(typeof(f0), Tuple{Float64}) @@ -1670,232 +1671,38 @@ end end -concat() = () -concat(a) = a -concat(a, b) = (a..., b...) -concat(a, b, c...) = concat(concat(a, b), c...) - -metaconcat(x) = concat(x...) - -metaconcat2(x, y) = concat(x..., y...) - -midconcat(x, y) = (x, concat(y...)...) - -metaconcat3(x, y, z) = concat(x..., y..., z...) - -@testset "Forward Apply iterate" begin - x = [(2.0, 3.0), (7.9, 11.2)] - dx = [(13.7, 15.2), (100.02, 304.1)] - - dres, = Enzyme.autodiff(Forward, metaconcat, Duplicated(x, dx)) - @test length(dres) == 4 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - - res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(x, dx)) - @test length(res) == 4 - @test res[1] ≈ 2.0 - @test res[2] ≈ 3.0 - @test res[3] ≈ 7.9 - @test res[4] ≈ 11.2 - @test length(dres) == 4 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - - - a = [("a", "b"), ("c", "d")] - da = [("e", "f"), ("g", "h")] - - dres, = Enzyme.autodiff(Forward, metaconcat, Duplicated(a, da)) - @test length(dres) == 4 - @test dres[1] == "a" - @test dres[2] == "b" - @test dres[3] == "c" - @test dres[4] == "d" - - res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(a, da)) - @test length(res) == 4 - @test res[1] == "a" - @test res[2] == "b" - @test res[3] == "c" - @test res[4] == "d" - @test length(dres) == 4 - @test dres[1] == "a" - @test dres[2] == "b" - @test dres[3] == "c" - @test dres[4] == "d" - - - Enzyme.autodiff(Forward, metaconcat, Const(a)) +function batchgf(out, args) + res = 0.0 + x = Base.inferencebarrier((args[1][1],)) + for v in x + v = v::Float64 + res += v + break + end + out[] = res + nothing +end -@static if VERSION ≥ v"1.7-" - dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Duplicated(a, da)) - @test length(dres) == 5 - @test dres[1] ≈ 7.0 - @test dres[2] == "a" - @test dres[3] == "b" - @test dres[4] == "c" - @test dres[5] == "d" - - res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Duplicated(a, da)) - @test length(res) == 5 - @test res[1] ≈ 1.0 - @test res[2] == "a" - @test res[3] == "b" - @test res[4] == "c" - @test res[5] == "d" - - @test length(dres) == 5 - @test dres[1] ≈ 7.0 - @test dres[2] == "a" - @test dres[3] == "b" - @test dres[4] == "c" - @test dres[5] == "d" - - - dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Const(a)) - @test length(dres) == 5 - @test dres[1] ≈ 7.0 - @test dres[2] == "a" - @test dres[3] == "b" - @test dres[4] == "c" - @test dres[5] == "d" - - res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Const(a)) - @test length(res) == 5 - @test res[1] ≈ 1.0 - @test res[2] == "a" - @test res[3] == "b" - @test res[4] == "c" - @test res[5] == "d" - @test length(dres) == 5 - @test dres[1] ≈ 7.0 - @test dres[2] == "a" - @test dres[3] == "b" - @test dres[4] == "c" - @test dres[5] == "d" -end - - y = [(-92.0, -93.0), (-97.9, -911.2)] - dy = [(-913.7, -915.2), (-9100.02, -9304.1)] - - dres, = Enzyme.autodiff(Forward, metaconcat2, Duplicated(x, dx), Duplicated(y, dy)) - @test length(dres) == 8 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - @test dres[5] ≈ -913.7 - @test dres[6] ≈ -915.2 - @test dres[7] ≈ -9100.02 - @test dres[8] ≈ -9304.1 - - res, dres = Enzyme.autodiff(Forward, metaconcat2, Duplicated, Duplicated(x, dx), Duplicated(y, dy)) - @test length(res) == 8 - @test res[1] ≈ 2.0 - @test res[2] ≈ 3.0 - @test res[3] ≈ 7.9 - @test res[4] ≈ 11.2 - @test res[5] ≈ -92.0 - @test res[6] ≈ -93.0 - @test res[7] ≈ -97.9 - @test res[8] ≈ -911.2 - @test length(dres) == 8 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - @test dres[5] ≈ -913.7 - @test dres[6] ≈ -915.2 - @test dres[7] ≈ -9100.02 - @test dres[8] ≈ -9304.1 - - - dres, = Enzyme.autodiff(Forward, metaconcat3, Duplicated(x, dx), Const(a), Duplicated(y, dy)) - @test length(dres) == 12 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - - @test dres[5] == "a" - @test dres[6] == "b" - @test dres[7] == "c" - @test dres[8] == "d" - - @test dres[9] ≈ -913.7 - @test dres[10] ≈ -915.2 - @test dres[11] ≈ -9100.02 - @test dres[12] ≈ -9304.1 - - res, dres = Enzyme.autodiff(Forward, metaconcat3, Duplicated, Duplicated(x, dx), Const(a), Duplicated(y, dy)) - @test length(res) == 12 - @test res[1] ≈ 2.0 - @test res[2] ≈ 3.0 - @test res[3] ≈ 7.9 - @test res[4] ≈ 11.2 - - @test res[5] == "a" - @test res[6] == "b" - @test res[7] == "c" - @test res[8] == "d" - - @test res[9] ≈ -92.0 - @test res[10] ≈ -93.0 - @test res[11] ≈ -97.9 - @test res[12] ≈ -911.2 - - @test length(dres) == 12 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - - @test dres[5] == "a" - @test dres[6] == "b" - @test dres[7] == "c" - @test dres[8] == "d" - - @test dres[9] ≈ -913.7 - @test dres[10] ≈ -915.2 - @test dres[11] ≈ -9100.02 - @test dres[12] ≈ -9304.1 - - - dres, = Enzyme.autodiff(Forward, metaconcat, BatchDuplicated(x, (dx, dy))) - @test length(dres[1]) == 4 - @test dres[1][1] ≈ 13.7 - @test dres[1][2] ≈ 15.2 - @test dres[1][3] ≈ 100.02 - @test dres[1][4] ≈ 304.1 - @test length(dres[2]) == 4 - @test dres[2][1] ≈ -913.7 - @test dres[2][2] ≈ -915.2 - @test dres[2][3] ≈ -9100.02 - @test dres[2][4] ≈ -9304.1 - - res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, BatchDuplicated(x, (dx, dy))) - @test length(res) == 4 - @test res[1] ≈ 2.0 - @test res[2] ≈ 3.0 - @test res[3] ≈ 7.9 - @test res[4] ≈ 11.2 - @test length(dres[1]) == 4 - @test dres[1][1] ≈ 13.7 - @test dres[1][2] ≈ 15.2 - @test dres[1][3] ≈ 100.02 - @test dres[1][4] ≈ 304.1 - @test length(dres[2]) == 4 - @test dres[2][1] ≈ -913.7 - @test dres[2][2] ≈ -915.2 - @test dres[2][3] ≈ -9100.02 - @test dres[2][4] ≈ -9304.1 +@testset "Batch Getfield" begin + x = [(2.0, 3.0)] + dx = [(0.0, 0.0)] + dx2 = [(0.0, 0.0)] + dx3 = [(0.0, 0.0)] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + dout3 = Ref(5.0) + Enzyme.autodiff(Reverse, batchgf, Const, BatchDuplicatedNoNeed(out, (dout, dout2, dout3)), BatchDuplicated(x, (dx, dx2, dx3))) + @test dx[1][1] ≈ 1.0 + @test dx[1][2] ≈ 0.0 + @test dx2[1][1] ≈ 3.0 + @test dx2[1][2] ≈ 0.0 + @test dx3[1][1] ≈ 5.0 + @test dx2[1][2] ≈ 0.0 end +include("applyiter.jl") + @testset "Dynamic Val Construction" begin dyn_f(::Val{D}) where D = prod(D) @@ -2566,41 +2373,6 @@ end Enzyme.API.runtimeActivity!(false) end -@testset "apply iterate" begin - function mktup(v) - tup = tuple(v...) - return tup[1][1] * tup[3][1] - end - - data = [[3.0], nothing, [2.0]] - ddata = [[0.0], nothing, [0.0]] - - Enzyme.autodiff(Reverse, mktup, Duplicated(data, ddata)) - @test ddata[1][1] ≈ 2.0 - @test ddata[3][1] ≈ 3.0 - - function mktup2(v) - tup = tuple(v...) - return (tup[1][1] * tup[3])::Float64 - end - - data = [[3.0], nothing, 2.0] - ddata = [[0.0], nothing, 0.0] - - @test_throws AssertionError Enzyme.autodiff(Reverse, mktup2, Duplicated(data, ddata)) - - function mktup3(v) - tup = tuple(v..., v...) - return tup[1][1] * tup[1][1] - end - - data = [[3.0]] - ddata = [[0.0]] - - Enzyme.autodiff(Reverse, mktup3, Duplicated(data, ddata)) - @test ddata[1][1] ≈ 6.0 -end - @testset "BLAS" begin x = [2.0, 3.0] dx = [0.2,0.3] From ffcc20c81b4fb1a1f1785f75f29c61df63f9f677 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 10 Jun 2024 11:02:26 -0500 Subject: [PATCH 110/495] Fix const-only apply iterate (#1526) * Fix const-only apply iterate * fix ct * Fix mixed activity for type unstable * Update jitrules.jl * Update jitrules.jl * wip tuple * fix batch tuple generation * Ensure runtime store error * fix * cleanup * ignore 1.8 * newstructv * ignore test --- src/compiler.jl | 43 +++- src/rules/jitrules.jl | 390 +++++++++++++++++++++-------- src/rules/typeunstablerules.jl | 445 +++++++++++++++++++++++++++++---- test/applyiter.jl | 14 ++ test/mixed.jl | 71 ++++++ test/runtests.jl | 1 + test/threads.jl | 3 +- 7 files changed, 808 insertions(+), 159 deletions(-) create mode 100644 test/mixed.jl diff --git a/src/compiler.jl b/src/compiler.jl index cca67bc874..ed44563ec4 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -474,7 +474,7 @@ end end @assert !Base.isabstracttype(T) - if !(Base.isconcretetype(T) || is_concrete_tuple(T) || T isa UnionAll) + if !(Base.isconcretetype(T) || (T <: Tuple && T != Tuple) || T isa UnionAll) throw(AssertionError("Type $T is not concrete type or concrete tuple")) end @@ -515,7 +515,7 @@ end return active_reg_inner(T, (), world) end -@inline function active_reg(::Type{T}, world::Union{Nothing, UInt}=nothing)::Bool where {T} +Base.@pure @inline function active_reg(::Type{T}, world::Union{Nothing, UInt}=nothing)::Bool where {T} seen = () # check if it could contain an active @@ -3342,6 +3342,8 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr world = job.world interp = GPUCompiler.get_interpreter(job) rt = job.config.params.rt + @assert eltype(rt) != Union{} + shadow_init = job.config.params.shadowInit ctx = context(mod) dl = string(LLVM.datalayout(mod)) @@ -3546,6 +3548,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, pactualRetType = actualRetType sret_union = is_sret_union(actualRetType) literal_rt = eltype(rettype) + @assert literal_rt != Union{} sret_union_rt = is_sret_union(literal_rt) @assert sret_union == sret_union_rt if sret_union @@ -3684,9 +3687,10 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end end - combinedReturn = Tuple{sret_types...} - if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types) - combinedReturn = AnonymousStruct(combinedReturn) + combinedReturn = if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types) + AnonymousStruct(Tuple{sret_types...}) + else + Tuple{sret_types...} end uses_sret = is_sret(combinedReturn) @@ -4794,6 +4798,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; libraries::Bool=true, deferred_codegen::Bool=true, optimize::Bool=true, toplevel::Bool=true, strip::Bool=false, validate::Bool=true, only_entry::Bool=false, parent_job::Union{Nothing, CompilerJob} = nothing) params = job.config.params + if params.run_enzyme + @assert eltype(params.rt) != Union{} + end expectedTapeType = params.expectedTapeType mode = params.mode TT = params.TT @@ -4801,7 +4808,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; abiwrap = params.abiwrap primal = job.source modifiedBetween = params.modifiedBetween - @assert length(modifiedBetween) == length(TT.parameters) + if length(modifiedBetween) != length(TT.parameters) + throw(AssertionError("length(modifiedBetween) [aka $(length(modifiedBetween))] != length(TT.parameters) [aka $(length(TT.parameters))] at TT=$TT")) + end returnPrimal = params.returnPrimal if !(params.rt <: Const) @@ -5297,6 +5306,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end @assert actualRetType !== nothing + if params.run_enzyme + @assert actualRetType != Union{} + end if must_wrap llvmfn = primalf @@ -5838,7 +5850,11 @@ end end push!(ccexprs, argexpr) - if !(FA <: Const) + if (FA <: Active) + return quote + error("Cannot have function with Active annotation, $FA") + end + elseif !(FA <: Const) argexpr = :(fn.dval) if isboxed push!(types, Any) @@ -6274,9 +6290,16 @@ end compile_result = cached_compilation(job) if !run_enzyme ErrT = PrimalErrorThunk{typeof(compile_result.adjoint), FA, rt2, TT, width, ReturnPrimal, World} - return quote - Base.@_inline_meta - $ErrT($(compile_result.adjoint)) + if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient + return quote + Base.@_inline_meta + ($ErrT($(compile_result.adjoint)), $ErrT($(compile_result.adjoint))) + end + else + return quote + Base.@_inline_meta + $ErrT($(compile_result.adjoint)) + end end elseif Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient TapeType = compile_result.TapeType diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index af12d2bfbc..e5ce78aa02 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1,5 +1,87 @@ +function func_mixed_call(N) + allargs = Expr[] + typeargs = Union{Symbol,Expr}[] + exprs2 = Union{Symbol,Expr}[] + for i in 1:N + arg = Symbol("arg_$i") + targ = Symbol("T$i") + e = :($arg::$targ) + push!(allargs, e) + push!(typeargs, targ) + + inarg = quote + if RefTypes[1+$i] + $arg[] + else + $arg + end + end + push!(exprs2, inarg) + end + + quote + @generated function runtime_mixed_call(::Val{RefTypes}, f::F, $(allargs...)) where {RefTypes, F, $(typeargs...)} + fexpr = :f + if RefTypes[1] + fexpr = :(($fexpr)[]) + end + exprs2 = Union{Symbol,Expr}[] + for i in 1:$N + arg = Symbol("arg_$i") + inarg = if RefTypes[1+i] + :($arg[]) + else + :($arg) + end + push!(exprs2, inarg) + end + @static if VERSION ≥ v"1.8-" + return quote + Base.@_inline_meta + @inline $fexpr($(exprs2...)) + end + else + return quote + Base.@_inline_meta + $fexpr($(exprs2...)) + end + end + end + end +end -function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate=false) +@generated function runtime_mixed_call(::Val{RefTypes}, f::F, allargs::Vararg{Any, N}) where {RefTypes, F, N} + fexpr = :f + if RefTypes[1] + fexpr = :(($fexpr)[]) + end + exprs2 = Union{Symbol,Expr}[] + for i in 1:N + inarg = if RefTypes[1+i] + :(allargs[$i][]) + else + :(allargs[$i]) + end + push!(exprs2, inarg) + end + @static if VERSION ≥ v"1.8-" + return quote + Base.@_inline_meta + @inline $fexpr($(exprs2...)) + end + else + return quote + Base.@_inline_meta + $fexpr($(exprs2...)) + end + end +end + +for N in 0:10 + eval(func_mixed_call(N)) +end + +function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate=false; func=true, mixed_or_active = false) primargs = Union{Symbol,Expr}[] shadowargs = Union{Symbol,Expr}[] batchshadowargs = Vector{Union{Symbol,Expr}}[] @@ -8,18 +90,20 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, typeargs = Symbol[] dfns = Union{Symbol,Expr}[:df] base_idx = 1 - for w in 2:Width - if base === nothing - shad = Symbol("df_$w") - t = Symbol("DF__$w*") - e = :($shad::$t) - push!(allargs, e) - push!(typeargs, t) - else - shad = :($base[$base_idx]) - base_idx += 1 + if func + for w in 2:Width + if base === nothing + shad = Symbol("df_$w") + t = Symbol("DF__$w*") + e = :($shad::$t) + push!(allargs, e) + push!(typeargs, t) + else + shad = :($base[$base_idx]) + base_idx += 1 + end + push!(dfns, shad) end - push!(dfns, shad) end for i in 1:N if base === nothing @@ -60,6 +144,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, @assert length(primtypes) == N wrapped = Expr[] modbetween = Expr[:(MB[1])] + active_refs = Expr[] for i in 1:N if iterate push!(modbetween, quote @@ -69,6 +154,10 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, end end) end + aref = Symbol("active_ref_$i") + push!(active_refs, quote + $aref = active_reg_nothrow($(primtypes[i]), Val(nothing)); + end) expr = if iterate :( if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) @@ -88,23 +177,57 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, end ) else - :( - if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) - @assert $(primtypes[i]) !== DataType - if !$forwardMode && active_reg($(primtypes[i])) - Active($(primargs[i])) - else - $((Width == 1) ? :Duplicated : :BatchDuplicated)($(primargs[i]), $(shadowargs[i])) - end - else - Const($(primargs[i])) - end - - ) + if forwardMode + quote + if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) + $((Width == 1) ? :Duplicated : :BatchDuplicated)($(primargs[i]), $(shadowargs[i])) + else + Const($(primargs[i])) + end + end + else + quote + if ActivityTup[$i+1] && $aref != AnyState + @assert $(primtypes[i]) !== DataType + if $aref == ActiveState + Active($(primargs[i])) + elseif $aref == MixedState + $((Width == 1) ? :Duplicated : :BatchDuplicated)(Ref($(primargs[i])), $(shadowargs[i])) + else + $((Width == 1) ? :Duplicated : :BatchDuplicated)($(primargs[i]), $(shadowargs[i])) + end + else + Const($(primargs[i])) + end + end + end end push!(wrapped, expr) end - return primargs, shadowargs, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween + + any_mixed = quote false end + for i in 1:N + aref = Symbol("active_ref_$i") + if mixed_or_active + any_mixed = :($any_mixed || $aref == MixedState || $aref == ActiveState) + else + any_mixed = :($any_mixed || $aref == MixedState) + end + end + + if mixed_or_active + push!(active_refs, quote + active_refs = (false, $(collect(:($(Symbol("active_ref_$i")) == MixedState || $(Symbol("active_ref_$i")) == ActiveState) for i in 1:N)...)) + end) + else + push!(active_refs, quote + active_refs = (false, $(collect(:($(Symbol("active_ref_$i")) == MixedState) for i in 1:N)...)) + end) + end + push!(active_refs, quote + any_mixed = $any_mixed + end) + return primargs, shadowargs, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween, active_refs end function body_runtime_generic_fwd(N, Width, wrapped, primtypes) @@ -159,7 +282,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) end function func_runtime_generic_fwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(true, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _, _, _ = setup_macro_wraps(true, N, Width) body = body_runtime_generic_fwd(N, Width, wrapped, primtypes) quote @@ -171,46 +294,75 @@ end @generated function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _, _ = setup_macro_wraps(true, N, Width, :allargs) + _, _, primtypes, _, _, wrapped, _, _, _ = setup_macro_wraps(true, N, Width, :allargs) return body_runtime_generic_fwd(N, Width, wrapped, primtypes) end -function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) +function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) nnothing = ntuple(i->nothing, Val(Width+1)) nres = ntuple(i->:(origRet), Val(Width+1)) nzeros = ntuple(i->:(Ref(make_zero(origRet))), Val(Width)) nres3 = ntuple(i->:(res[3]), Val(Width)) - ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) - Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) + ElTypes = ntuple(i->:(eltype($(Symbol("type_$i")))), Val(N)) + + MakeTypes = ntuple(i->:($(Symbol("type_$i")) = Core.Typeof(args[$i])), Val(N)) + + Types = ntuple(i->Symbol("type_$i"), Val(N)) + + MixedTypes = ntuple(i->:($(Symbol("active_ref_$i") == MixedState) ? Ref($(Symbol("type_$i"))) : $(Symbol("type_$i"))), Val(N)) return quote + $(active_refs...) args = ($(wrapped...),) + $(MakeTypes...) - # TODO: Annotation of return value - # tt0 = Tuple{$(primtypes...)} - tt′ = Tuple{$(Types...)} - rt = Core.Compiler.return_type(f, Tuple{$(ElTypes...)}) - annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - - annotation = if $Width != 1 && annotation0 <: Duplicated - BatchDuplicated{rt, $Width} + FT = Core.Typeof(f) + dupClosure0 = if ActivityTup[1] + !guaranteed_const(FT) else - annotation0 + false end - dupClosure = ActivityTup[1] - FT = Core.Typeof(f) - if dupClosure && guaranteed_const(FT) - dupClosure = false - end - world = codegen_world_age(FT, Tuple{$(ElTypes...)}) + internal_tape, origRet, initShadow, annotation = if any_mixed + ttM = Tuple{Val{active_refs}, FT, $(ElTypes...)} + rtM = Core.Compiler.return_type(runtime_mixed_call, ttM) + annotation0M = guess_activity(rtM, API.DEM_ReverseModePrimal) - forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, - annotation, tt′, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + annotationM = if $Width != 1 && annotation0M <: Duplicated + BatchDuplicated{rt, $Width} + else + annotation0M + end + worldM = codegen_world_age(typeof(runtime_mixed_call), ttM) + ModifiedBetweenM = Val((false, false, element(ModifiedBetween)...)) + + forward, adjoint = thunk(Val(worldM), + Const{typeof(runtime_mixed_call)}, + annotationM, Tuple{Const{Val{active_refs}}, dupClosure0 ? Duplicated{FT} : Const{FT}, $(Types...)}, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetweenM, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + + forward(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args...)..., annotationM + + else + tt = Tuple{$(ElTypes...)} + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + + annotationA = if $Width != 1 && annotation0 <: Duplicated + BatchDuplicated{rt, $Width} + else + annotation0 + end + world = codegen_world_age(FT, tt) + + forward, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT}, + annotationA, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + + forward(dupClosure0 ? Duplicated(f, df) : Const(f), args...)..., annotationA + end - internal_tape, origRet, initShadow = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) resT = typeof(origRet) if annotation <: Const shadow_return = nothing @@ -243,8 +395,8 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) end function func_runtime_generic_augfwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(false, N, Width) - body = body_runtime_generic_augfwd(N, Width, wrapped, primtypes) + _, _, primtypes, allargs, typeargs, wrapped, _, _, active_refs = setup_macro_wraps(false, N, Width) + body = body_runtime_generic_augfwd(N, Width, wrapped, primtypes, active_refs) quote function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} @@ -255,11 +407,11 @@ end @generated function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _, _= setup_macro_wraps(false, N, Width, :allargs) - return body_runtime_generic_augfwd(N, Width, wrapped, primtypes) + _, _, primtypes, _, _, wrapped, _, _, active_refs = setup_macro_wraps(false, N, Width, :allargs) + return body_runtime_generic_augfwd(N, Width, wrapped, primtypes, active_refs) end -function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) +function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, active_refs) outs = [] for i in 1:N for w in 1:Width @@ -273,7 +425,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) elseif $shad isa Base.RefValue $shad[] = recursive_add($shad[], $expr) else - error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)) + error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)*" tup[i]="*string(tup[$i])*" i="*string($i)*" w="*string($w)*" tup="*string(tup)) end ) push!(outs, out) @@ -290,49 +442,81 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) shadowret = :(($(shadowret...),)) end - ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) - Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) + ElTypes = ntuple(i->:(eltype($(Symbol("type_$i")))), Val(N)) + + MakeTypes = ntuple(i->:($(Symbol("type_$i")) = Core.Typeof(args[$i])), Val(N)) + + Types = ntuple(i->Symbol("type_$i"), Val(N)) + + MixedTypes = ntuple(i->:($(Symbol("active_ref_$i") == MixedState) ? Ref($(Symbol("type_$i"))) : $(Symbol("type_$i"))), Val(N)) quote + $(active_refs...) args = ($(wrapped...),) + $(MakeTypes...) + + FT = Core.Typeof(f) + dupClosure0 = if ActivityTup[1] + !guaranteed_const(FT) + else + false + end - # TODO: Annotation of return value - # tt0 = Tuple{$(primtypes...)} - tt = Tuple{$(ElTypes...)} - tt′ = Tuple{$(Types...)} - rt = Core.Compiler.return_type(f, tt) - annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + if any_mixed + ttM = Tuple{Val{active_refs}, FT, $(ElTypes...)} + rtM = Core.Compiler.return_type(runtime_mixed_call, ttM) + annotation0M = guess_activity(rtM, API.DEM_ReverseModePrimal) - annotation = if $Width != 1 && annotation0 <: Duplicated - BatchDuplicated{rt, $Width} + annotationM = if $Width != 1 && annotation0M <: Duplicated + BatchDuplicated{rt, $Width} + else + annotation0M + end + worldM = codegen_world_age(typeof(runtime_mixed_call), ttM) + ModifiedBetweenM = Val((false, false, element(ModifiedBetween)...)) + + _, adjoint = thunk(Val(worldM), + Const{typeof(runtime_mixed_call)}, + annotationM, Tuple{Const{Val{active_refs}}, dupClosure0 ? Duplicated{FT} : Const{FT}, $(Types...)}, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetweenM, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + if tape.shadow_return !== nothing + adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape) + else + adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape) + end + nothing else - annotation0 - end + tt = Tuple{$(ElTypes...)} + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - dupClosure = ActivityTup[1] - FT = Core.Typeof(f) - if dupClosure && guaranteed_const(FT) - dupClosure = false - end - world = codegen_world_age(FT, tt) + annotation = if $Width != 1 && annotation0 <: Duplicated + BatchDuplicated{rt, $Width} + else + annotation0 + end - forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + world = codegen_world_age(FT, tt) - if tape.shadow_return !== nothing - args = (args..., $shadowret) - end + _, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT}, + annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - tup = adjoint(dupClosure ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] + tup = if tape.shadow_return !== nothing + adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] + else + adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] + end - $(outs...) + $(outs...) + end return nothing end end function func_runtime_generic_rev(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _ = setup_macro_wraps(false, N, Width) - body = body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) + _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width) + body = body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs, active_refs) quote function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)} @@ -343,8 +527,8 @@ end @generated function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, batchshadowargs, _ = setup_macro_wraps(false, N, Width, :allargs) - return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) + _, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs) + return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs, active_refs) end @inline concat() = () @@ -416,6 +600,13 @@ end end end +@inline function allSame(::Val{Width}, res) where Width + ntuple(Val(Width)) do i + Base.@_inline_meta + res + end +end + @inline function allZero(::Val{Width}, res) where Width ntuple(Val(Width)) do i Base.@_inline_meta @@ -484,7 +675,7 @@ function body_runtime_iterate_fwd(N, Width, wrapped, primtypes) end function func_runtime_iterate_fwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(true, N, Width, #=base=#nothing, #=iterate=#true) + _, _, primtypes, allargs, typeargs, wrapped, _, _, active_refs = setup_macro_wraps(true, N, Width, #=base=#nothing, #=iterate=#true) body = body_runtime_iterate_fwd(N, Width, wrapped, primtypes) quote @@ -496,7 +687,7 @@ end @generated function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _, _ = setup_macro_wraps(true, N, Width, :allargs, #=iterate=#true) + _, _, primtypes, _, _, wrapped, _, _, active_refs = setup_macro_wraps(true, N, Width, :allargs, #=iterate=#true) return body_runtime_iterate_fwd(N, Width, wrapped, primtypes) end @@ -586,7 +777,7 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} if annotation <: Const shadow_return = nothing tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - return ReturnType((allFirst(Val(width+1), origRet)..., tape)) + return ReturnType((allSame(Val(width+1), origRet)..., tape)) elseif annotation <: Active if width == 1 shadow_return = Ref(make_zero(origRet)) @@ -623,7 +814,7 @@ function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) end function func_runtime_iterate_augfwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _, modbetween = setup_macro_wraps(false, N, Width, #=base=#nothing, #=iterate=#true) + _, _, primtypes, allargs, typeargs, wrapped, _, modbetween, active_refs = setup_macro_wraps(false, N, Width, #=base=#nothing, #=iterate=#true) body = body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) quote @@ -635,7 +826,7 @@ end @generated function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _ , modbetween, = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) + _, _, primtypes, _, _, wrapped, _ , modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) return body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) end @@ -835,7 +1026,7 @@ end @generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 - primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs) end @@ -849,7 +1040,7 @@ for (N, Width) in Iterators.product(0:30, 1:10) eval(func_runtime_iterate_rev(N, Width)) end -function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false) +function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false, endcast=true) width = get_width(gutils) mode = get_mode(gutils) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -862,8 +1053,6 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, ActivityList = LLVM.Value[] - to_preserve = LLVM.Value[] - @assert length(ops) != 0 fill_val = unsafe_to_llvm(nothing) @@ -918,9 +1107,6 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, else ev = extract_value!(B, inverted, w-1) end - if tape !== nothing - push!(to_preserve, ev) - end end push!(vals, ev) @@ -929,7 +1115,13 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, @assert length(ActivityList) == length(ops) if tape !== nothing - pushfirst!(vals, tape) + if tape isa Vector + for t in reverse(tape) + pushfirst!(vals, t) + end + else + pushfirst!(vals, tape) + end else pushfirst!(vals, unsafe_to_llvm(Val(ReturnType))) end @@ -975,7 +1167,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, debug_from_orig!(gutils, cal, orig) - if tape === nothing + if tape === nothing && endcast llty = convert(LLVMType, ReturnType) cal = LLVM.addrspacecast!(B, cal, LLVM.PointerType(T_jlvalue, Derived)) cal = LLVM.pointercast!(B, cal, LLVM.PointerType(llty, Derived)) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 1ee4f0d961..101796401f 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -1,8 +1,229 @@ +function body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs, tuple) + shadow_rets = Vector{Expr}[] + results = quote + $(active_refs...) + end + @assert length(primtypes) == N + @assert length(primargs) == N + @assert length(batchshadowargs) == N + for i in 1:N + @assert length(batchshadowargs[i]) == Width + shadow_rets_i = Expr[] + aref = Symbol("active_ref_$i") + for w in 1:Width + sref = Symbol("shadow_"*string(i)*"_"*string(w)) + push!(shadow_rets_i, quote + $sref = if $aref == AnyState + $(primargs[i]); + else + if !ActivityTup[$i] + if $aref == DupState || $aref == MixedState + prim = $(primargs[i]) + throw("Error cannot store inactive but differentiable variable $prim into active tuple") + end + end + if $aref == DupState + $(batchshadowargs[i][w]) + else + $(batchshadowargs[i][w])[] + end + end + end) + end + push!(shadow_rets, shadow_rets_i) + end -function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) - if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL - return true + refs = Expr[] + ref_syms = Symbol[] + res_syms = Symbol[] + for w in 1:Width + sres = Symbol("result_$w") + ref_res = Symbol("ref_result_$w") + combined = Expr[] + for i in 1:N + push!(combined, shadow_rets[i][w]) + end + if tuple + results = quote + $results + $sres = ($(combined...),) + end + else + results = quote + $results + $sres = $(Expr(:new, :NewType, combined...)) + end + end + push!(refs, quote + $ref_res = Ref($sres) + end) + push!(ref_syms, ref_res) + push!(res_syms, sres) end + + if Width == 1 + return quote + $results + if any_mixed + $(refs...) + $(ref_syms[1]) + else + $(res_syms[1]) + end + end + else + return quote + $results + if any_mixed + $(refs...) + ReturnType(($(ref_syms...),)) + else + ReturnType(($(res_syms...),)) + end + end + end +end + + +function body_construct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs, tuple) + outs = [] + for i in 1:N + for w in 1:Width + tsym = Symbol("tval_$w") + expr = if tuple + :($tsym[$i]) + else + :(getfield($tsym, $i)) + end + shad = batchshadowargs[i][w] + out = :(if $(Symbol("active_ref_$i")) == MixedState || $(Symbol("active_ref_$i")) == ActiveState + if $shad isa Base.RefValue + $shad[] = recursive_add($shad[], $expr) + else + error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)) + end + end + ) + push!(outs, out) + end + end + + tapes = Expr[:(tval_1 = tape[])] + for w in 2:Width + sym = Symbol("tval_$w") + df = Symbol("df_$w") + push!(tapes, :($sym = $df[])) + end + + quote + $(active_refs...) + + if any_mixed + $(tapes...) + $(outs...) + end + return nothing + end +end + + +function body_runtime_tuple_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) + body_construct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs, true) +end + +function body_runtime_newstruct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) + body_construct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs, false) +end + + +function body_runtime_tuple_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) + body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs, true) +end + +function func_runtime_tuple_augfwd(N, Width) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; func=false, mixed_or_active=true) + body = body_runtime_tuple_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) + + quote + function runtime_tuple_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, $(typeargs...)} + $body + end + end +end + +@generated function runtime_tuple_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType} + N = div(length(allargs), Width) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; func=false, mixed_or_active=true) + return body_runtime_tuple_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) +end + + +function func_runtime_tuple_rev(N, Width) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; mixed_or_active=true) + body = body_runtime_tuple_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) + + quote + function runtime_tuple_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, $(allargs...)) where {ActivityTup, MB, TapeType, $(typeargs...)} + $body + end + end +end + +@generated function runtime_tuple_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, allargs...) where {ActivityTup, MB, Width, TapeType} + N = div(length(allargs)-(Width-1), Width) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; mixed_or_active=true) + return body_runtime_tuple_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) +end + + +function body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) + body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs, false) +end + +function func_runtime_newstruct_augfwd(N, Width) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width) + body = body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) + + quote + function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, ::Type{NewType}, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, NewType, $(typeargs...)} + $body + end + end +end + +@generated function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, ::Type{NewType}, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, NewType} + N = div(length(allargs)+2, Width+1)-1 + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs) + return body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) +end + +function func_runtime_newstruct_rev(N, Width) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; mixed_or_active=true) + body = body_runtime_newstruct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) + + quote + function runtime_newstruct_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, ::Type{NewStruct}, tape::TapeType, $(allargs...)) where {ActivityTup, MB, NewStruct, TapeType, $(typeargs...)} + $body + end + end +end + +@generated function runtime_newstruct_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, ::Type{NewStruct}, tape::TapeType, allargs...) where {ActivityTup, MB, Width, NewStruct, TapeType} + N = div(length(allargs)-(Width-1), Width) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; mixed_or_active=true) + return body_runtime_newstruct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) +end + +for (N, Width) in Iterators.product(0:30, 1:10) + eval(func_runtime_newstruct_augfwd(N, Width)) + eval(func_runtime_newstruct_rev(N, Width)) + eval(func_runtime_tuple_augfwd(N, Width)) + eval(func_runtime_tuple_rev(N, Width)) +end + + +# returns if legal and completed +function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR) origops = collect(operands(orig)) width = get_width(gutils) @@ -10,34 +231,35 @@ function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) @assert is_constant_value(gutils, origops[offset]) icvs = [is_constant_value(gutils, v) for v in origops[offset+1:end-1]] - abs = [abs_typeof(v, true) for v in origops[offset+1:end-1]] + abs_partial = [abs_typeof(v, true) for v in origops[offset+1:end-1]] + abs = [abs_typeof(v) for v in origops[offset+1:end-1]] - legal = true - for (icv, (found, typ)) in zip(icvs, abs) + @assert length(icvs) == length(abs) + for (icv, (found_partial, typ_partial), (found, typ)) in zip(icvs, abs_partial, abs) + # Constants not handled unless known inactive from type if icv - if found - if guaranteed_const_nongen(typ, world) - continue - end + if !found_partial + return false + end + if !guaranteed_const_nongen(typ_partial, world) + return false + end + end + # if any active [e.g. ActiveState / MixedState] data could exist + # err + if !fwd + if !found + return false + end + act = active_reg_inner(typ, (), world) + if act == MixedState || act == ActiveState + return false end - legal = false end end - # if all(icvs) - # shadowres = new_from_original(gutils, orig) - # if width != 1 - # shadowres2 = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(shadowres)))) - # for idx in 1:width - # shadowres2 = insert_value!(B, shadowres2, shadowres, idx-1) - # end - # shadowres = shadowres2 - # end - # unsafe_store!(shadowR, shadowres.ref) - # return false - # end - if !legal - emit_error(B, orig, "Enzyme: Not yet implemented, mixed activity for jl_new_struct constants="*string(icvs)*" "*string(orig)*" "*string(abs)*" "*string([v for v in origops[offset+1:end-1]])) + if !run + return true end shadowsin = LLVM.Value[invert_pointer(gutils, o, B) for o in origops[offset:end-1] ] @@ -62,19 +284,72 @@ function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) end end unsafe_store!(shadowR, shadowres.ref) + return true +end + + +function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + return true + end + + if !newstruct_common(#=fwd=#true, #=run=#true, offset, B, orig, gutils, normalR, shadowR) + abs_partial = [abs_typeof(v, true) for v in origops[offset+1:end-1]] + origops = collect(operands(orig)) + emit_error(B, orig, "Enzyme: Not yet implemented, mixed activity for jl_new_struct constants="*string(icvs)*" "*string(orig)*" "*string(abs)*" "*string([v for v in origops[offset+1:end-1]])) + end + return false end + function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) -end + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) -function error_if_active_newstruct(::Type{T}, ::Type{Y}) where {T, Y} - seen = () - areg = active_reg_inner(T, seen, nothing, #=justActive=#Val(true)) - if areg == ActiveState - throw(AssertionError("Found unhandled active variable ($T) in reverse mode of jl_newstruct constructor for $Y")) + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + return true end - nothing + + if !newstruct_common(#=fwd=#false, #=run=#true, offset, B, orig, gutils, normalR, shadowR) + normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + + + width = get_width(gutils) + + sret = generic_setup(orig, runtime_newstruct_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset, B, false; firstconst=true, endcast = false) + + if width == 1 + shadow = sret + else + AT = LLVM.ArrayType(T_prjlvalue, Int(width)) + llty = convert(LLVMType, AnyArray(Int(width))) + cal = sret + cal = LLVM.addrspacecast!(B, cal, LLVM.PointerType(T_jlvalue, Derived)) + cal = LLVM.pointercast!(B, cal, LLVM.PointerType(llty, Derived)) + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + shadow = LLVM.UndefValue(ST) + for i in 1:width + gep = LLVM.inbounds_gep!(B, AT, cal, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) + ld = LLVM.load!(B, T_prjlvalue, gep) + shadow = insert_value!(B, shadow, ld, i-1) + end + end + unsafe_store!(shadowR, shadow.ref) + + unsafe_store!(tapeR, sret.ref) + return false + end + + return false end function common_newstructv_rev(offset, B, orig, gutils, tape) @@ -90,20 +365,11 @@ function common_newstructv_rev(offset, B, orig, gutils, tape) if !needsShadow return end - - origops = collect(operands(orig)) - width = get_width(gutils) - - world = enzyme_extract_world(LLVM.parent(position(B))) - @assert is_constant_value(gutils, origops[offset]) - icvs = [is_constant_value(gutils, v) for v in origops[offset+1:end-1]] - abs = [abs_typeof(v, true) for v in origops[offset+1:end-1]] - - - ty = lookup_value(gutils, new_from_original(gutils, origops[offset]), B) - for v in origops[offset+1:end-1] - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_newstruct), emit_jltypeof!(B, lookup_value(gutils, new_from_original(gutils, v), B)), ty]) + if !newstruct_common(#=fwd=#false, #=run=#false, offset, B, orig, gutils, #=normalR=#nothing, #=shadowR=#nothing) + @assert tape !== C_NULL + width = get_width(gutils) + generic_setup(orig, runtime_newstruct_rev, Nothing, gutils, #=start=#offset, B, true; firstconst=true, tape) end return nothing @@ -112,13 +378,94 @@ end function common_f_tuple_fwd(offset, B, orig, gutils, normalR, shadowR) common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) end + function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - common_f_tuple_fwd(offset, B, orig, gutils, normalR, shadowR) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if is_constant_value(gutils, orig) || needsShadowP[] == 0 + return true + end + + if !newstruct_common(#=fwd=#false, #=run=#true, offset, B, orig, gutils, normalR, shadowR) + normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + + + width = get_width(gutils) + + sret = generic_setup(orig, runtime_tuple_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset+1, B, false; endcast = false) + + if width == 1 + shadow = sret + else + AT = LLVM.ArrayType(T_prjlvalue, Int(width)) + llty = convert(LLVMType, AnyArray(Int(width))) + cal = sret + cal = LLVM.addrspacecast!(B, cal, LLVM.PointerType(T_jlvalue, Derived)) + cal = LLVM.pointercast!(B, cal, LLVM.PointerType(llty, Derived)) + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + shadow = LLVM.UndefValue(ST) + for i in 1:width + gep = LLVM.inbounds_gep!(B, AT, cal, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) + ld = LLVM.load!(B, T_prjlvalue, gep) + shadow = insert_value!(B, shadow, ld, i-1) + end + end + unsafe_store!(shadowR, shadow.ref) + + unsafe_store!(tapeR, sret.ref) + + return false + end end function common_f_tuple_rev(offset, B, orig, gutils, tape) - # This function allocates a new return which returns a pointer, thus this instruction itself cannot transfer - # derivative info, only create a shadow pointer, which is handled by the forward pass. + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + needsPrimal = needsPrimalP[] != 0 + needsShadow = needsShadowP[] != 0 + + if !needsShadow + return + end + + if is_constant_value(gutils, orig) + return true + end + + if !newstruct_common(#=fwd=#false, #=run=#false, offset, B, orig, gutils, #=normalR=#nothing, #=shadowR=#nothing) + @assert tape !== C_NULL + width = get_width(gutils) + tape2 = if width != 1 + res = LLVM.Value[] + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + + AT = LLVM.ArrayType(T_prjlvalue, Int(width)) + llty = convert(LLVMType, AnyArray(Int(width))) + cal = tape + cal = LLVM.addrspacecast!(B, cal, LLVM.PointerType(T_jlvalue, Derived)) + cal = LLVM.pointercast!(B, cal, LLVM.PointerType(llty, Derived)) + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + + for i in 1:width + gep = LLVM.inbounds_gep!(B, AT, cal, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) + ld = LLVM.load!(B, T_prjlvalue, gep) + push!(res, ld) + end + res + else + tape + end + generic_setup(orig, runtime_tuple_rev, Nothing, gutils, #=start=#offset+1, B, true; tape=tape2) + end return nothing end diff --git a/test/applyiter.jl b/test/applyiter.jl index 2518e2d829..b1a26e5f54 100644 --- a/test/applyiter.jl +++ b/test/applyiter.jl @@ -89,6 +89,20 @@ function tupapprox(a, b) return a ≈ b end +@testset "Const Apply iterate" begin + function extiter() + vals = Any[3,] + extracted = Tuple(vals) + return extracted + end + + fwd, rev = Enzyme.autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(extiter)}, Duplicated) + + tape, res, dres = fwd(Const(extiter)) + @test res == (3,) + @test dres == (3,) +end + @testset "Reverse Apply iterate" begin x = [(2.0, 3.0), (7.9, 11.2)] dx = [(0.0, 0.0), (0.0, 0.0)] diff --git a/test/mixed.jl b/test/mixed.jl new file mode 100644 index 0000000000..dae0623073 --- /dev/null +++ b/test/mixed.jl @@ -0,0 +1,71 @@ +using Enzyme, Test + +@noinline function mixedmul(tup::T) where T + return tup[1] * tup[2][1] +end + +function outmixedmul(x::Float64) + vec = [x] + tup = (x, vec) + Base.inferencebarrier(mixedmul)(tup)::Float64 +end + +function outmixedmul2(res, x::Float64) + vec = [x] + tup = (x, vec) + res[] = Base.inferencebarrier(mixedmul)(tup)::Float64 +end + +@testset "Basic Mixed Activity" begin + @test 6.2 ≈ Enzyme.autodiff(Reverse, outmixedmul, Active, Active(3.1))[1][1] +end + +@testset "Byref Mixed Activity" begin + res = Ref(4.7) + dres = Ref(1.0) + @test 6.2 ≈ Enzyme.autodiff(Reverse, outmixedmul2, Const, Duplicated(res, dres), Active(3.1))[1][2] +end + +@static if VERSION >= v"1.8-" +@testset "Batched Byref Mixed Activity" begin + res = Ref(4.7) + dres = Ref(1.0) + dres2 = Ref(3.0) + sig = Enzyme.autodiff(Reverse, outmixedmul2, Const, BatchDuplicated(res, (dres, dres2)), Active(3.1)) + @test 6.2 ≈ sig[1][2][1] + @test 3*6.2 ≈ sig[1][2][2] +end +end + +function tupmixedmul(x::Float64) + vec = [x] + tup = (x, Base.inferencebarrier(vec)) + Base.inferencebarrier(mixedmul)(tup)::Float64 +end + +@testset "Tuple Mixed Activity" begin + @test 6.2 ≈ Enzyme.autodiff(Reverse, tupmixedmul, Active, Active(3.1))[1][1] +end + +function outtupmixedmul(res, x::Float64) + vec = [x] + tup = (x, Base.inferencebarrier(vec)) + res[] = Base.inferencebarrier(mixedmul)(tup)::Float64 +end + +@testset "Byref Tuple Mixed Activity" begin + res = Ref(4.7) + dres = Ref(1.0) + @test 6.2 ≈ Enzyme.autodiff(Reverse, outtupmixedmul, Const, Duplicated(res, dres), Active(3.1))[1][2] +end + +@static if VERSION >= v"1.8-" +@testset "Batched Byref Tuple Mixed Activity" begin + res = Ref(4.7) + dres = Ref(1.0) + dres2 = Ref(3.0) + sig = Enzyme.autodiff(Reverse, outtupmixedmul, Const, BatchDuplicated(res, (dres, dres2)), Active(3.1)) + @test 6.2 ≈ sig[1][2][1] + @test 3*6.2 ≈ sig[1][2][2] +end +end diff --git a/test/runtests.jl b/test/runtests.jl index e931666f90..ca05883c13 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1701,6 +1701,7 @@ end @test dx2[1][2] ≈ 0.0 end +include("mixed.jl") include("applyiter.jl") @testset "Dynamic Val Construction" begin diff --git a/test/threads.jl b/test/threads.jl index 5fe80916d3..6899d8d2d6 100644 --- a/test/threads.jl +++ b/test/threads.jl @@ -74,7 +74,8 @@ end out = [1.0, 2.0] dout = [1.0, 1.0] @static if VERSION < v"1.8" - @test_throws AssertionError autodiff(Reverse, f_multi, Const, Duplicated(out, dout), Active(2.0)) + # GPUCompiler causes a stack overflow due to https://github.com/JuliaGPU/GPUCompiler.jl/issues/587 + # @test_throws AssertionError autodiff(Reverse, f_multi, Const, Duplicated(out, dout), Active(2.0)) else res = autodiff(Reverse, f_multi, Const, Duplicated(out, dout), Active(2.0)) @test res[1][2] ≈ 2.0 From b8f9bebc921a173ad331415cd56881af692c0afa Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 10 Jun 2024 09:10:09 -0700 Subject: [PATCH 111/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 848c47e7ee..1292fbc5ce 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.12" +version = "0.12.13" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From df7dd8798dd0bf9e62bdbd692ad07ad61b5ebe50 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 10 Jun 2024 14:51:50 -0400 Subject: [PATCH 112/495] Handle non-zero mixed return (#1529) * Handle non-zero mixed return * improve mixed activity rule errors --- src/rules/customrules.jl | 12 ++++-- src/rules/jitrules.jl | 93 +++++++++++++++++++++++++++++++++++----- 2 files changed, 90 insertions(+), 15 deletions(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index e8c573a176..989a733c01 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -207,7 +207,7 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, return args, activity, (overwritten...,), actives, kwtup end -function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt)) +function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt), B) width = get_width(gutils) mode = get_mode(gutils) @@ -246,10 +246,14 @@ function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, activep = API.DFT_DUP_NONEED end + if activep == API.DFT_CONSTANT RT = Const{RealRt} - elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg(RealRt, world) ) + elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg_inner(RealRt, (), world, #=justActive=#Val(true)) == ActiveState) + if active_reg_inner(RealRt, (), world, #=justActive=#Val(false)) == MixedState && B !== nothing + emit_error(B, orig, "Enzyme: Return type $RealRt has mixed internal activity types in evaluation of custom rule for $mi. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information") + end RT = Active{RealRt} elseif activep == API.DFT_DUP_ARG @@ -298,7 +302,7 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) # 2) Create activity, and annotate function spec args, activity, overwritten, actives, kwtup = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#false, isKWCall) - RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt) + RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B) alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) @@ -511,7 +515,7 @@ end # 2) Create activity, and annotate function spec args, activity, overwritten, actives, kwtup = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#!forward, isKWCall) - RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt) + RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B) needsShadowJL = if RT <: Active false diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index e5ce78aa02..4aaa1c813d 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -311,6 +311,44 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) MixedTypes = ntuple(i->:($(Symbol("active_ref_$i") == MixedState) ? Ref($(Symbol("type_$i"))) : $(Symbol("type_$i"))), Val(N)) + ending = if Width == 1 + quote + if active_reg_nothrow(resT, Val(nothing)) == MixedState && !(initShadow isa Base.RefValue) + shadow_return = Ref(initShadow) + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType((origRet, shadow_return, tape)) + else + shadow_return = nothing + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType((origRet, initShadow, tape)) + end + end + else + expr = :() + shads = Expr[] + for i in 1:Width + if i == 1 + expr = quote !(initShadow[$i] isa Base.RefValue) end + else + expr = quote $expr || !(initShadow[$i] isa Base.RefValue) end + end + push!(shads, quote + Ref(initShadow[$i]) + end) + end + quote + if active_reg_nothrow(resT, Val(nothing)) == MixedState && ($expr) + shadow_return = ($(shads...),) + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType((origRet, shadow_return..., tape)) + else + shadow_return = nothing + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType((origRet, initShadow..., tape)) + end + end + end + return quote $(active_refs...) args = ($(wrapped...),) @@ -384,13 +422,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) @assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed - shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - if $Width == 1 - return ReturnType((origRet, initShadow, tape)) - else - return ReturnType((origRet, initShadow..., tape)) - end + $ending end end @@ -411,6 +443,31 @@ end return body_runtime_generic_augfwd(N, Width, wrapped, primtypes, active_refs) end +function nonzero_active_data(x::T) where T<: AbstractFloat + return x != zero(T) +end + +nonzero_active_data(::T) where T<: Base.RefValue = false +nonzero_active_data(::T) where T<: Array = false +nonzero_active_data(::T) where T<: Ptr = false + +function nonzero_active_data(x::T) where T + if guaranteed_const(T) + return false + end + if ismutable(x) + return false + end + + for f in fieldnames(T) + xi = getfield(x, f) + if nonzero_active_data(xi) + return true + end + end + return false +end + function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, active_refs) outs = [] for i in 1:N @@ -462,6 +519,10 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act false end + tt = Tuple{$(ElTypes...)} + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + if any_mixed ttM = Tuple{Val{active_refs}, FT, $(ElTypes...)} rtM = Core.Compiler.return_type(runtime_mixed_call, ttM) @@ -479,16 +540,20 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act Const{typeof(runtime_mixed_call)}, annotationM, Tuple{Const{Val{active_refs}}, dupClosure0 ? Duplicated{FT} : Const{FT}, $(Types...)}, Val(API.DEM_ReverseModePrimal), width, ModifiedBetweenM, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + if tape.shadow_return !== nothing + if !(annotation0M <: Active) && nonzero_active_data(($shadowret,)) + ET = ($(ElTypes...),) + throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")) + end + end + if annotation0M <: Active adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape) else adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape) end nothing else - tt = Tuple{$(ElTypes...)} - rt = Core.Compiler.return_type(f, tt) - annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) annotation = if $Width != 1 && annotation0 <: Duplicated BatchDuplicated{rt, $Width} @@ -502,7 +567,13 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - tup = if tape.shadow_return !== nothing + if tape.shadow_return !== nothing + if !(annotation0 <: Active) && nonzero_active_data(($shadowret,)) + ET = ($(ElTypes...),) + throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")) + end + end + tup = if annotation0 <: Active adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] else adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] From 6c2b0d9926ea50f1e4a7215b80278e6d3b75a0bf Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 10 Jun 2024 16:12:08 -0700 Subject: [PATCH 113/495] Improve rule arg mixed errors (#1530) * Improve rule arg mixed errors * fixup * improve errs --- src/rules/customrules.jl | 16 +++++++++++++++- src/rules/jitrules.jl | 12 ++++++++++-- src/rules/typeunstablerules.jl | 12 ++++++------ 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 989a733c01..c658850c2e 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -122,11 +122,14 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, push!(activity, Ty) - elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg(arg.typ, world) ) + elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg_inner(arg.typ, (), world, #=justActive=#Val(true)) == ActiveState) Ty = Active{arg.typ} llty = convert(LLVMType, Ty) arty = convert(LLVMType, arg.typ; allow_boxed=true) if B !== nothing + if active_reg_inner(arg.typ, (), world, #=justActive=#Val(false)) == MixedState + emit_error(B, orig, "Enzyme: Argument type $(arg.typ) has mixed internal activity types in evaluation of custom rule for $mi. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information") + end al0 = al = emit_allocobj!(B, Ty) al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) @@ -716,6 +719,17 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])) innerTy = value_type(parameters(llvmf)[tape_idx+(sret !== nothing)+(RT <: Active)]) if innerTy != value_type(tape) + if isabstracttype(TapeT) + msg = sprint() do io + println(io, "Enzyme : mismatch between innerTy $innerTy and tape type $(value_type(tape))") + println(io, "tape_idx=", tape_idx) + println(io, "sret=", sret) + println(io, "RT=", RT) + println(io, "tape=", tape) + println(io, "llvmf=", string(llvmf)) + end + throw(AssertionError(msg)) + end llty = convert(LLVMType, TapeT; allow_boxed=true) al0 = al = emit_allocobj!(B, TapeT) al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 4aaa1c813d..f2f9d27407 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1111,7 +1111,7 @@ for (N, Width) in Iterators.product(0:30, 1:10) eval(func_runtime_iterate_rev(N, Width)) end -function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false, endcast=true) +function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false, endcast=true, firstconst_after_tape=true) width = get_width(gutils) mode = get_mode(gutils) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -1132,7 +1132,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - if firstconst + if firstconst && !firstconst_after_tape val = new_from_original(gutils, operands(orig)[start]) if lookup val = lookup_value(gutils, val, B) @@ -1196,6 +1196,14 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, else pushfirst!(vals, unsafe_to_llvm(Val(ReturnType))) end + + if firstconst && firstconst_after_tape + val = new_from_original(gutils, operands(orig)[start]) + if lookup + val = lookup_value(gutils, val, B) + end + pushfirst!(vals, val) + end if mode != API.DEM_ForwardMode uncacheable = get_uncacheable(gutils, orig) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 101796401f..36f2798c0c 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -181,19 +181,19 @@ function body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primarg end function func_runtime_newstruct_augfwd(N, Width) - primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; mixed_or_active=true) body = body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) quote - function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, ::Type{NewType}, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, NewType, $(typeargs...)} + function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, ::Type{NewType}, RT::Val{ReturnType}, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, NewType, $(typeargs...)} $body end end end -@generated function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, ::Type{NewType}, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, NewType} +@generated function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, ::Type{NewType}, RT::Val{ReturnType}, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, NewType} N = div(length(allargs)+2, Width+1)-1 - primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; mixed_or_active=true) return body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) end @@ -325,7 +325,7 @@ function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tap width = get_width(gutils) - sret = generic_setup(orig, runtime_newstruct_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset, B, false; firstconst=true, endcast = false) + sret = generic_setup(orig, runtime_newstruct_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset, B, false; firstconst=true, endcast = false, firstconst_after_tape=true) if width == 1 shadow = sret @@ -369,7 +369,7 @@ function common_newstructv_rev(offset, B, orig, gutils, tape) if !newstruct_common(#=fwd=#false, #=run=#false, offset, B, orig, gutils, #=normalR=#nothing, #=shadowR=#nothing) @assert tape !== C_NULL width = get_width(gutils) - generic_setup(orig, runtime_newstruct_rev, Nothing, gutils, #=start=#offset, B, true; firstconst=true, tape) + generic_setup(orig, runtime_newstruct_rev, Nothing, gutils, #=start=#offset, B, true; firstconst=true, tape, firstconst_after_tape=true) end return nothing From bd609070880dc97c1bfcfac0ba1d281b8c46ff45 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 11 Jun 2024 14:22:18 -0700 Subject: [PATCH 114/495] Fix reverse mode closure issues (#1533) * Fix custom reverse on closure * fix closure --- src/rules/customrules.jl | 33 ++++++++++++++++++++++++--------- test/rrules.jl | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index c658850c2e..de24d01053 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -687,7 +687,8 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, end end end - push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) + + # push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) needsTape = !isghostty(TapeT) && !Core.Compiler.isconstType(TapeT) @@ -711,22 +712,37 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(llvmf, i)))) for i in 1:length(collect(parameters(llvmf)))) - _, sret, returnRoots = get_return_info(enzyme_custom_extract_mi(llvmf)[2]) + miRT = enzyme_custom_extract_mi(llvmf)[2] + _, sret, returnRoots = get_return_info(miRT) if !forward + funcTy = rev_TT.parameters[isKWCall ? 4 : 2] if needsTape @assert tape != C_NULL - tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])) - innerTy = value_type(parameters(llvmf)[tape_idx+(sret !== nothing)+(RT <: Active)]) + tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])) + !isghostty(funcTy) + trueidx = tape_idx+(sret !== nothing)+(returnRoots !== nothing)+swiftself+(RT <: Active) + innerTy = value_type(parameters(llvmf)[trueidx]) if innerTy != value_type(tape) - if isabstracttype(TapeT) + if isabstracttype(TapeT) || TapeT == Tuple || TapeT.layout == C_NULL msg = sprint() do io println(io, "Enzyme : mismatch between innerTy $innerTy and tape type $(value_type(tape))") println(io, "tape_idx=", tape_idx) + println(io, "true_idx=", trueidx) + println(io, "isKWCall=", isKWCall) + println(io, "kwtup=", kwtup) + println(io, "funcTy=", funcTy) + println(io, "isghostty(funcTy)=", isghostty(funcTy)) + println(io, "miRT=", miRT) println(io, "sret=", sret) + println(io, "returnRoots=", returnRoots) + println(io, "swiftself=", swiftself) println(io, "RT=", RT) println(io, "tape=", tape) - println(io, "llvmf=", string(llvmf)) + println(io, "llvmf=", string(LLVM.function_type(llvmf))) + println(io, "TapeT=", TapeT) + println(io, "mi=", mi) + println(io, "ami=", ami) + println(io, "rev_TT =", rev_TT) end throw(AssertionError(msg)) end @@ -749,7 +765,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, val = LLVM.Value(API.EnzymeGradientUtilsDiffe(gutils, orig, B)) else llety = convert(LLVMType, eltype(RT)) - ptr_val = invert_pointer(gutils, operands(orig)[1], B) + ptr_val = invert_pointer(gutils, operands(orig)[1 + !isghostty(funcTy)], B) val = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llety))) for idx in 1:width ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1) @@ -769,8 +785,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, if any_jltypes(llty) emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) end - - insert!(args, 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])), al) + insert!(args, 1+(!isghostty(funcTy))+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])), al) end end diff --git a/test/rrules.jl b/test/rrules.jl index 1322895924..171c160b0f 100644 --- a/test/rrules.jl +++ b/test/rrules.jl @@ -305,4 +305,44 @@ end @test dU[1] ≈ 7 * ( 3.0 + 4.0im ) end end + + +struct Closure + v::Vector{Float64} +end + +function (cl::Closure)(x) + val = cl.v[1] * x + cl.v[1] = 0.0 + return val +end + + +function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closure}, + ::Type{<:Active}, args::Vararg{Active,N}) where {N} + vec = copy(func.val.v) + pval = func.val(args[1].val) + primal = if EnzymeRules.needs_primal(config) + pval + else + nothing + end + return AugmentedReturn(primal, nothing, vec) +end + +function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{Closure}, + dret::Active, tape, args::Vararg{Active,N}) where {N} + dargs = ntuple(Val(N)) do i + 7 * args[1].val * dret.val + tape[1] * 1000 + end + return dargs +end + +@testset "Closure rule" begin + cl = Closure([3.14]) + res = autodiff(Reverse, cl, Active, Active(2.7))[1][1] + @test res ≈ 7 * 2.7 + 3.14 * 1000 + @test cl.v[1] ≈ 0.0 +end + end # ReverseRules From fb6f959d35d947378a048ed3201105f1769f1dd4 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 12 Jun 2024 19:25:26 -0700 Subject: [PATCH 115/495] Support Julia 1.11 (#1372) * Test against v1.11 * WIP: adapt to 1.11 changes * fix constructor * Update interpreter.jl * add cache_token * fixup! add cache_token * Apply suggestions from code review * fixup! Apply suggestions from code review --------- Co-authored-by: William Moses --- .github/workflows/CI.yml | 7 +++++ src/compiler.jl | 19 ++++++++++++ src/compiler/interpreter.jl | 60 ++++++++++++++++++++++++------------- 3 files changed, 66 insertions(+), 20 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index fb719bc0ab..bc20420585 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -26,6 +26,7 @@ jobs: - '1.8' - '1.9' - '1.10' + - ~1.11.0-0 - 'nightly' os: - ubuntu-20.04 @@ -86,6 +87,11 @@ jobs: libEnzyme: packaged version: '1.10' assertions: true + - os: ubuntu-20.04 + arch: x64 + libEnzyme: packaged + version: '1.11' + assertions: true steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 @@ -170,6 +176,7 @@ jobs: - '1.8' - '1.9' - '1.10' + - ~1.11.0-0 - 'nightly' os: - ubuntu-latest diff --git a/src/compiler.jl b/src/compiler.jl index ed44563ec4..d4103c706e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2992,8 +2992,27 @@ GPUCompiler.runtime_module(::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) GPUCompiler.runtime_slug(job::CompilerJob{EnzymeTarget}) = "enzyme" # provide a specific interpreter to use. +if VERSION >= v"1.11.0-DEV.1552" +struct EnzymeCacheToken + target_type::Type + always_inline + method_table::Core.MethodTable + param_type::Type + mode::API.CDerivativeMode +end + +GPUCompiler.ci_cache_token(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = + EnzymeCacheToken( + typeof(job.config.target), job.config.always_inline, GPUCompiler.method_table(job), + typeof(job.config.params), job.config.params.mode, + ) + +GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = + Interpreter.EnzymeInterpreter(GPUCompiler.ci_cache_token(job), GPUCompiler.method_table(job), job.world, job.config.params.mode) +else GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = Interpreter.EnzymeInterpreter(GPUCompiler.ci_cache(job), GPUCompiler.method_table(job), job.world, job.config.params.mode) +end include("compiler/passes.jl") include("compiler/optimize.jl") diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 5885679be5..95ff12a422 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -1,12 +1,28 @@ module Interpreter import Enzyme: API using Core.Compiler: AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, OptimizationParams, MethodInstance -using GPUCompiler: CodeCache, WorldView, @safe_debug +using GPUCompiler: @safe_debug +if VERSION < v"1.11.0-DEV.1552" + using GPUCompiler: CodeCache, WorldView, @safe_debug +end +const HAS_INTEGRATED_CACHE = VERSION >= v"1.11.0-DEV.1552" + import ..Enzyme import ..EnzymeRules +@static if VERSION ≥ v"1.11.0-DEV.1498" + import Core.Compiler: get_inference_world + using Base: get_world_counter +else + import Core.Compiler: get_world_counter, get_world_counter as get_inference_world +end + struct EnzymeInterpreter <: AbstractInterpreter - global_cache::CodeCache +@static if HAS_INTEGRATED_CACHE + token::Any +else + code_cache::CodeCache +end method_table::Union{Nothing,Core.MethodTable} # Cache of inference results for this particular interpreter @@ -19,34 +35,38 @@ struct EnzymeInterpreter <: AbstractInterpreter opt_params::OptimizationParams mode::API.CDerivativeMode +end - function EnzymeInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt, mode::API.CDerivativeMode) - @assert world <= Base.get_world_counter() +function EnzymeInterpreter(cache_or_token, mt::Union{Nothing,Core.MethodTable}, world::UInt, mode::API.CDerivativeMode) + @assert world <= Base.get_world_counter() - return new( - cache, - mt, + return EnzymeInterpreter( + cache_or_token, + mt, - # Initially empty cache - Vector{InferenceResult}(), + # Initially empty cache + Vector{InferenceResult}(), - # world age counter - world, + # world age counter + world, - # parameters for inference and optimization - InferenceParams(unoptimize_throw_blocks=false), - VERSION >= v"1.8.0-DEV.486" ? OptimizationParams() : - OptimizationParams(unoptimize_throw_blocks=false), - mode - ) - end + # parameters for inference and optimization + InferenceParams(unoptimize_throw_blocks=false), + VERSION >= v"1.8.0-DEV.486" ? OptimizationParams() : + OptimizationParams(unoptimize_throw_blocks=false), + mode + ) end Core.Compiler.InferenceParams(interp::EnzymeInterpreter) = interp.inf_params Core.Compiler.OptimizationParams(interp::EnzymeInterpreter) = interp.opt_params -Core.Compiler.get_world_counter(interp::EnzymeInterpreter) = interp.world +get_inference_world(interp::EnzymeInterpreter) = interp.world Core.Compiler.get_inference_cache(interp::EnzymeInterpreter) = interp.local_cache -Core.Compiler.code_cache(interp::EnzymeInterpreter) = WorldView(interp.global_cache, interp.world) +@static if HAS_INTEGRATED_CACHE + Core.Compiler.cache_owner(interp::EnzymeInterpreter) = interp.token +else + Core.Compiler.code_cache(interp::EnzymeInterpreter) = WorldView(interp.code_cache, interp.world) +end # No need to do any locking since we're not putting our results into the runtime cache Core.Compiler.lock_mi_inference(interp::EnzymeInterpreter, mi::MethodInstance) = nothing From 15f9bb1018541ad59f58019f7657044255864d89 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 13 Jun 2024 07:08:19 -0700 Subject: [PATCH 116/495] MixedDuplicated for custom rules (#1534) * MixedDuplicated for custom rules * more mixed duplicated * Handle mixed custom rule arg * starting batching * fix * fix tests * simplify mixed activity use --- lib/EnzymeCore/src/EnzymeCore.jl | 26 ++++ src/Enzyme.jl | 3 + src/compiler.jl | 177 +++++++++++++++++++++++++-- src/rules/customrules.jl | 102 ++++++++++++---- src/rules/jitrules.jl | 204 +++++-------------------------- test/abi.jl | 1 + test/mixedrrule.jl | 108 ++++++++++++++++ test/rrules.jl | 1 + test/usermixed.jl | 91 ++++++++++++++ 9 files changed, 508 insertions(+), 205 deletions(-) create mode 100644 test/mixedrrule.jl create mode 100644 test/usermixed.jl diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index fb788fd5a6..fee15c9dc6 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -150,6 +150,32 @@ end @inline batch_size(::Type{BatchDuplicatedNoNeed{T,N}}) where {T,N} = N +""" + MixedDuplicated(x, ∂f_∂x) + +Like [`Duplicated`](@ref), except x may contain both active [immutable] and duplicated [mutable] +data which is differentiable. Only used within custom rules. +""" +struct MixedDuplicated{T} <: Annotation{T} + val::T + dval::Base.RefValue{T} + @inline MixedDuplicated(x::T1, dx::Base.RefValue{T1}, check::Bool=true) where {T1} = new{T1}(x, dx) +end + +""" + BatchMixedDuplicated(x, ∂f_∂xs) + +Like [`MixedDuplicated`](@ref), except contains several shadows to compute derivatives +for all at once. Only used within custom rules. +""" +struct BatchMixedDuplicated{T,N} <: Annotation{T} + val::T + dval::NTuple{N,Base.RefValue{T}} + @inline BatchMixedDuplicated(x::T1, dx::NTuple{N,Base.RefValue{T1}}, check::Bool=true) where {T1, N} = new{T1, N}(x, dx) +end +@inline batch_size(::BatchMixedDuplicated{T,N}) where {T,N} = N +@inline batch_size(::Type{BatchMixedDuplicated{T,N}}) where {T,N} = N + """ abstract type ABI diff --git a/src/Enzyme.jl b/src/Enzyme.jl index a6bc604e6a..87b8e249e9 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -11,6 +11,9 @@ export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, import EnzymeCore: BatchDuplicatedFunc export BatchDuplicatedFunc +import EnzymeCore: MixedDuplicated, BatchMixedDuplicated +export MixedDuplicated, BatchMixedDuplicated + import EnzymeCore: batch_size, get_func export batch_size, get_func diff --git a/src/compiler.jl b/src/compiler.jl index d4103c706e..bdaacd05dd 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2450,6 +2450,50 @@ else end end +function store_nonjl_types!(B, startval, p) + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + vals = LLVM.Value[] + if p != nothing + push!(vals, p) + end + todo = Tuple{Tuple, LLVM.Value}[((), startval)] + while length(todo) != 0 + path, cur = popfirst!(todo) + ty = value_type(cur) + if isa(ty, LLVM.PointerType) + if any_jltypes(ty) + continue + end + end + if isa(ty, LLVM.ArrayType) + if any_jltypes(ty) + for i=1:length(ty) + ev = extract_value!(B, cur, i-1) + push!(todo, ((path..., i-1), ev)) + end + continue + end + end + if isa(ty, LLVM.StructType) + if any_jltypes(ty) + for (i, t) in enumerate(LLVM.elements(ty)) + ev = extract_value!(B, cur, i-1) + push!(todo, ((path..., i-1), ev)) + end + continue + end + end + parray = LLVM.Value[LLVM.ConstantInt(LLVM.IntType(64), 0)] + for v in path + push!(parray, LLVM.ConstantInt(LLVM.IntType(32), v)) + end + gptr = gep!(B, value_type(startval), p, parray) + st = store!(B, cur, gptr) + end + return +end + function get_julia_inner_types(B, p, startvals...; added=[]) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -3404,7 +3448,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr else push!(args_activity, API.DFT_OUT_DIFF) end - elseif T <: Duplicated || T<: BatchDuplicated || T<: BatchDuplicatedFunc + elseif T <: Duplicated || T<: BatchDuplicated || T<: BatchDuplicatedFunc || T <: MixedDuplicated || T <: BatchMixedDuplicated push!(args_activity, API.DFT_DUP_ARG) elseif T <: DuplicatedNoNeed || T<: BatchDuplicatedNoNeed push!(args_activity, API.DFT_DUP_NONEED) @@ -3588,7 +3632,6 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, isboxed = GPUCompiler.deserves_argbox(source_typ) llvmT = isboxed ? T_prjlvalue : convert(LLVMType, source_typ) - push!(T_wrapperargs, llvmT) if T <: Const || T <: BatchDuplicatedFunc @@ -3617,6 +3660,11 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if is_adjoint && i != 1 push!(ActiveRetTypes, Nothing) end + elseif T <: MixedDuplicated || T <: BatchMixedDuplicated + push!(T_wrapperargs, LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue))) + if is_adjoint && i != 1 + push!(ActiveRetTypes, Nothing) + end else error("calling convention should be annotated, got $T") end @@ -3799,7 +3847,23 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if isghostty(T′) || Core.Compiler.isconstType(T′) continue end - push!(realparms, params[i]) + + isboxed = GPUCompiler.deserves_argbox(T′) + + llty = value_type(params[i]) + + convty = convert(LLVMType, T′; allow_boxed=true) + + if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType)) + al = emit_allocobj!(builder, Base.RefValue{T′}) + al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) + store!(builder, params[i], al) + al = addrspacecast!(builder, al, LLVM.PointerType(llty, Derived)) + push!(realparms, al) + else + push!(realparms, params[i]) + end + i += 1 if T <: Const elseif T <: Active @@ -3827,6 +3891,34 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, elseif T <: Duplicated || T <: DuplicatedNoNeed push!(realparms, params[i]) i += 1 + elseif T <: MixedDuplicated || T <: BatchMixedDuplicated + parmsi = params[i] + + if T <: BatchMixedDuplicated + if GPUCompiler.deserves_argbox(NTuple{width, Base.RefValue{T′}}) + njlvalue = LLVM.ArrayType(Int(width), T_prjlvalue) + parmsi = bitcast!(builder, parmsi, LLVM.PointerType(njlvalue, addrspace(value_type(parmsi)))) + parmsi = load!(builder, njlvalue, parmsi) + end + end + + isboxed = GPUCompiler.deserves_argbox(T′) + + resty = isboxed ? llty : LLVM.PointerType(llty, Derived) + + ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, resty))) + for idx in 1:width + pv = (width == 1) ? parmsi : extract_value!(builder, parmsi, idx-1) + pv = bitcast!(builder, pv, LLVM.PointerType(llty, addrspace(value_type(pv)))) + pv = addrspacecast!(builder, pv, LLVM.PointerType(llty, Derived)) + if isboxed + pv = load!(builder, llty, pv, "mixedboxload") + end + ival = (width == 1 ) ? pv : insert_value!(builder, ival, pv, idx-1) + end + + push!(realparms, ival) + i += 1 elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed isboxed = GPUCompiler.deserves_argbox(NTuple{width, T′}) val = params[i] @@ -4357,6 +4449,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function # generate the wrapper function type & definition wrapper_types = LLVM.LLVMType[] + wrapper_attrs = Vector{LLVM.Attribute}[] _, sret, returnRoots = get_return_info(actualRetType) sret_union = is_sret_union(actualRetType) @@ -4391,31 +4484,44 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if swiftself push!(wrapper_types, value_type(parameters(entry_f)[1+sret+returnRoots])) + push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("swiftself")]) end boxedArgs = Set{Int}() loweredArgs = Set{Int}() + raisedArgs = Set{Int}() for arg in args typ = arg.codegen.typ if GPUCompiler.deserves_argbox(arg.typ) push!(boxedArgs, arg.arg_i) push!(wrapper_types, typ) + push!(wrapper_attrs, LLVM.Attribute[]) elseif arg.cc != GPUCompiler.BITS_REF - push!(wrapper_types, typ) + if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated) + push!(boxedArgs, arg.arg_i) + push!(raisedArgs, arg.arg_i) + push!(wrapper_types, LLVM.PointerType(typ, Derived)) + push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")]) + else + push!(wrapper_types, typ) + push!(wrapper_attrs, LLVM.Attribute[]) + end else # bits ref, and not boxed - # if TT.parameters[arg.arg_i] <: Const - # push!(boxedArgs, arg.arg_i) - # push!(wrapper_types, typ) - # else + if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated) + push!(boxedArgs, arg.arg_i) + push!(wrapper_types, typ) + push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")]) + else push!(wrapper_types, eltype(typ)) + push!(wrapper_attrs, LLVM.Attribute[]) push!(loweredArgs, arg.arg_i) - # end + end end end - if length(loweredArgs) == 0 && !sret && !sret_union + if length(loweredArgs) == 0 && length(raisedArgs) == 0 && !sret && !sret_union return entry_f, returnRoots, boxedArgs, loweredArgs end @@ -4436,8 +4542,10 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function end push!(function_attributes(wrapper_f), EnumAttribute("returns_twice")) push!(function_attributes(entry_f), EnumAttribute("returns_twice")) - if swiftself - push!(parameter_attributes(wrapper_f, 1), EnumAttribute("swiftself")) + for (i, v) in enumerate(wrapper_attrs) + for attr in v + push!(parameter_attributes(wrapper_f, i), attr) + end end seen = TypeTreeTable() @@ -4463,6 +4571,12 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function parm = ops[arg.codegen.i] if arg.arg_i in loweredArgs push!(nops, load!(builder, convert(LLVMType, arg.typ), parm)) + elseif arg.arg_i in raisedArgs + obj = emit_allocobj!(builder, arg.typ) + bc = bitcast!(builder, obj, LLVM.PointerType(value_type(parm), addrspace(value_type(obj)))) + store!(builder, parm, bc) + addr = addrspacecast!(builder, bc, LLVM.PointerType(value_type(parm), Derived)) + push!(nops, addr) else push!(nops, parm) end @@ -4547,6 +4661,13 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(arg.typ, ctx, dl, seen)))) push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))))) push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) + elseif arg.arg_i in raisedArgs + wrapparm = load!(builder, convert(LLVMType, arg.typ), wrapparm) + ctx = LLVM.context(wrapparm) + push!(wrapper_args, wrapparm) + push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(Base.RefValue{arg.typ}, ctx, dl, seen)))) + push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))))) + push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) else push!(wrapper_args, wrapparm) for attr in collect(parameter_attributes(entry_f, arg.codegen.i)) @@ -4626,6 +4747,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function elseif LLVM.return_type(entry_ft) == LLVM.VoidType() ret!(builder) else + ctx = LLVM.context(wrapper_f) push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen)))) push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType))))) push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) @@ -4687,7 +4809,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMReturnStatusAction) != 0 msg = sprint() do io println(io, string(mod)) - println(io, LVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction)) + println(io, LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction)) println(io, string(wrapper_f)) println(io, "parmsRemoved=", parmsRemoved, " retRemoved=", retRemoved, " prargs=", prargs) println(io, "Broken function") @@ -5966,6 +6088,35 @@ end push!(ActiveRetTypes, Nothing) end push!(ccexprs, argexpr) + elseif T <: MixedDuplicated + if RawCall + argexpr = argexprs[i] + i+=1 + else + argexpr = Expr(:., expr, QuoteNode(:dval)) + end + push!(types, Any) + if is_adjoint + push!(ActiveRetTypes, Nothing) + end + push!(ccexprs, argexpr) + elseif T <: BatchMixedDuplicated + if RawCall + argexpr = argexprs[i] + i+=1 + else + argexpr = Expr(:., expr, QuoteNode(:dval)) + end + isboxedvec = GPUCompiler.deserves_argbox(NTuple{width, Base.RefValue{source_typ}}) + if isboxedvec + push!(types, Any) + else + push!(types, NTuple{width, Base.RefValue{source_typ}}) + end + if is_adjoint + push!(ActiveRetTypes, Nothing) + end + push!(ccexprs, argexpr) else error("calling convention should be annotated, got $T") end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index de24d01053..1eab07c3f2 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -12,6 +12,7 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, actives = LLVM.Value[] + mixeds = Tuple{LLVM.Value, Type, LLVM.Value}[] uncacheable = get_uncacheable(gutils, orig) mode = get_mode(gutils) @@ -122,14 +123,11 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, push!(activity, Ty) - elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg_inner(arg.typ, (), world, #=justActive=#Val(true)) == ActiveState) + elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg_inner(arg.typ, (), world) == ActiveState) Ty = Active{arg.typ} llty = convert(LLVMType, Ty) arty = convert(LLVMType, arg.typ; allow_boxed=true) if B !== nothing - if active_reg_inner(arg.typ, (), world, #=justActive=#Val(false)) == MixedState - emit_error(B, orig, "Enzyme: Argument type $(arg.typ) has mixed internal activity types in evaluation of custom rule for $mi. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information") - end al0 = al = emit_allocobj!(B, Ty) al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) @@ -157,44 +155,92 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, ival = lookup_value(gutils, ival, B) end end + shadowty = arg.typ + mixed = false if width == 1 - if activep == API.DFT_DUP_ARG - Ty = Duplicated{arg.typ} + + if active_reg_inner(arg.typ, (), world) == MixedState + # TODO mixedupnoneed + shadowty = Base.RefValue{shadowty} + Ty = MixedDuplicated{arg.typ} + mixed = true else - @assert activep == API.DFT_DUP_NONEED - Ty = DuplicatedNoNeed{arg.typ} + if activep == API.DFT_DUP_ARG + Ty = Duplicated{arg.typ} + else + @assert activep == API.DFT_DUP_NONEED + Ty = DuplicatedNoNeed{arg.typ} + end end else - if activep == API.DFT_DUP_ARG - Ty = BatchDuplicated{arg.typ, Int(width)} + if active_reg_inner(arg.typ, (), world) == MixedState + # TODO batchmixedupnoneed + shadowty = Base.RefValue{shadowty} + Ty = BatchMixedDuplicated{arg.typ, Int(width)} + mixed = true else - @assert activep == API.DFT_DUP_NONEED - Ty = BatchDuplicatedNoNeed{arg.typ, Int(width)} + if activep == API.DFT_DUP_ARG + Ty = BatchDuplicated{arg.typ, Int(width)} + else + @assert activep == API.DFT_DUP_NONEED + Ty = BatchDuplicatedNoNeed{arg.typ, Int(width)} + end end end llty = convert(LLVMType, Ty) arty = convert(LLVMType, arg.typ; allow_boxed=true) + iarty = convert(LLVMType, shadowty; allow_boxed=true) sarty = LLVM.LLVMType(API.EnzymeGetShadowType(width, arty)) + siarty = LLVM.LLVMType(API.EnzymeGetShadowType(width, iarty)) if B !== nothing al0 = al = emit_allocobj!(B, Ty) al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) + needsload = false if value_type(val) != eltype(value_type(ptr)) val = load!(B, arty, val) + if !mixed + ptr_val = ival + ival = UndefValue(siarty) + for idx in 1:width + ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1) + ld = load!(B, iarty, ev) + ival = (width == 1 ) ? ld : insert_value!(B, ival, ld, idx-1) + end + end + needsload = true + end + store!(B, val, ptr) + + iptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 1)]) + + if mixed + RefTy = arg.typ + if width != 1 + RefTy = NTuple{N, RefTy} + end + llrty = convert(LLVMType, RefTy) + RefTy = Base.RefValue{RefTy} + refal0 = refal = emit_allocobj!(B, RefTy) + refal = bitcast!(B, refal, LLVM.PointerType(llrty, addrspace(value_type(refal)))) + + @assert needsload ptr_val = ival - ival = UndefValue(sarty) + ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llrty))) for idx in 1:width ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1) - ld = load!(B, arty, ev) + ld = load!(B, llrty, ev) ival = (width == 1 ) ? ld : insert_value!(B, ival, ld, idx-1) end + store!(B, ival, refal) + emit_writebarrier!(B, get_julia_inner_types(B, refal0, ival)) + ival = refal0 + push!(mixeds, (ptr_val, arg.typ, refal)) end - store!(B, val, ptr) - iptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 1)]) store!(B, ival, iptr) if any_jltypes(llty) @@ -207,7 +253,7 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, end end - return args, activity, (overwritten...,), actives, kwtup + return args, activity, (overwritten...,), actives, kwtup, mixeds end function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt), B) @@ -304,7 +350,7 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) end # 2) Create activity, and annotate function spec - args, activity, overwritten, actives, kwtup = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#false, isKWCall) + args, activity, overwritten, actives, kwtup, _ = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#false, isKWCall) RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B) alloctx = LLVM.IRBuilder() @@ -517,7 +563,7 @@ end isKWCall = isKWCallSignature(mi.specTypes) # 2) Create activity, and annotate function spec - args, activity, overwritten, actives, kwtup = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#!forward, isKWCall) + args, activity, overwritten, actives, kwtup, mixeds = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#!forward, isKWCall) RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B) needsShadowJL = if RT <: Active @@ -573,7 +619,7 @@ end end end end - return ami, augprimal_TT, (args, activity, overwritten, actives, kwtup, RT, needsPrimal, needsShadow, origNeedsPrimal) + return ami, augprimal_TT, (args, activity, overwritten, actives, kwtup, RT, needsPrimal, needsShadow, origNeedsPrimal, mixeds) end @inline function has_aug_fwd_rule(orig, gutils) @@ -599,7 +645,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, # 2) Create activity, and annotate function spec ami, augprimal_TT, setup = aug_fwd_mi(orig, gutils, forward, B) - args, activity, overwritten, actives, kwtup, RT, needsPrimal, needsShadow, origNeedsPrimal = setup + args, activity, overwritten, actives, kwtup, RT, needsPrimal, needsShadow, origNeedsPrimal, mixeds = setup needsShadowJL = if RT <: Active false @@ -970,6 +1016,20 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, end idx+=1 end + + for (ptr_val, argTyp, refal) in mixeds + RefTy = argTyp + if width != 1 + RefTy = NTuple{N, RefTy} + end + curs = load!(B, convert(LLVMType, RefTy), refal) + + for idx in 1:width + evp = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1) + evcur = (width == 1) ? curs : extract_value!(B, curs, idx-1) + store_nonjl_types!(B, evcur, evp) + end + end end if forward diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index f2f9d27407..af8f83b80e 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1,86 +1,3 @@ -function func_mixed_call(N) - allargs = Expr[] - typeargs = Union{Symbol,Expr}[] - exprs2 = Union{Symbol,Expr}[] - for i in 1:N - arg = Symbol("arg_$i") - targ = Symbol("T$i") - e = :($arg::$targ) - push!(allargs, e) - push!(typeargs, targ) - - inarg = quote - if RefTypes[1+$i] - $arg[] - else - $arg - end - end - push!(exprs2, inarg) - end - - quote - @generated function runtime_mixed_call(::Val{RefTypes}, f::F, $(allargs...)) where {RefTypes, F, $(typeargs...)} - fexpr = :f - if RefTypes[1] - fexpr = :(($fexpr)[]) - end - exprs2 = Union{Symbol,Expr}[] - for i in 1:$N - arg = Symbol("arg_$i") - inarg = if RefTypes[1+i] - :($arg[]) - else - :($arg) - end - push!(exprs2, inarg) - end - @static if VERSION ≥ v"1.8-" - return quote - Base.@_inline_meta - @inline $fexpr($(exprs2...)) - end - else - return quote - Base.@_inline_meta - $fexpr($(exprs2...)) - end - end - end - end -end - -@generated function runtime_mixed_call(::Val{RefTypes}, f::F, allargs::Vararg{Any, N}) where {RefTypes, F, N} - fexpr = :f - if RefTypes[1] - fexpr = :(($fexpr)[]) - end - exprs2 = Union{Symbol,Expr}[] - for i in 1:N - inarg = if RefTypes[1+i] - :(allargs[$i][]) - else - :(allargs[$i]) - end - push!(exprs2, inarg) - end - @static if VERSION ≥ v"1.8-" - return quote - Base.@_inline_meta - @inline $fexpr($(exprs2...)) - end - else - return quote - Base.@_inline_meta - $fexpr($(exprs2...)) - end - end -end - -for N in 0:10 - eval(func_mixed_call(N)) -end - function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate=false; func=true, mixed_or_active = false) primargs = Union{Symbol,Expr}[] shadowargs = Union{Symbol,Expr}[] @@ -192,7 +109,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, if $aref == ActiveState Active($(primargs[i])) elseif $aref == MixedState - $((Width == 1) ? :Duplicated : :BatchDuplicated)(Ref($(primargs[i])), $(shadowargs[i])) + $((Width == 1) ? :MixedDuplicated : :BatchMixedDuplicated)($(primargs[i]), $(shadowargs[i])) else $((Width == 1) ? :Duplicated : :BatchDuplicated)($(primargs[i]), $(shadowargs[i])) end @@ -361,45 +278,23 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) false end + tt = Tuple{$(ElTypes...)} + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - internal_tape, origRet, initShadow, annotation = if any_mixed - ttM = Tuple{Val{active_refs}, FT, $(ElTypes...)} - rtM = Core.Compiler.return_type(runtime_mixed_call, ttM) - annotation0M = guess_activity(rtM, API.DEM_ReverseModePrimal) - - annotationM = if $Width != 1 && annotation0M <: Duplicated - BatchDuplicated{rt, $Width} - else - annotation0M - end - worldM = codegen_world_age(typeof(runtime_mixed_call), ttM) - ModifiedBetweenM = Val((false, false, element(ModifiedBetween)...)) - - forward, adjoint = thunk(Val(worldM), - Const{typeof(runtime_mixed_call)}, - annotationM, Tuple{Const{Val{active_refs}}, dupClosure0 ? Duplicated{FT} : Const{FT}, $(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetweenM, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - - forward(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args...)..., annotationM - + annotationA = if $Width != 1 && annotation0 <: Duplicated + BatchDuplicated{rt, $Width} else - tt = Tuple{$(ElTypes...)} - rt = Core.Compiler.return_type(f, tt) - annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - - annotationA = if $Width != 1 && annotation0 <: Duplicated - BatchDuplicated{rt, $Width} - else - annotation0 - end - world = codegen_world_age(FT, tt) + annotation0 + end + world = codegen_world_age(FT, tt) - forward, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT}, - annotationA, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + forward, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT}, + annotationA, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - forward(dupClosure0 ? Duplicated(f, df) : Const(f), args...)..., annotationA - end + internal_tape, origRet, initShadow = forward(dupClosure0 ? Duplicated(f, df) : Const(f), args...) + annotation = annotationA resT = typeof(origRet) if annotation <: Const @@ -523,64 +418,31 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act rt = Core.Compiler.return_type(f, tt) annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - if any_mixed - ttM = Tuple{Val{active_refs}, FT, $(ElTypes...)} - rtM = Core.Compiler.return_type(runtime_mixed_call, ttM) - annotation0M = guess_activity(rtM, API.DEM_ReverseModePrimal) - - annotationM = if $Width != 1 && annotation0M <: Duplicated - BatchDuplicated{rt, $Width} - else - annotation0M - end - worldM = codegen_world_age(typeof(runtime_mixed_call), ttM) - ModifiedBetweenM = Val((false, false, element(ModifiedBetween)...)) - - _, adjoint = thunk(Val(worldM), - Const{typeof(runtime_mixed_call)}, - annotationM, Tuple{Const{Val{active_refs}}, dupClosure0 ? Duplicated{FT} : Const{FT}, $(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetweenM, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - - if tape.shadow_return !== nothing - if !(annotation0M <: Active) && nonzero_active_data(($shadowret,)) - ET = ($(ElTypes...),) - throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")) - end - end - if annotation0M <: Active - adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape) - else - adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape) - end - nothing + annotation = if $Width != 1 && annotation0 <: Duplicated + BatchDuplicated{rt, $Width} else + annotation0 + end - annotation = if $Width != 1 && annotation0 <: Duplicated - BatchDuplicated{rt, $Width} - else - annotation0 - end - - world = codegen_world_age(FT, tt) + world = codegen_world_age(FT, tt) - _, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT}, - annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + _, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT}, + annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - if tape.shadow_return !== nothing - if !(annotation0 <: Active) && nonzero_active_data(($shadowret,)) - ET = ($(ElTypes...),) - throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")) - end - end - tup = if annotation0 <: Active - adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] - else - adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] + if tape.shadow_return !== nothing + if !(annotation0 <: Active) && nonzero_active_data(($shadowret,)) + ET = ($(ElTypes...),) + throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")) end - - $(outs...) end + tup = if annotation0 <: Active + adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] + else + adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] + end + + $(outs...) return nothing end end diff --git a/test/abi.jl b/test/abi.jl index 7371af504e..8d4251bb70 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -442,3 +442,4 @@ abssum(x) = sum(abs2, x); end +include("usermixed.jl") \ No newline at end of file diff --git a/test/mixedrrule.jl b/test/mixedrrule.jl new file mode 100644 index 0000000000..32407f3c12 --- /dev/null +++ b/test/mixedrrule.jl @@ -0,0 +1,108 @@ +module ReverseMixedRules + +using Enzyme +using Enzyme: EnzymeRules +using Test + +import .EnzymeRules: augmented_primal, reverse, Annotation, has_rrule_from_sig +using .EnzymeRules + +function mixfnc(tup) + return tup[1] * tup[2][1] +end + +function mixouter(x, y) + res = mixfnc((x, y)) + fill!(y, 0.0) + return res +end + +function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(mixfnc)}, + ::Type{<:Active}, tup::MixedDuplicated{Tuple{Float64, Vector{Float64}}}) + pval = func.val(tup.val) + vec = copy(tup.val[2]) + primal = if EnzymeRules.needs_primal(config) + pval + else + nothing + end + return AugmentedReturn(primal, nothing, vec) +end + +function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(mixfnc)}, + dret::Active, tape, tup::MixedDuplicated{Tuple{Float64, Vector{Float64}}}) + prev = tup.dval[] + tup.dval[] = (7 * tape[1] * dret.val, prev[2]) + prev[2][1] = 1000 * dret.val * tup.val[1] + return (nothing,) +end + +@testset "Mixed activity rule" begin + x = [3.14] + dx = [0.0] + res = autodiff(Reverse, mixouter, Active, Active(2.7), Duplicated(x, dx))[1][1] + @test res ≈ 7 * 3.14 + @test dx[1] ≈ 1000 * 2.7 + @test x[1] ≈ 0.0 +end + + +function recmixfnc(tup) + return sum(tup[1]) * tup[2][1] +end + +function recmixouter(x, y, z) + res = recmixfnc(((x, z), y)) + fill!(y, 0.0) + return res +end + +function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(recmixfnc)}, + ::Type{<:Active}, tup) + pval = func.val(tup.val) + vec = copy(tup.val[2]) + primal = if EnzymeRules.needs_primal(config) + pval + else + nothing + end + return AugmentedReturn(primal, nothing, vec) +end + +# check if a value is guaranteed to be not contain active[register] data +# (aka not either mixed or active) +@inline function guaranteed_nonactive(::Type{T}) where T + rt = Enzyme.Compiler.active_reg_inner(T, (), nothing) + return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState +end + +function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(recmixfnc)}, + dret::Active, tape, tup) + prev = tup.dval[] + dRT = typeof(prev) + + tup.dval[] = Enzyme.Compiler.splatnew(dRT, ntuple(Val(fieldcount(dRT))) do i + Base.@_inline_meta + pv = getfield(prev, i) + if i == 1 + next = (7 * tape[1] * dret.val, 31 * tape[1] * dret.val) + Enzyme.Compiler.recursive_add(pv, next, identity, guaranteed_nonactive) + else + pv + end + end) + prev[2][1] = 1000 * dret.val * tup.val[1][1] + .0001 * dret.val * tup.val[1][2] + return (nothing,) +end + +@testset "Recursive Mixed activity rule" begin + x = [3.14] + dx = [0.0] + res = autodiff(Reverse, recmixouter, Active, Active(2.7), Duplicated(x, dx), Active(56.47))[1] + @test res[1] ≈ 7 * 3.14 + @test res[3] ≈ 31 * 3.14 + @test dx[1] ≈ 1000 * 2.7 + .0001 * 56.47 + @test x[1] ≈ 0.0 +end + +end # ReverseMixedRules diff --git a/test/rrules.jl b/test/rrules.jl index 171c160b0f..ee3b9af138 100644 --- a/test/rrules.jl +++ b/test/rrules.jl @@ -345,4 +345,5 @@ end @test cl.v[1] ≈ 0.0 end +include("mixedrrule.jl") end # ReverseRules diff --git a/test/usermixed.jl b/test/usermixed.jl new file mode 100644 index 0000000000..b5cd0e158b --- /dev/null +++ b/test/usermixed.jl @@ -0,0 +1,91 @@ +using Enzyme +using Test + +function user_mixfnc(tup) + return tup[1] * tup[2][1] +end + +@testset "MixedDuplicated struct call" begin + tup = (2.7, [3.14]) + dtup = Ref((0.0, [0.0])) + + res = autodiff(Reverse, user_mixfnc, Active, MixedDuplicated(tup, dtup)) + @test dtup[][1] ≈ 3.14 + @test dtup[][2] ≈ [2.7] +end + + +function user_mixfnc_byref(out, tup) + out[] = tup[1] * tup[2][1] + return nothing +end + +@testset "Batch MixedDuplicated struct call" begin + tup = (2.7, [3.14]) + dtup = (Ref((0.0, [0.0])), Ref((0.0, [0.0]))) + out = Ref(0.0) + dout = (Ref(1.0), Ref(3.0)) + res = autodiff(Reverse, user_mixfnc_byref, Const, BatchDuplicated(out, dout), BatchMixedDuplicated(tup, dtup)) + @test dtup[1][][1] ≈ 3.14 + @test dtup[1][][2] ≈ [2.7] + @test dtup[2][][1] ≈ 3*3.14 + @test dtup[2][][2] ≈ [3*2.7] +end + +function mix_square(x) + return x * x +end + +@testset "MixedDuplicated float64 call" begin + tup = 2.7 + dtup = Ref(0.0) + res = autodiff(Reverse, mix_square, Active, MixedDuplicated(tup, dtup)) + @test res[1] == (nothing,) + @test dtup[] ≈ 2 * 2.7 +end + + +function mix_square_byref(out, x) + out[] = x * x + return nothing +end + +@testset "BatchMixedDuplicated float64 call" begin + tup = 2.7 + dtup = (Ref(0.0), Ref(0.0)) + out = Ref(0.0) + dout = (Ref(1.0), Ref(3.0)) + res = autodiff(Reverse, mix_square_byref, Const, BatchDuplicated(out, dout), BatchMixedDuplicated(tup, dtup)) + @test res[1] == (nothing,nothing) + @test dtup[1][] ≈ 2 * 2.7 + @test dtup[2][] ≈ 3 * 2 * 2.7 +end + +function mix_ar(x) + return x[1] * x[2] +end + +@testset "MixedDuplicated vector{float64} call" begin + tup = [2.7, 3.14] + dtup = Ref([0.0, 0.0]) + res = autodiff(Reverse, mix_ar, Active, MixedDuplicated(tup, dtup)) + @test res[1] == (nothing,) + @test dtup[] ≈ [3.14, 2.7] +end + + +function mix_ar_byref(out, x) + out[] = x[1] * x[2] + return nothing +end + +@testset "BatchMixedDuplicated vector{float64} call" begin + tup = [2.7, 3.14] + dtup = (Ref([0.0, 0.0]), Ref([0.0, 0.0])) + out = Ref(0.0) + dout = (Ref(1.0), Ref(3.0)) + res = autodiff(Reverse, mix_ar_byref, Const, BatchDuplicated(out, dout), BatchMixedDuplicated(tup, dtup)) + @test res[1] == (nothing,nothing) + @test dtup[1][] ≈ [3.14, 2.7] + @test dtup[2][] ≈ [3*3.14, 3*2.7] +end From 6b8a50c666e72515d77dfadfca7cbbcb1a411fca Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 13 Jun 2024 07:10:37 -0700 Subject: [PATCH 117/495] Update Project.toml --- lib/EnzymeCore/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 20a89b9a05..fe809f1a14 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.7.4" +version = "0.7.5" [compat] Adapt = "3, 4" From 835b6d5797e73eb7fa48b0985f441aa2258f8951 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 13 Jun 2024 07:10:52 -0700 Subject: [PATCH 118/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1292fbc5ce..66a06e1714 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.7.4" +EnzymeCore = "0.7.5" Enzyme_jll = "0.0.121" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" From a889bb620a91870622d882cdb2df2652e88aa5db Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 14 Jun 2024 12:57:38 -0400 Subject: [PATCH 119/495] Mixed activity for getfield (#1535) * Mixed activity for getfield * bump ver * fixup runtime iterate for mixed * fix iter * mixedduplicated return * fixup * fix * try inference fix re ref * try more * Update Project.toml * Update jitrules.jl --- Project.toml | 4 +- src/Enzyme.jl | 14 ++ src/compiler.jl | 137 +++++++++---- src/rules/jitrules.jl | 348 +++++++++++++++++++++------------ src/rules/typeunstablerules.jl | 60 ++++-- test/applyiter.jl | 2 + test/mixedapplyiter.jl | 144 ++++++++++++++ test/usermixed.jl | 116 +++++++++++ 8 files changed, 651 insertions(+), 174 deletions(-) create mode 100644 test/mixedapplyiter.jl diff --git a/Project.toml b/Project.toml index 66a06e1714..f72e22bd39 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.13" +version = "0.12.14" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -20,7 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.5" -Enzyme_jll = "0.0.121" +Enzyme_jll = "0.0.122" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 87b8e249e9..de694b04f3 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -64,6 +64,10 @@ end arg = @inbounds args[i] if arg isa Active return true + elseif arg isa MixedDuplicated + return true + elseif arg isa BatchMixedDuplicated + return true else return false end @@ -95,6 +99,10 @@ end end @inline same_or_one_rec(current) = current +@inline same_or_one_rec(current, arg::BatchMixedDuplicated{T, N}, args...) where {T,N} = + same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec(current, arg::Type{BatchMixedDuplicated{T, N}}, args...) where {T,N} = + same_or_one_rec(same_or_one_helper(current, N), args...) @inline same_or_one_rec(current, arg::BatchDuplicatedFunc{T, N}, args...) where {T,N} = same_or_one_rec(same_or_one_helper(current, N), args...) @inline same_or_one_rec(current, arg::Type{BatchDuplicatedFunc{T, N}}, args...) where {T,N} = @@ -844,6 +852,12 @@ result, ∂v, ∂A else BatchDuplicatedNoNeed{eltype(A2), width} end + elseif A2 <: MixedDuplicated && width != 1 + if A2 isa UnionAll + BatchMixedDuplicated{T, width} where T + else + BatchMixedDuplicated{eltype(A2), width} + end else A2 end diff --git a/src/compiler.jl b/src/compiler.jl index bdaacd05dd..7bbb1bbedd 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -543,6 +543,13 @@ end return res end +# check if a value is guaranteed to be not contain active[register] data +# (aka not either mixed or active) +@inline function guaranteed_nonactive(::Type{T}) where T + rt = Enzyme.Compiler.active_reg_nothrow(T, Val(nothing)) + return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState +end + @inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = guess_activity(T, convert(API.CDerivativeMode, mode)) @inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T} @@ -555,6 +562,8 @@ end else if ActReg == ActiveState return Active{T} + elseif ActReg == MixedState + return MixedDuplicated{T} else return Duplicated{T} end @@ -2494,7 +2503,7 @@ function store_nonjl_types!(B, startval, p) return end -function get_julia_inner_types(B, p, startvals...; added=[]) +function get_julia_inner_types(B, p, startvals...; added=LLVM.API.LLVMValueRef[]) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) vals = LLVM.Value[] @@ -2547,8 +2556,20 @@ function get_julia_inner_types(B, p, startvals...; added=[]) end continue end - GPUCompiler.@safe_warn "Enzyme illegal subtype", ty, cur, SI, p, v - @assert false + if isa(ty, LLVM.IntegerType) + continue + end + if isa(ty, LLVM.FloatingPointType) + continue + end + msg = sprint() do io + println(io, "Enzyme illegal subtype") + println(io, "ty=", ty) + println(io, "cur=", cur) + println(io, "p=", p) + println(io, "startvals=", startvals) + end + throw(AssertionError(msg)) end return vals end @@ -3474,7 +3495,11 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr # If requested, the shadow return value of the function # For each active (non duplicated) argument # The adjoint of that argument - retType = convert(API.CDIFFE_TYPE, rt) + retType = if rt <: MixedDuplicated || rt <: BatchMixedDuplicated + API.DFT_OUT_DIFF + else + convert(API.CDIFFE_TYPE, rt) + end rules = Dict{String, API.CustomRuleType}( "jl_array_copy" => @cfunction(inout_rule, @@ -3513,7 +3538,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr if mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType)) - shadowReturnUsed = returnUsed && (retType == API.DFT_DUP_ARG || retType == API.DFT_DUP_NONEED) + shadowReturnUsed = returnUsed && (retType == API.DFT_DUP_ARG || retType == API.DFT_DUP_NONEED || rt <: MixedDuplicated || rt <: BatchMixedDuplicated) returnUsed &= returnPrimal augmented = API.EnzymeCreateAugmentedPrimal( logic, primalf, retType, args_activity, TA, #=returnUsed=# returnUsed, @@ -3679,16 +3704,20 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end # API.DFT_OUT_DIFF - if is_adjoint && rettype <: Active - @assert !sret_union - if allocatedinline(actualRetType) != allocatedinline(literal_rt) - throw(AssertionError("Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype)")) - end - if !allocatedinline(actualRetType) - throw(AssertionError("Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)")) + if is_adjoint + if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + @assert !sret_union + if allocatedinline(actualRetType) != allocatedinline(literal_rt) + throw(AssertionError("Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype)")) + end + if rettype <: Active + if !allocatedinline(actualRetType) + throw(AssertionError("Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)")) + end + end + dretTy = LLVM.LLVMType(API.EnzymeGetShadowType(width, convert(LLVMType, actualRetType; allow_boxed=!(rettype <: Active)))) + push!(T_wrapperargs, dretTy) end - dretTy = LLVM.LLVMType(API.EnzymeGetShadowType(width, convert(LLVMType, actualRetType))) - push!(T_wrapperargs, dretTy) end data = Array{Int64}(undef, 3) @@ -3730,6 +3759,12 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, else push!(sret_types, AnonymousStruct(NTuple{width, literal_rt})) end + elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + if width == 1 + push!(sret_types, Base.RefValue{literal_rt}) + else + push!(sret_types, AnonymousStruct(NTuple{width, Base.RefValue{literal_rt}})) + end end else @assert rettype <: Const || rettype <: Active @@ -3953,7 +3988,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end end - if is_adjoint && rettype <: Active + if is_adjoint && (rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated) push!(realparms, params[i]) i += 1 end @@ -3999,12 +4034,26 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if data[i] != -1 eval = extract_value!(builder, val, data[i]) end + if i == 3 + if rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue))) + for idx in 1:width + pv = (width == 1) ? eval : extract_value!(builder, eval, idx-1) + al0 = al = emit_allocobj!(builder, Base.RefValue{eltype(rettype)}) + llty = value_type(pv) + al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) + store!(builder, pv, al) + emit_writebarrier!(builder, get_julia_inner_types(builder, al0, pv)) + ival = (width == 1 ) ? al0 : insert_value!(builder, ival, al0, idx-1) + end + eval = ival + end + end eval = fixup_abi(i, eval) ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) si = store!(builder, eval, ptr) returnNum+=1 - if i == 3 && shadow_init shadows = LLVM.Value[] if width == 1 @@ -5943,22 +5992,28 @@ end end if !RawCall && !(CC <: PrimalErrorThunk) - if rettype <: Active + if rettype <: Active if length(argtypes) + is_adjoint + needs_tape != length(argexprs) return quote - throw(MethodError($CC(fptr), $args)) + throw(MethodError($CC(fptr), (fn, args...))) + end + end + elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + if length(argtypes) + is_adjoint * width + needs_tape != length(argexprs) + return quote + throw(MethodError($CC(fptr), (fn, args...))) end end elseif rettype <: Const if length(argtypes) + needs_tape != length(argexprs) return quote - throw(MethodError($CC(fptr), $args)) + throw(MethodError($CC(fptr), (fn, args...))) end end else if length(argtypes) + needs_tape != length(argexprs) return quote - throw(MethodError($CC(fptr), $args)) + throw(MethodError($CC(fptr), (fn, args...))) end end end @@ -5966,11 +6021,6 @@ end types = DataType[] - if eltype(rettype) === Union{} && false - return quote - error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up") - end - end if !(rettype <: Const) && (isghostty(eltype(rettype)) || Core.Compiler.isconstType(eltype(rettype)) || eltype(rettype) === DataType) rrt = eltype(rettype) error("Return type `$rrt` not marked Const, but is ghost or const type.") @@ -6133,17 +6183,28 @@ end end # API.DFT_OUT_DIFF - if is_adjoint && rettype <: Active - # TODO handle batch width - @assert allocatedinline(jlRT) - j_drT = if width == 1 - jlRT - else - NTuple{width, jlRT} + if is_adjoint + if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + # TODO handle batch width + if rettype <: Active + @assert allocatedinline(jlRT) + end + j_drT = if width == 1 + jlRT + else + NTuple{width, jlRT} + end + push!(types, j_drT) + if width == 1 || rettype <: Active + push!(ccexprs, argexprs[i]) + i+=1 + else + push!(ccexprs, quote + ($(argexprs[i:i+width-1]...),) + end) + i+=width + end end - push!(types, j_drT) - push!(ccexprs, argexprs[i]) - i+=1 end if needs_tape @@ -6181,8 +6242,12 @@ end end if rettype <: Duplicated || rettype <: DuplicatedNoNeed push!(sret_types, jlRT) + elseif rettype <: MixedDuplicated + push!(sret_types, Base.RefValue{jlRT}) elseif rettype <: BatchDuplicated || rettype <: BatchDuplicatedNoNeed push!(sret_types, AnonymousStruct(NTuple{width, jlRT})) + elseif rettype <: BatchMixedDuplicated + push!(sret_types, AnonymousStruct(NTuple{width, Base.RefValue{jlRT}})) elseif CC <: AugmentedForwardThunk push!(sret_types, Nothing) elseif rettype <: Const @@ -6406,6 +6471,8 @@ end @inline remove_innerty(::Type{<:DuplicatedNoNeed}) = DuplicatedNoNeed @inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated @inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed +@inline remove_innerty(::Type{<:MixedDuplicated}) = MixedDuplicated +@inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated @inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} JuliaContext() do ctx diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index af8f83b80e..f04145f7bc 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1,4 +1,4 @@ -function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate=false; func=true, mixed_or_active = false) +function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate=false; func=true, mixed_or_active = false, reverse=false) primargs = Union{Symbol,Expr}[] shadowargs = Union{Symbol,Expr}[] batchshadowargs = Vector{Union{Symbol,Expr}}[] @@ -76,23 +76,50 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, $aref = active_reg_nothrow($(primtypes[i]), Val(nothing)); end) expr = if iterate - :( - if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) - @assert $(primtypes[i]) !== DataType - if !$forwardMode && active_reg($(primtypes[i])) - iterate_unwrap_augfwd_act($(primargs[i])...) - else - $((Width == 1) ? quote - iterate_unwrap_augfwd_dup(Val($forwardMode), $(primargs[i]), $(shadowargs[i])) - end : quote - iterate_unwrap_augfwd_batchdup(Val($forwardMode), Val($Width), $(primargs[i]), $(shadowargs[i])) + if forwardMode + dupexpr = if Width == 1 + quote + iterate_unwrap_fwd_dup($(primargs[i]), $(shadowargs[i])) + end + else + quote + iterate_unwrap_fwd_batchdup(Val($Width), $(primargs[i]), $(shadowargs[i])) + end + end + :( + if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) + @assert $(primtypes[i]) !== DataType + $dupexpr + else + map(Const, $(primargs[i])) + end + ) + else + dupexpr = if Width == 1 + quote + iterate_unwrap_augfwd_dup(Val($reverse), refs, $(primargs[i]), $(shadowargs[i])) + end + else + quote + iterate_unwrap_augfwd_batchdup(Val($reverse), refs, Val($Width), $(primargs[i]), $(shadowargs[i])) + end + end + :( + if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) + @assert $(primtypes[i]) !== DataType + if $aref == ActiveState + iterate_unwrap_augfwd_act($(primargs[i])...) + elseif $aref == MixedState + T = $(primtypes[i]) + throw(AssertionError("Mixed State of type $T is unsupported in apply iterate")) + else + $dupexpr + end + else + map(Const, $(primargs[i])) end - ) - end - else - map(Const, $(primargs[i])) - end - ) + ) + end else if forwardMode quote @@ -131,16 +158,6 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, any_mixed = :($any_mixed || $aref == MixedState) end end - - if mixed_or_active - push!(active_refs, quote - active_refs = (false, $(collect(:($(Symbol("active_ref_$i")) == MixedState || $(Symbol("active_ref_$i")) == ActiveState) for i in 1:N)...)) - end) - else - push!(active_refs, quote - active_refs = (false, $(collect(:($(Symbol("active_ref_$i")) == MixedState) for i in 1:N)...)) - end) - end push!(active_refs, quote any_mixed = $any_mixed end) @@ -230,8 +247,8 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) ending = if Width == 1 quote - if active_reg_nothrow(resT, Val(nothing)) == MixedState && !(initShadow isa Base.RefValue) - shadow_return = Ref(initShadow) + if annotation <: MixedDuplicated + shadow_return = initShadow tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) return ReturnType((origRet, shadow_return, tape)) else @@ -241,23 +258,11 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) end end else - expr = :() - shads = Expr[] - for i in 1:Width - if i == 1 - expr = quote !(initShadow[$i] isa Base.RefValue) end - else - expr = quote $expr || !(initShadow[$i] isa Base.RefValue) end - end - push!(shads, quote - Ref(initShadow[$i]) - end) - end quote - if active_reg_nothrow(resT, Val(nothing)) == MixedState && ($expr) - shadow_return = ($(shads...),) + if annotation <: BatchMixedDuplicated + shadow_return = (initShadow...,) tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - return ReturnType((origRet, shadow_return..., tape)) + return ReturnType((origRet, initShadow..., tape)) else shadow_return = nothing tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) @@ -284,6 +289,8 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) annotationA = if $Width != 1 && annotation0 <: Duplicated BatchDuplicated{rt, $Width} + elseif $Width != 1 && annotation0 <: MixedDuplicated + BatchMixedDuplicated{rt, $Width} else annotation0 end @@ -315,8 +322,6 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) end end - @assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed - $ending end end @@ -430,14 +435,14 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - if tape.shadow_return !== nothing - if !(annotation0 <: Active) && nonzero_active_data(($shadowret,)) - ET = ($(ElTypes...),) - throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")) - end - end tup = if annotation0 <: Active adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] + elseif annotation0 <: MixedDuplicated || annotation0 <: BatchMixedDuplicated + if $Width == 1 + adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] + else + adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret..., tape.internal_tape)[1] + end else adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] end @@ -493,30 +498,85 @@ end end end -@inline function iterate_unwrap_augfwd_dup(::Val{forwardMode}, args, dargs) where forwardMode +@inline function iterate_unwrap_fwd_dup(args, dargs) ntuple(Val(length(args))) do i Base.@_inline_meta arg = args[i] ty = Core.Typeof(arg) if guaranteed_const(ty) Const(arg) - elseif !forwardMode && active_reg(ty) - Active(arg) else Duplicated(arg, dargs[i]) end end end -@inline function iterate_unwrap_augfwd_batchdup(::Val{forwardMode}, ::Val{Width}, args, dargs) where {forwardMode, Width} + +@inline function iterate_unwrap_fwd_batchdup(::Val{Width}, args, dargs) where {Width} ntuple(Val(length(args))) do i Base.@_inline_meta arg = args[i] ty = Core.Typeof(arg) if guaranteed_const(ty) Const(arg) - elseif !forwardMode && active_reg(ty) + else + BatchDuplicated(arg, ntuple(Val(Width)) do j + Base.@_inline_meta + dargs[j][i] + end) + end + end +end + +function push_if_not_ref(::Val{reverse}, vals, darg, ::Type{T2}) where {reverse, T2} + if reverse + return popfirst!(vals) + else + tmp = Base.RefValue{T2}(darg) + push!(vals, tmp) + return tmp + end +end + +function push_if_not_ref(::Val{reverse}, vals, darg::Base.RefValue{T2}, ::Type{T2}) where {reverse, T2} + return darg +end + +@inline function iterate_unwrap_augfwd_dup(::Val{reverse}, vals, args, dargs) where reverse + ntuple(Val(length(args))) do i + Base.@_inline_meta + arg = args[i] + ty = Core.Typeof(arg) + actreg = active_reg_nothrow(ty, Val(nothing)) + if actreg == AnyState + Const(arg) + elseif actreg == ActiveState Active(arg) + elseif actreg == MixedState + darg = Base.inferencebarrier(dargs[i]) + MixedDuplicated(arg, push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty}) + else + Duplicated(arg, dargs[i]) + end + end +end + +@inline function iterate_unwrap_augfwd_batchdup(::Val{reverse}, vals, ::Val{Width}, args, dargs) where {reverse, Width} + ntuple(Val(length(args))) do i + Base.@_inline_meta + arg = args[i] + ty = Core.Typeof(arg) + actreg = active_reg_nothrow(ty, Val(nothing)) + if actreg == AnyState + Const(arg) + elseif actreg == ActiveState + Active(arg) + elseif actreg == MixedState + BatchMixedDuplicated(arg, ntuple(Val(Width)) do j + Base.@_inline_meta + darg = Base.inferencebarrier(dargs[j][i]) + push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty} + end) else BatchDuplicated(arg, ntuple(Val(Width)) do j Base.@_inline_meta @@ -597,9 +657,10 @@ function fwddiff_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType end end -function body_runtime_iterate_fwd(N, Width, wrapped, primtypes) +function body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) return quote + $(active_refs...) args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) @@ -609,7 +670,7 @@ end function func_runtime_iterate_fwd(N, Width) _, _, primtypes, allargs, typeargs, wrapped, _, _, active_refs = setup_macro_wraps(true, N, Width, #=base=#nothing, #=iterate=#true) - body = body_runtime_iterate_fwd(N, Width, wrapped, primtypes) + body = body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) quote function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, ReturnType, F, DF, $(typeargs...)} @@ -621,7 +682,7 @@ end @generated function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, _, _, active_refs = setup_macro_wraps(true, N, Width, :allargs, #=iterate=#true) - return body_runtime_iterate_fwd(N, Width, wrapped, primtypes) + return body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) end function primal_tuple(args::Vararg{Annotation, Nargs}) where Nargs @@ -631,29 +692,43 @@ function primal_tuple(args::Vararg{Annotation, Nargs}) where Nargs end end -function shadow_tuple(::Val{1}, args::Vararg{Annotation, Nargs}) where Nargs - ntuple(Val(Nargs)) do i +function shadow_tuple(::Type{Ann}, ::Val{1}, args::Vararg{Annotation, Nargs}) where {Ann, Nargs} + res = ntuple(Val(Nargs)) do i Base.@_inline_meta @assert !(args[i] isa Active) if args[i] isa Const args[i].val + elseif args[i] isa MixedDuplicated + args[i].dval[] else args[i].dval end end + if Ann <: MixedDuplicated + Ref(res) + else + res + end end -function shadow_tuple(::Val{width}, args::Vararg{Annotation, Nargs}) where {width, Nargs} +function shadow_tuple(::Type{Ann}, ::Val{width}, args::Vararg{Annotation, Nargs}) where {Ann, width, Nargs} ntuple(Val(width)) do w - ntuple(Val(Nargs)) do i - Base.@_inline_meta - @assert !(args[i] isa Active) - if args[i] isa Const - args[i].val - else - args[i].dval[w] + res = ntuple(Val(Nargs)) do i + Base.@_inline_meta + @assert !(args[i] isa Active) + if args[i] isa Const + args[i].val + elseif args[i] isa BatchMixedDuplicated + args[i].dval[w][] + else + args[i].dval[w] + end + end + if Ann <: BatchMixedDuplicated + Ref(res) + else + res end - end end end @@ -669,6 +744,8 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} annotation = if width != 1 if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated BatchDuplicated{rt, width} + elseif annotation0 <: MixedDuplicated + BatchMixedDuplicated{rt, width} elseif annotation0 <: Active Active{rt} else @@ -677,6 +754,8 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} else if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated Duplicated{rt} + elseif annotation0 <: MixedDuplicated + MixedDuplicated{rt} elseif annotation0 <: Active Active{rt} else @@ -703,19 +782,20 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) forward(fa, args...) else - nothing, primal_tuple(args...), annotation <: Active ? nothing : shadow_tuple(Val(width), args...) + nothing, primal_tuple(args...), annotation <: Active ? nothing : shadow_tuple(annotation, Val(width), args...) end resT = typeof(origRet) + if annotation <: Const shadow_return = nothing tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) return ReturnType((allSame(Val(width+1), origRet)..., tape)) elseif annotation <: Active - if width == 1 - shadow_return = Ref(make_zero(origRet)) + shadow_return = if width == 1 + Ref(make_zero(origRet)) else - shadow_return = allZero(Val(width), origRet) + allZero(Val(width), origRet) end tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) if width == 1 @@ -725,30 +805,49 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} end end - @assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed - - shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) if width == 1 - return ReturnType((origRet, initShadow, tape)) + if annotation <: MixedDuplicated + shadow_return = initShadow + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType((origRet, initShadow, tape)) + else + shadow_return = nothing + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType((origRet, initShadow, tape)) + end else - return ReturnType((origRet, initShadow..., tape)) + if annotation <: BatchMixedDuplicated + shadow_return = initShadow + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType((origRet, initShadow..., tape)) + else + shadow_return = nothing + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType((origRet, initShadow..., tape)) + end end end -function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) +function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) + results = Expr[] + for i in 1:(Width+1) + push!(results, :(tmpvals[$i])) + end return quote + refs = Base.RefValue[] + $(active_refs...) args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - augfwd_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, Val(concat($(modbetween...))), FT, tt′, f, df, args...)::ReturnType + tmpvals = augfwd_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, Val(concat($(modbetween...))), FT, tt′, f, df, args...)::ReturnType + ReturnType(($(results...), (tmpvals[$(Width+2)], refs))) end end function func_runtime_iterate_augfwd(N, Width) _, _, primtypes, allargs, typeargs, wrapped, _, modbetween, active_refs = setup_macro_wraps(false, N, Width, #=base=#nothing, #=iterate=#true) - body = body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) + body = body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) quote function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} @@ -760,7 +859,7 @@ end @generated function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, _ , modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) - return body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) + return body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) end @@ -781,6 +880,8 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween annotation = if width != 1 if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated BatchDuplicated{rt, width} + elseif annotation0 <: MixedDuplicated + BatchMixedDuplicated{rt, width} elseif annotation0 <: Active Active{rt} else @@ -789,6 +890,8 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween else if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated Duplicated{rt} + elseif annotation0 <: MixedDuplicated + MixedDuplicated{rt} elseif annotation0 <: Active Active{rt} else @@ -811,15 +914,20 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween forward, adjoint = thunk(Val(world), FA, annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - + args2 = if tape.shadow_return !== nothing if width == 1 (args..., tape.shadow_return[]) else - (args..., ntuple(Val(width)) do w + shads = ntuple(Val(width)) do w Base.@_inline_meta tape.shadow_return[w][] - end) + end + if annotation <: MixedDuplicated || annotation <: BatchMixedDuplicated + (args..., shads...,) + else + (args..., shads) + end end else args @@ -838,6 +946,15 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween tape.shadow_return[w][][i] end end + elseif args[i] isa MixedDuplicated || args[i] isa BatchMixedDuplicated + if width == 1 + tape.shadow_return[][i] + else + ntuple(Val(width)) do w + Base.@_inline_meta + tape.shadow_return[w][][i] + end + end else nothing end @@ -849,14 +966,20 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween ntuple(Val(width)) do w Base.@_inline_meta - - if tup[i] == nothing - else - expr = if width == 1 - tup[i] + if args[i] isa Active || args[i] isa MixedDuplicated || args[i] isa BatchMixedDuplicated + expr = if args[i] isa Active || f == Base.tuple + if width == 1 + tup[i] + else + tup[i][w] + end + elseif args[i] isa MixedDuplicated + args[i].dval[] else - tup[i][w] + # if args[i] isa BatchMixedDuplicated + args[i].dval[w][] end + idx_of_vec, idx_in_vec = lengths[i] vec = @inbounds shadowargs[idx_of_vec][w] if vec isa Base.RefValue @@ -866,7 +989,7 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween Base.@_inline_meta prev = getfield(vecld, i) if i == idx_in_vec - recursive_add(prev, expr) + recursive_add(prev, expr, identity, guaranteed_nonactive) else prev end @@ -876,7 +999,7 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween if val isa Base.RefValue val[] = recursive_add(val[], expr) elseif ismutable(vec) - @inbounds vec[idx_in_vec] = recursive_add(val, expr) + @inbounds vec[idx_in_vec] = recursive_add(val, expr, identity, guaranteed_nonactive) else error("Enzyme Mutability Error: Cannot in place to immutable value vec[$idx_in_vec] = $val, vec=$vec") end @@ -891,26 +1014,7 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween nothing end -function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shadowargs) - outs = [] - for i in 1:N - for w in 1:Width - expr = if Width == 1 - :(tup[$i]) - else - :(tup[$i][$w]) - end - shad = shadowargs[i][w] - out = :(if tup[$i] === nothing - elseif $shad isa Base.RefValue - $shad[] = recursive_add($shad[], $expr) - else - error("Enzyme Mutability Error: Cannot add in place to immutable value "*string($shad)) - end - ) - push!(outs, out) - end - end +function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shadowargs, active_refs) shadow_ret = nothing if Width == 1 shadowret = :(tape.shadow_return[]) @@ -938,17 +1042,19 @@ function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shado push!(shadowsplat, :(($(s...),))) end quote + (tape0, refs) = tape + $(active_refs...) args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - rev_with_return(Val($Width), Val(ActivityTup[1]), Val(concat($(modbetween...))), Val(concat($(lengths...))), FT, tt′, f, df, tape, ($(shadowsplat...),), args...) + rev_with_return(Val($Width), Val(ActivityTup[1]), Val(concat($(modbetween...))), Val(concat($(lengths...))), FT, tt′, f, df, tape0, ($(shadowsplat...),), args...) return nothing end end function func_runtime_iterate_rev(N, Width) - primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween = setup_macro_wraps(false, N, Width, #=body=#nothing, #=iterate=#true) - body = body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, #=body=#nothing, #=iterate=#true; reverse=true) + body = body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs, active_refs) quote function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)} @@ -959,8 +1065,8 @@ end @generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 - primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) - return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true; reverse=true) + return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs, active_refs) end # Create specializations diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 36f2798c0c..0b20dc77d4 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -603,7 +603,9 @@ function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isco Base.getfield(dptr, symname) end RT = Core.Typeof(res) - if active_reg(RT) + + actreg = active_reg_nothrow(RT, Val(nothing)) + if actreg == ActiveState if length(dptrs) == 0 return Ref{RT}(make_zero(res)) else @@ -612,6 +614,17 @@ function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isco Ref{RT}(make_zero(res)) end) end + elseif actreg == MixedState + if length(dptrs) == 0 + return Ref{RT}(res) + else + fval = NT((Ref{RT}(res), (ntuple(Val(length(dptrs))) do i + Base.@_inline_meta + dv = dptrs[i] + Ref{RT}(getfield(dv isa Base.RefValue ? dv[] : dv, symname)) + end)...)) + return fval + end else if length(dptrs) == 0 return res @@ -633,8 +646,8 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc Base.getfield(dptr, symname+1) end RT = Core.Typeof(res) - actreg = active_reg(RT) - if actreg + actreg = active_reg_nothrow(RT, Val(nothing)) + if actreg == ActiveState if length(dptrs) == 0 return Ref{RT}(make_zero(res))::Any else @@ -643,6 +656,17 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc Ref{RT}(make_zero(res)) end) end + elseif actreg == MixedState + if length(dptrs) == 0 + return Ref{RT}(res)::Any + else + fval = NT((Ref{RT}(res), (ntuple(Val(length(dptrs))) do i + Base.@_inline_meta + dv = dptrs[i] + Ref{RT}(getfield(dv isa Base.RefValue ? dv[] : dv, symname+1)) + end)...)) + return fval + end else if length(dptrs) == 0 return res::Any @@ -665,7 +689,9 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, end RT = Core.Typeof(cur) - if active_reg(RT) && !isconst + + actreg = active_reg_nothrow(RT, Val(nothing)) + if (actreg == ActiveState || actreg == MixedState) && !isconst if length(dptrs) == 0 if dptr isa Base.RefValue vload = dptr[] @@ -674,13 +700,13 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, Base.@_inline_meta prev = getfield(vload, i) if fieldname(dRT, i) == symname - recursive_add(prev, dret[]) + recursive_add(prev, dret[], identity, guaranteed_nonactive) else prev end end) else - setfield!(dptr, symname, recursive_add(cur, dret[])) + setfield!(dptr, symname, recursive_add(cur, dret[], identity, guaranteed_nonactive)) end else if dptr isa Base.RefValue @@ -690,7 +716,7 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, Base.@_inline_meta prev = getfield(vload, j) if fieldname(dRT, j) == symname - recursive_add(prev, dret[1][]) + recursive_add(prev, dret[1][], identity, guaranteed_nonactive) else prev end @@ -706,7 +732,7 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, Base.@_inline_meta prev = getfield(vload, j) if fieldname(dRT, j) == symname - recursive_add(prev, dret[1+i][]) + recursive_add(prev, dret[1+i][], identity, guaranteed_nonactive) else prev end @@ -717,7 +743,7 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, else Base.getfield(dptrs[i], symname) end - setfield!(dptrs[i], symname, recursive_add(curi, dret[1+i][])) + setfield!(dptrs[i], symname, recursive_add(curi, dret[1+i][], identity, guaranteed_nonactive)) end end end @@ -733,7 +759,9 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} end RT = Core.Typeof(cur) - if active_reg(RT) && !isconst + + actreg = active_reg_nothrow(RT, Val(nothing)) + if (actreg == ActiveState || actreg == MixedState) && !isconst if length(dptrs) == 0 if dptr isa Base.RefValue vload = dptr[] @@ -742,13 +770,13 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} Base.@_inline_meta prev = getfield(vload, i) if i == symname+1 - recursive_add(prev, dret[]) + recursive_add(prev, dret[], identity, guaranteed_nonactive) else prev end end) else - setfield!(dptr, symname+1, recursive_add(cur, dret[])) + setfield!(dptr, symname+1, recursive_add(cur, dret[], identity, guaranteed_nonactive)) end else if dptr isa Base.RefValue @@ -758,13 +786,13 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} Base.@_inline_meta prev = getfield(vload, j) if j == symname+1 - recursive_add(prev, dret[1][]) + recursive_add(prev, dret[1][], identity, guaranteed_nonactive) else prev end end) else - setfield!(dptr, symname+1, recursive_add(cur, dret[1][])) + setfield!(dptr, symname+1, recursive_add(cur, dret[1][], identity, guaranteed_nonactive)) end for i in 1:length(dptrs) if dptrs[i] isa Base.RefValue @@ -774,7 +802,7 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} Base.@_inline_meta prev = getfield(vload, j) if j == symname+1 - recursive_add(prev, dret[1+i][]) + recursive_add(prev, dret[1+i][], identity, guaranteed_nonactive) else prev end @@ -785,7 +813,7 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} else Base.getfield(dptrs[i], symname+1) end - setfield!(dptrs[i], symname+1, recursive_add(curi, dret[1+i][])) + setfield!(dptrs[i], symname+1, recursive_add(curi, dret[1+i][], identity, guaranteed_nonactive)) end end end diff --git a/test/applyiter.jl b/test/applyiter.jl index b1a26e5f54..11e9ebf37c 100644 --- a/test/applyiter.jl +++ b/test/applyiter.jl @@ -503,3 +503,5 @@ end Enzyme.autodiff(Reverse, mktup3, Duplicated(data, ddata)) @test ddata[1][1] ≈ 6.0 end + +include("mixedapplyiter.jl") \ No newline at end of file diff --git a/test/mixedapplyiter.jl b/test/mixedapplyiter.jl new file mode 100644 index 0000000000..bb7f18243c --- /dev/null +++ b/test/mixedapplyiter.jl @@ -0,0 +1,144 @@ +using Enzyme, Test + +concat() = () +concat(a) = a +concat(a, b) = (a..., b...) +concat(a, b, c...) = concat(concat(a, b), c...) + +metaconcat(x) = concat(x...) + +metaconcat2(x, y) = concat(x..., y...) + +midconcat(x, y) = (x, concat(y...)...) + +metaconcat3(x, y, z) = concat(x..., y..., z...) + +function mixed_metasumsq(f, args...) + res = 0.0 + x = f(args...) + for v in x + v = v::Tuple{Float64, Vector{Float64}} + res += v[1]*v[1] + v[2][1] * v[2][1] + end + return res +end + +function mixed_metasumsq3(f, args...) + res = 0.0 + x = f(args...) + for v in x + v = v + res += v*v + end + return res +end + +function make_byref(out, fn, args...) + out[] = fn(args...) + nothing +end + +function tupapprox(a, b) + if a isa Tuple && b isa Tuple + if length(a) != length(b) + return false + end + for (aa, bb) in zip(a, b) + if !tupapprox(aa, bb) + return false + end + end + return true + end + if a isa Array && b isa Array + if size(a) != size(b) + return false + end + for i in length(a) + if !tupapprox(a[i], b[i]) + return false + end + end + return true + end + return a ≈ b +end + +@testset "Mixed Reverse Apply iterate (tuple)" begin + x = [((2.0, [2.7]), (3.0, [3.14])), ((7.9, [47.0]), (11.2, [56.0]))] + dx = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] + res = Enzyme.autodiff(Reverse, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @test tupapprox(dx, [((4.0, [5.4]), (6.0, [6.28])), ((15.8, [94.0]), (22.4, [112.0]))]) + + x = [((2.0, [2.7]), (3.0, [3.14])), ((7.9, [47.0]), (11.2, [56.0]))] + + dx = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] + res = Enzyme.autodiff(ReverseWithPrimal, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @test res[2] ≈ 5562.9996 + @test tupapprox(dx, [((4.0, [5.4]), (6.0, [6.28])), ((15.8, [94.0]), (22.4, [112.0]))]) +end + +@testset "BatchMixed Reverse Apply iterate (tuple)" begin + x = [((2.0, [2.7]), (3.0, [3.14])), ((7.9, [47.0]), (11.2, [56.0]))] + dx = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] + dx2 = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] + + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test tupapprox(dx, [((4.0, [5.4]), (6.0, [6.28])), ((15.8, [94.0]), (22.4, [112.0]))]) + @test tupapprox(dx2, [((3*4.0, [3*5.4]), (3*6.0, [3*6.28])), ((3*15.8, [3*94.0]), (3*22.4, [3*112.0]))]) + + x = [((2.0, [2.7]), (3.0, [3.14])), ((7.9, [47.0]), (11.2, [56.0]))] + dx = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] + dx2 = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] + + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test out[] ≈ 5562.9996 + @test tupapprox(dx, [((4.0, [5.4]), (6.0, [6.28])), ((15.8, [94.0]), (22.4, [112.0]))]) + @test tupapprox(dx2, [((3*4.0, [3*5.4]), (3*6.0, [3*6.28])), ((3*15.8, [3*94.0]), (3*22.4, [3*112.0]))]) +end + + +@testset "Mixed Reverse Apply iterate (list)" begin + x = [[(2.0, [2.7]), (3.0, [3.14])], [(7.9, [47.0]), (11.2, [56.0])]] + dx = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] + + res = Enzyme.autodiff(Reverse, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @test tupapprox(dx, [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]]) + + dx = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] + + res = Enzyme.autodiff(ReverseWithPrimal, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @test res[2] ≈ 5562.9996 + @test tupapprox(dx, [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]]) +end + +@testset "BatchMixed Reverse Apply iterate (list)" begin + x = [[(2.0, [2.7]), (3.0, [3.14])], [(7.9, [47.0]), (11.2, [56.0])]] + dx = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] + dx2 = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] + + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test tupapprox(dx, [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]]) + @test tupapprox(dx2, [[(3*4.0, [3*5.4]), (3*6.0, [3*6.28])], [(3*15.8, [3*94.0]), (3*22.4, [3*112.0])]]) + + x = [[(2.0, [2.7]), (3.0, [3.14])], [(7.9, [47.0]), (11.2, [56.0])]] + dx = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] + dx2 = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] + + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test out[] ≈ 5562.9996 + @test tupapprox(dx, [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]]) + @test tupapprox(dx2, [[(3*4.0, [3*5.4]), (3*6.0, [3*6.28])], [(3*15.8, [3*94.0]), (3*22.4, [3*112.0])]]) +end \ No newline at end of file diff --git a/test/usermixed.jl b/test/usermixed.jl index b5cd0e158b..f97c5737ec 100644 --- a/test/usermixed.jl +++ b/test/usermixed.jl @@ -1,6 +1,122 @@ using Enzyme using Test +########## MixedDuplicated of Return + +function user_mixret(x, y) + return (x, y) +end + +@testset "MixedDuplicated struct return" begin + x = 2.7 + y = [3.14] + dy = [0.0] + + fwd, rev = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(user_mixret)}, MixedDuplicated, Active{Float64}, Duplicated{Vector{Float64}}) + + tape, res, dres = fwd(Const(user_mixret), Active(x), Duplicated(y, dy)) + + @test res[1] ≈ x + @test res[2] === y + + @test dres[][1] ≈ 0.0 + @test dres[][2] === dy + + outs = rev(Const(user_mixret), Active(x), Duplicated(y, dy), (47.56, dy), tape) + + @test outs[1][1] ≈ 47.56 +end + +@testset "BatchMixedDuplicated struct return" begin + x = 2.7 + y = [3.14] + dy = [0.0] + dy2 = [0.0] + + fwd, rev = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(user_mixret)}, BatchMixedDuplicated, Active{Float64}, BatchDuplicated{Vector{Float64}, 2}) + + tape, res, dres = fwd(Const(user_mixret), Active(x), BatchDuplicated(y, (dy, dy2))) + + @test res[1] ≈ x + @test res[2] === y + + @test dres[1][][1] ≈ 0.0 + @test dres[1][][2] === dy + @test dres[2][][1] ≈ 0.0 + @test dres[2][][2] === dy2 + + outs = rev(Const(user_mixret), Active(x), BatchDuplicated(y, (dy, dy2)), (47.0, dy), (56.0, dy), tape) + + @test outs[1][1][1] ≈ 47.0 + @test outs[1][1][2] ≈ 56.0 +end + + +function user_fltret(x, y) + return x +end + +@testset "MixedDuplicated float return" begin + x = 2.7 + + fwd, rev = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(identity)}, MixedDuplicated, Active{Float64}) + + tape, res, dres = fwd(Const(identity), Active(x)) + + @test res ≈ x + @test dres[] ≈ 0.0 + + outs = rev(Const(identity), Active(x), 47.56, tape) + + @test outs[1][1] ≈ 47.56 +end + +@testset "BatchMixedDuplicated float return" begin + x = 2.7 + y = [3.14] + dy = [0.0] + dy2 = [0.0] + + fwd, rev = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(user_fltret)}, BatchMixedDuplicated, Active{Float64}, BatchDuplicated{Vector{Float64}, 2}) + + tape, res, dres = fwd(Const(user_fltret), Active(x), BatchDuplicated(y, (dy, dy2))) + + @test res ≈ x + + @test dres[1][] ≈ 0.0 + @test dres[2][] ≈ 0.0 + + outs = rev(Const(user_fltret), Active(x), BatchDuplicated(y, (dy, dy2)), 47.0, 56.0, tape) + + @test outs[1][1][1] ≈ 47.0 + @test outs[1][1][2] ≈ 56.0 +end + +function vecsq(x) + x[2] = x[1] * x[1] + return x +end + +@testset "MixedDuplicated vector return" begin + y = [3.14, 0.0] + dy = [0.0, 2.7] + + fwd, rev = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(vecsq)}, MixedDuplicated, Duplicated{Vector{Float64}}) + + tape, res, dres = fwd(Const(vecsq), Duplicated(y, dy)) + + @test res === y + + @test dres[] === dy + + outs = rev(Const(vecsq), Duplicated(y, dy), dy, tape) + + @test dy ≈ [3.14 * 2.7 * 2, 0.0] +end + + +########## MixedDuplicated of Argument + function user_mixfnc(tup) return tup[1] * tup[2][1] end From 82cc451b0ebcf4434fb8fc73f51c58ba1ea6be96 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 15 Jun 2024 09:32:24 -0400 Subject: [PATCH 120/495] Abstract is mixed (#1536) * Abstract is mixed * fix unionall * fix * more fixups --- src/compiler.jl | 97 ++++++++++++++++++++++------------ src/rules/jitrules.jl | 57 +++++++++++++++++++- src/rules/typeunstablerules.jl | 24 ++++----- test/mixedapplyiter.jl | 25 ++++++++- test/runtests.jl | 3 ++ 5 files changed, 158 insertions(+), 48 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7bbb1bbedd..cd5403b6c2 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -252,16 +252,30 @@ end ActivityState(Int(a1) | Int(a2)) end -struct Merger{seen,worldT,justActive,UnionSret} +struct Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed} world::worldT end @inline element(::Val{T}) where T = T -@inline function (c::Merger{seen,worldT,justActive,UnionSret})(f::Int) where {seen,worldT,justActive,UnionSret} +# From https://github.com/JuliaLang/julia/blob/81813164963f38dcd779d65ecd222fad8d7ed437/src/cgutils.cpp#L570 +@inline function isghostty(ty) + if ty === Union{} + return true + end + if Base.isconcretetype(ty) && !ismutabletype(ty) + if sizeof(ty) == 0 + return true + end + # TODO consider struct_to_llvm ? + end + return false +end + +@inline function (c::Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed})(f::Int) where {seen,worldT,justActive,UnionSret,AbstractIsMixed} T = element(first(seen)) - reftype = ismutabletype(T) || T isa UnionAll + reftype = ismutabletype(T) || (T isa UnionAll && !AbstractIsMixed) if justActive && reftype return Val(AnyState) @@ -273,7 +287,7 @@ end return Val(AnyState) end - sub = active_reg_inner(subT, seen, c.world, Val(justActive), Val(UnionSret)) + sub = active_reg_inner(subT, seen, c.world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) if sub == AnyState Val(AnyState) @@ -372,16 +386,23 @@ end end) end -@inline function active_reg_recur(::Type{ST}, seen::Seen, world, ::Val{justActive}, ::Val{UnionSret}) where {ST, Seen, justActive, UnionSret} +@inline function active_reg_recur(::Type{ST}, seen::Seen, world, ::Val{justActive}, ::Val{UnionSret}, ::Val{AbstractIsMixed}) where {ST, Seen, justActive, UnionSret, AbstractIsMixed} if ST isa Union - return forcefold(Val(active_reg_recur(ST.a, seen, world, Val(justActive), Val(UnionSret))), Val(active_reg_recur(ST.b, seen, world, Val(justActive), Val(UnionSret)))) + return forcefold(Val(active_reg_recur(ST.a, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))), Val(active_reg_recur(ST.b, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)))) end - return active_reg_inner(ST, seen, world, Val(justActive), Val(UnionSret)) + return active_reg_inner(ST, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) end -@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret} +@inline is_vararg_tup(x) = false +@inline is_vararg_tup(::Type{Tuple{Vararg{T2}}}) where T2 = true + +@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false), ::Val{AbstractIsMixed}=Val(false))::ActivityState where {ST,T, justActive, UnionSret, AbstractIsMixed} if T === Any - return DupState + if AbstractIsMixed + return MixedState + else + return DupState + end end if T === Union{} @@ -389,7 +410,7 @@ end end if T <: Complex && !(T isa UnionAll) - return active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret)) + return active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) end if T <: AbstractFloat @@ -401,10 +422,14 @@ end return AnyState end - if is_arrayorvararg_ty(T) && active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret)) == AnyState + if is_arrayorvararg_ty(T) && active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) == AnyState return AnyState else - return DupState + if AbstractIsMixed && is_vararg_tup(T) + return MixedState + else + return DupState + end end end @@ -434,10 +459,18 @@ end if T isa UnionAll aT = Base.argument_datatype(T) if aT === nothing - return DupState + if AbstractIsMixed + return MixedState + else + return DupState + end end if datatype_fieldcount(aT) === nothing - return DupState + if AbstractIsMixed + return MixedState + else + return DupState + end end end @@ -445,16 +478,24 @@ end # if sret union, the data is stored in a stack memory location and is therefore # not unique'd preventing the boxing of the union in the default case if UnionSret && is_sret_union(T) - return active_reg_recur(T, seen, world, Val(justActive), Val(UnionSret)) + return active_reg_recur(T, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) else if justActive return AnyState end if active_reg_inner(T.a, seen, world, Val(justActive), Val(UnionSret)) != AnyState - return DupState + if AbstractIsMixed + return MixedState + else + return DupState + end end if active_reg_inner(T.b, seen, world, Val(justActive), Val(UnionSret)) != AnyState - return DupState + if AbstractIsMixed + return MixedState + else + return DupState + end end end return AnyState @@ -462,7 +503,11 @@ end # if abstract it must be by reference if Base.isabstracttype(T) - return DupState + if AbstractIsMixed + return MixedState + else + return DupState + end end if ismutabletype(T) @@ -504,7 +549,7 @@ end seen2 = (Val(nT), seen...) - fty = Merger{seen2,typeof(world),justActive, UnionSret}(world) + fty = Merger{seen2,typeof(world),justActive, UnionSret, AbstractIsMixed}(world) ty = forcefold(Val(AnyState), ntuple(fty, Val(fieldcount(nT)))...) @@ -1158,20 +1203,6 @@ function permit_inlining!(f::LLVM.Function) end end -# From https://github.com/JuliaLang/julia/blob/81813164963f38dcd779d65ecd222fad8d7ed437/src/cgutils.cpp#L570 -@inline function isghostty(ty) - if ty === Union{} - return true - end - if Base.isconcretetype(ty) && !ismutabletype(ty) - if sizeof(ty) == 0 - return true - end - # TODO consider struct_to_llvm ? - end - return false -end - struct Tape{TapeTy,ShadowTy,ResT} internal_tape::TapeTy shadow_return::ShadowTy diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index f04145f7bc..dc0bd04d80 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -95,6 +95,15 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, end ) else + mixexpr = if Width == 1 + quote + iterate_unwrap_augfwd_mix(Val($reverse), refs, $(primargs[i]), $(shadowargs[i])) + end + else + quote + iterate_unwrap_augfwd_batchmix(Val($reverse), refs, Val($Width), $(primargs[i]), $(shadowargs[i])) + end + end dupexpr = if Width == 1 quote iterate_unwrap_augfwd_dup(Val($reverse), refs, $(primargs[i]), $(shadowargs[i])) @@ -110,8 +119,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, if $aref == ActiveState iterate_unwrap_augfwd_act($(primargs[i])...) elseif $aref == MixedState - T = $(primtypes[i]) - throw(AssertionError("Mixed State of type $T is unsupported in apply iterate")) + $mixexpr else $dupexpr end @@ -586,6 +594,51 @@ end end end +@inline function iterate_unwrap_augfwd_mix(::Val{reverse}, vals, args, dargs0) where reverse + dargs = dargs0[] + ntuple(Val(length(args))) do i + Base.@_inline_meta + arg = args[i] + ty = Core.Typeof(arg) + actreg = active_reg_nothrow(ty, Val(nothing)) + if actreg == AnyState + Const(arg) + elseif actreg == ActiveState + Active(arg) + elseif actreg == MixedState + darg = Base.inferencebarrier(dargs[i]) + MixedDuplicated(arg, push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty}) + else + Duplicated(arg, dargs[i]) + end + end +end + +@inline function iterate_unwrap_augfwd_batchmix(::Val{reverse}, vals, ::Val{Width}, args, dargs) where {reverse, Width} + ntuple(Val(length(args))) do i + Base.@_inline_meta + arg = args[i] + ty = Core.Typeof(arg) + actreg = active_reg_nothrow(ty, Val(nothing)) + if actreg == AnyState + Const(arg) + elseif actreg == ActiveState + Active(arg) + elseif actreg == MixedState + BatchMixedDuplicated(arg, ntuple(Val(Width)) do j + Base.@_inline_meta + darg = Base.inferencebarrier(dargs[j][][i]) + push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty} + end) + else + BatchDuplicated(arg, ntuple(Val(Width)) do j + Base.@_inline_meta + dargs[j][][i] + end) + end + end +end + @inline function allFirst(::Val{Width}, res) where Width ntuple(Val(Width)) do i Base.@_inline_meta diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 0b20dc77d4..3ed5cfd72f 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -11,13 +11,13 @@ function body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batch shadow_rets_i = Expr[] aref = Symbol("active_ref_$i") for w in 1:Width - sref = Symbol("shadow_"*string(i)*"_"*string(w)) + sref = Symbol("sub_shadow_"*string(i)*"_"*string(w)) push!(shadow_rets_i, quote $sref = if $aref == AnyState $(primargs[i]); else if !ActivityTup[$i] - if $aref == DupState || $aref == MixedState + if ($aref == DupState || $aref == MixedState) && $(batchshadowargs[i][w]) === nothing prim = $(primargs[i]) throw("Error cannot store inactive but differentiable variable $prim into active tuple") end @@ -98,7 +98,7 @@ function body_construct_rev(N, Width, primtypes, active_refs, primargs, batchsha shad = batchshadowargs[i][w] out = :(if $(Symbol("active_ref_$i")) == MixedState || $(Symbol("active_ref_$i")) == ActiveState if $shad isa Base.RefValue - $shad[] = recursive_add($shad[], $expr) + $shad[] = recursive_add($shad[], $expr, identity, guaranteed_nonactive) else error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)) end @@ -248,10 +248,10 @@ function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR) # if any active [e.g. ActiveState / MixedState] data could exist # err if !fwd - if !found + if !found_partial return false end - act = active_reg_inner(typ, (), world) + act = active_reg_inner(typ_partial, (), world, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) if act == MixedState || act == ActiveState return false end @@ -306,7 +306,7 @@ function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) return false end -function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) +function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)::Bool needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) @@ -379,7 +379,7 @@ function common_f_tuple_fwd(offset, B, orig, gutils, normalR, shadowR) common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) end -function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) +function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)::Bool needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) @@ -420,8 +420,8 @@ function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) unsafe_store!(tapeR, sret.ref) - return false end + return false end function common_f_tuple_rev(offset, B, orig, gutils, tape) @@ -474,7 +474,7 @@ function f_tuple_fwd(B, orig, gutils, normalR, shadowR) common_f_tuple_fwd(1, B, orig, gutils, normalR, shadowR) end -function f_tuple_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +function f_tuple_augfwd(B, orig, gutils, normalR, shadowR, tapeR)::Bool common_f_tuple_augfwd(1, B, orig, gutils, normalR, shadowR, tapeR) end @@ -487,7 +487,7 @@ function new_structv_fwd(B, orig, gutils, normalR, shadowR) common_newstructv_fwd(1, B, orig, gutils, normalR, shadowR) end -function new_structv_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +function new_structv_augfwd(B, orig, gutils, normalR, shadowR, tapeR)::Bool common_newstructv_augfwd(1, B, orig, gutils, normalR, shadowR, tapeR) end @@ -525,7 +525,7 @@ function new_structt_fwd(B, orig, gutils, normalR, shadowR) unsafe_store!(shadowR, shadowres.ref) return false end -function new_structt_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +function new_structt_augfwd(B, orig, gutils, normalR, shadowR, tapeR)::Bool new_structt_fwd(B, orig, gutils, normalR, shadowR) end @@ -821,7 +821,7 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} return nothing end -function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) +function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)::Bool if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL return true end diff --git a/test/mixedapplyiter.jl b/test/mixedapplyiter.jl index bb7f18243c..0a4f06cbb9 100644 --- a/test/mixedapplyiter.jl +++ b/test/mixedapplyiter.jl @@ -141,4 +141,27 @@ end @test out[] ≈ 5562.9996 @test tupapprox(dx, [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]]) @test tupapprox(dx2, [[(3*4.0, [3*5.4]), (3*6.0, [3*6.28])], [(3*15.8, [3*94.0]), (3*22.4, [3*112.0])]]) -end \ No newline at end of file +end + +struct MyRectilinearGrid5{FT,FZ} + x :: FT + z :: FZ +end + + +@inline flatten_tuple(a::Tuple) = @inbounds a[2:end] +@inline flatten_tuple(a::Tuple{<:Any}) = tuple() #inner_flatten_tuple(a[1])...) + +function myupdate_state!(model) + tupled = Base.inferencebarrier((model,model)) + flatten_tuple(tupled) + return nothing +end + +@testset "Abstract type allocation" begin + model = MyRectilinearGrid5{Float64, Vector{Float64}}(0.0, [0.0]) + dmodel = MyRectilinearGrid5{Float64, Vector{Float64}}(0.0, [0.0]) + autodiff(Enzyme.Reverse, + myupdate_state!, + MixedDuplicated(model, Ref(dmodel))) +end diff --git a/test/runtests.jl b/test/runtests.jl index ca05883c13..719687ad42 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -140,6 +140,9 @@ end @assert Enzyme.Compiler.active_reg_inner(Tuple{S,Int64} where S, (), Base.get_world_counter()) == Enzyme.Compiler.DupState @assert Enzyme.Compiler.active_reg_inner(Union{Float64,Nothing}, (), nothing) == Enzyme.Compiler.DupState @assert Enzyme.Compiler.active_reg_inner(Union{Float64,Nothing}, (), nothing, #=justActive=#Val(false), #=unionSret=#Val(true)) == Enzyme.Compiler.ActiveState + @test Enzyme.Compiler.active_reg_inner(Tuple, (), nothing) == Enzyme.Compiler.DupState + @test Enzyme.Compiler.active_reg_inner(Tuple, (), nothing, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) == Enzyme.Compiler.MixedState + @test Enzyme.Compiler.active_reg_inner(Tuple{A,A} where A, (), nothing, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) == Enzyme.Compiler.MixedState world = codegen_world_age(typeof(f0), Tuple{Float64}) thunk_a = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI) thunk_b = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Const, Tuple{Const{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI) From 53f64a6397ea75f6a8c2fd12899a177f30a9a1d4 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 15 Jun 2024 23:08:10 -0400 Subject: [PATCH 121/495] inactive kwargs (#1539) * inactive kwargs * only kwargs * fixup * fix * fix --- src/compiler.jl | 43 +++++++++++++++++++++++++++++++++----- src/compiler/validation.jl | 1 + src/internal_rules.jl | 2 +- 3 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index cd5403b6c2..f2d1829571 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1087,6 +1087,7 @@ end function get_array_struct() +@static if VERSION < v"1.11-" # JL_EXTENSION typedef struct { # JL_DATA_TYPE # void *data; @@ -1117,6 +1118,41 @@ function get_array_struct() nrows = LLVM.IntType(8*sizeof(Csize_t)) return LLVM.StructType([ptrty, sizeT, arrayFlags, elsz, off, nrows]; packed=true) +else +# JL_EXTENSION typedef struct { +# JL_DATA_TYPE +# size_t length; +# void *ptr; +# // followed by padding and inline data, or owner pointer +# #ifdef _P64 +# // union { +# // jl_value_t *owner; +# // T inl[]; +# // }; +# #else +# // +# // jl_value_t *owner; +# // size_t padding[1]; +# // T inl[]; +# #endif +# } jl_genericmemory_t; +# +# JL_EXTENSION typedef struct { +# JL_DATA_TYPE +# void *ptr_or_offset; +# jl_genericmemory_t *mem; +# } jl_genericmemoryref_t; +# +# JL_EXTENSION typedef struct { +# JL_DATA_TYPE +# jl_genericmemoryref_t ref; +# size_t dimsize[]; // length for 1-D, otherwise length is mem->length +# } jl_array_t; + i8 = LLVM.IntType(8) + ptrty = LLVM.PointerType(i8, 10) + sizeT = LLVM.IntType(8*sizeof(Csize_t)) + return LLVM.StructType([ptrty, sizeT]; packed=true) +end end function get_array_data(B, array) @@ -1171,9 +1207,6 @@ function get_array_nrows(B, array) return LLVM.load!(B, nrows, v) end -dedupargs() = () -dedupargs(a, da, args...) = (a, dedupargs(args...)...) - # Force sret struct Return2 ret1::Any @@ -5398,11 +5431,11 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("readonly"), StringAttribute("enzyme_ta_norecur")]) continue end - if EnzymeRules.is_inactive_from_sig(mi.specTypes; world, method_table, caller) + if EnzymeRules.is_inactive_from_sig(specTypes; world, method_table, caller) && has_method(Tuple{typeof(EnzymeRules.inactive), specTypes.parameters...}, world, method_table) handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation"), StringAttribute("enzyme_ta_norecur")]) continue end - if EnzymeRules.is_inactive_noinl_from_sig(mi.specTypes; world, method_table, caller) + if EnzymeRules.is_inactive_noinl_from_sig(specTypes; world, method_table, caller) && has_method(Tuple{typeof(EnzymeRules.inactive_noinl), specTypes.parameters...}, world, method_table) handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation"), StringAttribute("enzyme_ta_norecur")], false, false) for bb in blocks(llvmfn) for inst in instructions(bb) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index caf86cbc03..8715ce1991 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -345,6 +345,7 @@ end end @inline function is_inactive(tys, world::UInt, mt) + specTypes = Interpreter.simplify_kw(Tuple{tys...}) if has_method(Tuple{typeof(EnzymeRules.inactive), tys...}, world, mt) return true end diff --git a/src/internal_rules.jl b/src/internal_rules.jl index ea33959b23..fe5ed05b89 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -15,7 +15,7 @@ end function EnzymeRules.inactive(::typeof(Base.fixup_stdlib_path), args...) return nothing end -function EnzymeRules.inactive(::typeof(Base.CoreLogging.handle_message), args...) +function EnzymeRules.inactive(::typeof(Base.CoreLogging.handle_message), args...; kwargs...) return nothing end function EnzymeRules.inactive(::typeof(Base.CoreLogging.logging_error), args...) From e9804331d9e5ef6c277982e68f9077203b6d1a37 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 16 Jun 2024 12:06:35 -0400 Subject: [PATCH 122/495] Quick adaptation to new pm (#1538) * Quick adaptation to new pm * fixup * fix passes * effects * fix versioning * more fixes * fixup * add gcloaded fb * Update utils.jl * more 1.11 * Update utils.jl * Update utils.jl * Update compiler.jl --- src/absint.jl | 9 +- src/compiler.jl | 220 +++++++++++++----- src/compiler/optimize.jl | 443 ++++++++++++++++++++++++++++++------- src/compiler/utils.jl | 182 +++++++++++++++ src/compiler/validation.jl | 89 ++++++-- src/rules/llvmrules.jl | 54 +++++ src/typetree.jl | 50 +++-- 7 files changed, 887 insertions(+), 160 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index ae9c35a09b..10dc024013 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -323,12 +323,17 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ end function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} - if isa(arg, ConstantExpr) ce = arg while isa(ce, ConstantExpr) if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr ce = operands(ce)[1] + elseif opcode(ce) == LLVM.API.LLVMGetElementPtr + if all(x -> isa(x, LLVM.ConstantInt) && convert(UInt, x) == 0, operands(ce)[2:end]) + ce = operands(ce)[1] + else + break + end else break end @@ -336,7 +341,7 @@ function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} if isa(ce, LLVM.GlobalVariable) ce = LLVM.initializer(ce) if (isa(ce, LLVM.ConstantArray) || isa(ce, LLVM.ConstantDataArray)) && eltype(value_type(ce)) == LLVM.IntType(8) - return (true, String(map((x)->convert(UInt8, x), collect(flib)[1:(end-1)]))) + return (true, String(map((x)->convert(UInt8, x), collect(ce)[1:(end-1)]))) end end diff --git a/src/compiler.jl b/src/compiler.jl index f2d1829571..3f57743cc0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3238,7 +3238,11 @@ function annotate!(mod, mode) for fname in ("julia.typeof",) if haskey(fns, fname) fn = fns[fname] - push!(function_attributes(fn), LLVM.EnumAttribute("readnone", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + end push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) end end @@ -3246,15 +3250,18 @@ function annotate!(mod, mode) for fname in ("jl_excstack_state","ijl_excstack_state") if haskey(fns, fname) fn = fns[fname] - push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) - push!(function_attributes(fn), LLVM.StringAttribute("inaccessiblememonly")) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) + push!(function_attributes(fn), LLVM.StringAttribute("inaccessiblememonly")) + else + push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + end end end for fname in ("jl_types_equal", "ijl_types_equal") if haskey(fns, fname) fn = fns[fname] - push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) end end @@ -3278,7 +3285,12 @@ function annotate!(mod, mode) if operands(c)[1] != fn continue end - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), LLVM.EnumAttribute("readonly", 0)) + attr = if LLVM.version().major <= 15 + LLVM.EnumAttribute("readonly") + else + EnumAttribute("memory", MemoryEffect((MRI_Ref << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data) + end + LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), attr) end end end @@ -3287,7 +3299,11 @@ function annotate!(mod, mode) if haskey(fns, fname) fn = fns[fname] # TODO per discussion w keno perhaps this should change to readonly / inaccessiblememonly - push!(function_attributes(fn), LLVM.EnumAttribute("readnone", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + end push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) end end @@ -3320,7 +3336,11 @@ function annotate!(mod, mode) for fname in ("julia.pointer_from_objref",) if haskey(fns, fname) fn = fns[fname] - push!(function_attributes(fn), LLVM.EnumAttribute("readnone", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + end end end @@ -3336,8 +3356,13 @@ function annotate!(mod, mode) fn = fns[boxfn] push!(return_attributes(fn), LLVM.EnumAttribute("noalias", 0)) push!(function_attributes(fn), no_escaping_alloc) + accattr = if LLVM.version().major <= 15 + LLVM.EnumAttribute("inaccessiblememonly") + else + EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_ModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data) + end if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) - push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly", 0)) + push!(function_attributes(fn), accattr) end for u in LLVM.uses(fn) c = LLVM.user(u) @@ -3348,7 +3373,7 @@ function annotate!(mod, mode) if cf == fn LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0)) if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), LLVM.EnumAttribute("inaccessiblememonly", 0)) + LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), accattr) end end if !isa(cf, LLVM.Function) @@ -3363,7 +3388,12 @@ function annotate!(mod, mode) LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0)) LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), LLVM.EnumAttribute("inaccessiblememonly", 0)) + attr = if LLVM.version().major <= 15 + LLVM.EnumAttribute("inaccessiblememonly") + else + EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_ModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data) + end + LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), attr) end end end @@ -3372,14 +3402,22 @@ function annotate!(mod, mode) for gc in ("llvm.julia.gc_preserve_begin", "llvm.julia.gc_preserve_end") if haskey(fns, gc) fn = fns[gc] - push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly")) + else + push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_ModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + end end end for rfn in ("jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id") if haskey(fns, rfn) fn = fns[rfn] - push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + end end end @@ -3388,8 +3426,12 @@ function annotate!(mod, mode) if haskey(fns, rfn) fn = fns[rfn] push!(parameter_attributes(fn, 2), LLVM.StringAttribute("enzyme_inactive")) - push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) - push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) + push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) + else + push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_Ref << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + end end end # Key of jl_eqtable_get/put is inactive, definitionally @@ -3400,15 +3442,23 @@ function annotate!(mod, mode) push!(parameter_attributes(fn, 4), LLVM.StringAttribute("enzyme_inactive")) push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("writeonly")) push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("nocapture")) - push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) + else + push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_ModRef << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + end end end for rfn in ("jl_in_threaded_region_", "jl_in_threaded_region") if haskey(fns, rfn) fn = fns[rfn] - push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) - push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) + push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly")) + else + push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + end end end end @@ -4893,17 +4943,26 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if kind(prev) == kind(StringAttribute("enzyme_shouldrecompute")) push!(attributes, prev) end - if kind(prev) == kind(EnumAttribute("readonly")) - push!(attributes, prev) - end - if kind(prev) == kind(EnumAttribute("readnone")) - push!(attributes, prev) + if LLVM.version().major <= 15 + if kind(prev) == kind(EnumAttribute("readonly")) + push!(attributes, prev) + end + if kind(prev) == kind(EnumAttribute("readnone")) + push!(attributes, prev) + end + if kind(prev) == kind(EnumAttribute("argmemonly")) + push!(attributes, prev) + end + if kind(prev) == kind(EnumAttribute("inaccessiblememonly")) + push!(attributes, prev) + end end - if kind(prev) == kind(EnumAttribute("argmemonly")) - push!(attributes, prev) + if LLVM.version().major > 15 + if kind(prev) == kind(EnumAttribute("memory")) + old = MemoryEffect(value(attr)) + mem = MemoryEffect(( set_writing(getModRef(old, ArgMem)) << getLocationPos(ArgMem)) | (getModRef(old, InaccessibleMem) << getLocationPos(InaccessibleMem)) | (getModRef(old, Other) << getLocationPos(Other))) + push!(attributes, EnumAttribute("memory", mem.data)) end - if kind(prev) == kind(EnumAttribute("inaccessiblememonly")) - push!(attributes, prev) end if kind(prev) == kind(EnumAttribute("speculatable")) push!(attributes, prev) @@ -5382,44 +5441,85 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; sparam_vals = mi.specTypes.parameters[2:end] # mi.sparam_vals if func == typeof(Base.eps) || func == typeof(Base.nextfloat) || func == typeof(Base.prevfloat) - handleCustom(llvmfn, "jl_inactive_inout", [StringAttribute("enzyme_inactive"), - EnumAttribute("readnone", 0), - EnumAttribute("speculatable", 0), + if LLVM.version().major <= 15 + handleCustom(llvmfn, "jl_inactive_inout", [StringAttribute("enzyme_inactive"), + EnumAttribute("readnone"), + EnumAttribute("speculatable"), + StringAttribute("enzyme_shouldrecompute") + ]) + else + handleCustom(llvmfn, "jl_inactive_inout", [StringAttribute("enzyme_inactive"), + EnumAttribute("memory", NoEffects.data), + EnumAttribute("speculatable"), StringAttribute("enzyme_shouldrecompute") ]) + end continue end if func == typeof(Base.to_tuple_type) - handleCustom(llvmfn, "jl_to_tuple_type", - [EnumAttribute("readonly", 0), - EnumAttribute("inaccessiblememonly", 0), - EnumAttribute("speculatable", 0), - StringAttribute("enzyme_shouldrecompute"), - StringAttribute("enzyme_inactive"), - ]) + if LLVM.version().major <= 15 + handleCustom(llvmfn, "jl_to_tuple_type", + [EnumAttribute("readonly"), + EnumAttribute("inaccessiblememonly", 0), + EnumAttribute("speculatable", 0), + StringAttribute("enzyme_shouldrecompute"), + StringAttribute("enzyme_inactive"), + ]) + else + handleCustom(llvmfn, "jl_to_tuple_type", + [ + EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data), + EnumAttribute("inaccessiblememonly", 0), + EnumAttribute("speculatable", 0), + StringAttribute("enzyme_shouldrecompute"), + StringAttribute("enzyme_inactive"), + ]) + end continue end if func == typeof(Base.mightalias) - handleCustom(llvmfn, "jl_mightalias", - [EnumAttribute("readonly", 0), - StringAttribute("enzyme_shouldrecompute"), - StringAttribute("enzyme_inactive"), - StringAttribute("enzyme_no_escaping_allocation"), - EnumAttribute("nofree"), - StringAttribute("enzyme_ta_norecur"), - ], true, false) + if LLVM.version().major <= 15 + handleCustom(llvmfn, "jl_mightalias", + [EnumAttribute("readonly"), + StringAttribute("enzyme_shouldrecompute"), + StringAttribute("enzyme_inactive"), + StringAttribute("enzyme_no_escaping_allocation"), + EnumAttribute("nofree"), + StringAttribute("enzyme_ta_norecur"), + ], true, false) + else + handleCustom(llvmfn, "jl_mightalias", + [ + EnumAttribute("memory", ReadOnlyEffects.data), + StringAttribute("enzyme_shouldrecompute"), + StringAttribute("enzyme_inactive"), + StringAttribute("enzyme_no_escaping_allocation"), + EnumAttribute("nofree"), + StringAttribute("enzyme_ta_norecur"), + ], true, false) + end continue end if func == typeof(Base.Threads.threadid) || func == typeof(Base.Threads.nthreads) name = (func == typeof(Base.Threads.threadid)) ? "jl_threadid" : "jl_nthreads" - handleCustom(llvmfn, name, - [EnumAttribute("readonly", 0), - EnumAttribute("inaccessiblememonly", 0), - EnumAttribute("speculatable", 0), - StringAttribute("enzyme_shouldrecompute"), - StringAttribute("enzyme_inactive"), - StringAttribute("enzyme_no_escaping_allocation") - ]) + if LLVM.version().major <= 15 + handleCustom(llvmfn, name, + [EnumAttribute("readonly"), + EnumAttribute("inaccessiblememonly"), + EnumAttribute("speculatable"), + StringAttribute("enzyme_shouldrecompute"), + StringAttribute("enzyme_inactive"), + StringAttribute("enzyme_no_escaping_allocation") + ]) + else + handleCustom(llvmfn, name, + [EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data), + EnumAttribute("speculatable"), + StringAttribute("enzyme_shouldrecompute"), + StringAttribute("enzyme_inactive"), + StringAttribute("enzyme_no_escaping_allocation") + ]) + end continue end # Since this is noreturn and it can't write to any operations in the function @@ -5428,7 +5528,13 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; # fn, but it doesn't presently so for now we will ensure this by hand if func == typeof(Base.Checked.throw_overflowerr_binaryop) llvmfn = functions(mod)[k.specfunc] - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("readonly"), StringAttribute("enzyme_ta_norecur")]) + if LLVM.version().major <= 15 + handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("readonly"), StringAttribute("enzyme_ta_norecur")]) + else + handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), + EnumAttribute("memory", ReadOnlyEffects.data), + StringAttribute("enzyme_ta_norecur")]) + end continue end if EnzymeRules.is_inactive_from_sig(specTypes; world, method_table, caller) && has_method(Tuple{typeof(EnzymeRules.inactive), specTypes.parameters...}, world, method_table) @@ -5576,9 +5682,14 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; res = call!(builder, LLVM.function_type(llvmfn), llvmfn, collect(parameters(wrapper_f))) + sretkind = kind(if LLVM.version().major >= 12 + TypeAttribute("sret", LLVM.Int32Type()) + else + EnumAttribute("sret") + end) for idx in length(collect(parameters(llvmfn))) for attr in collect(parameter_attributes(llvmfn, idx)) - if kind(attr) == kind(EnumAttribute("sret")) + if kind(attr) == sretkind LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(idx), attr) end end @@ -5708,6 +5819,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_inactive")) end + TapeType::Type = Cvoid if params.run_enzyme diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 842ebe0fe3..dbda1b0880 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -1,4 +1,4 @@ -mutable struct PipelineConfig +struct PipelineConfig Speedup::Cint Size::Cint lower_intrinsics::Cint @@ -17,7 +17,8 @@ end const RunAttributor = Ref(true) -function pipeline_options(; lower_intrinsics=true, dump_native=false, external_use=false, llvm_only=false, always_inline=true, enalbe_early_simplifications=true, +function pipeline_options(; lower_intrinsics=true, dump_native=false, external_use=false, llvm_only=false, always_inline=true, enable_early_simplifications=true, + enable_early_optimizations=true, enable_scalar_optimizations=true, enable_loop_optimizations=true, enable_vector_pipeline=true, @@ -26,6 +27,255 @@ function pipeline_options(; lower_intrinsics=true, dump_native=false, external_u return PipelineConfig(Speedup, Size, lower_intrinsics, dump_native, external_use, llvm_only, always_inline, enable_early_simplifications, enable_early_optimizations, enable_scalar_optimizations, enable_loop_optimizations, enable_vector_pipeline, remove_ni, cleanup) end +function run_jl_pipeline(pm, tm; kwargs...) + config = Ref(pipeline_options(;kwargs...)) + function jl_pipeline(m) + @dispose pb=PassBuilder(tm) begin + NewPMModulePassManager(pb) do mpm + @ccall jl_build_newpm_pipeline(mpm.ref::Ptr{Cvoid}, pb.ref::Ptr{Cvoid}, config::Ptr{PipelineConfig})::Cvoid + run!(mpm, m, tm) + end + end + return true + end + add!(pm, ModulePass("JLPipeline", jl_pipeline)) +end + +@static if VERSION < v"1.11.0-DEV.428" +else + barrier_noop!(pm) = nothing +end + +@static if VERSION < v"1.11-" + function gc_invariant_verifier_tm!(pm, tm, cond) + gc_invariant_verifier!(pm, cond) + end +else + function gc_invariant_verifier_tm!(pm, tm, cond) + function gc_invariant_verifier(f) + @dispose pb=PassBuilder(tm) begin + NewPMFunctionPassManager(pb) do fpm + add!(fpm, GCInvariantVerifierPass(GCInvariantVerifierPassOptions(;strong=cond))) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, FunctionPass("GCInvariantVerifier", gc_invariant_verifier)) + end +end + +@static if VERSION < v"1.11-" + function propagate_julia_addrsp_tm!(pm, tm) + propagate_julia_addrsp!(pm) + end +else + function propagate_julia_addrsp_tm!(pm, tm) + function prop_julia_addr(f) + @dispose pb=PassBuilder(tm) begin + NewPMFunctionPassManager(pb) do fpm + add!(fpm, PropagateJuliaAddrspacesPass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, FunctionPass("PropagateJuliaAddrSpace", prop_julia_addr)) + end +end + +@static if VERSION < v"1.11-" + function alloc_opt_tm!(pm, tm) + alloc_opt!(pm) + end +else + function alloc_opt_tm!(pm, tm) + function alloc_opt(f) + @dispose pb=PassBuilder(tm) begin + NewPMFunctionPassManager(pb) do fpm + add!(fpm, AllocOptPass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, FunctionPass("AllocOpt", alloc_opt)) + end +end + +@static if VERSION < v"1.11-" + function remove_ni_tm!(pm, tm) + remove_ni!(pm) + end +else + function remove_ni_tm!(pm, tm) + function remove_ni(f) + @dispose pb=PassBuilder(tm) begin + NewPMModulePassManager(pb) do fpm + add!(fpm, RemoveNIPass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, ModulePass("RemoveNI", remove_ni)) + end +end + +@static if VERSION < v"1.11-" + function julia_licm_tm!(pm, tm) + julia_licm!(pm) + end +else + function julia_licm_tm!(pm, tm) + function julia_licm(f) + @dispose pb=PassBuilder(tm) begin + NewPMLoopPassManager(pb) do fpm + add!(fpm, JuliaLICMPass()) + run!(fpm, f, tm) + end + end + return true + end + # really looppass + add!(pm, FunctionPass("JuliaLICM", julia_licm)) + end +end + +@static if VERSION < v"1.11-" + function lower_simdloop_tm!(pm, tm) + lower_simdloop!(pm) + end +else + function lower_simdloop_tm!(pm, tm) + function lower_simdloop(f) + @dispose pb=PassBuilder(tm) begin + NewPMLoopPassManager(pb) do fpm + add!(fpm, LowerSIMDLoopPass()) + run!(fpm, f, tm) + end + end + return true + end + # really looppass + add!(pm, FunctionPass("LowerSIMDLoop", lower_simdloop)) + end +end + +@static if VERSION < v"1.11-" + function demote_float16_tm!(pm, tm) + demote_float16!(pm) + end +else + function demote_float16_tm!(pm, tm) + function demote_float16(f) + @dispose pb=PassBuilder(tm) begin + NewPMFunctionPassManager(pb) do fpm + add!(fpm, DemoteFloat16Pass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, FunctionPass("DemoteFloat16", demote_float16)) + end +end + +@static if VERSION < v"1.11-" + function lower_exc_handlers_tm!(pm, tm) + lower_exc_handlers!(pm) + end +else + function lower_exc_handlers_tm!(pm, tm) + function lower_exc_handlers(f) + @dispose pb=PassBuilder(tm) begin + NewPMFunctionPassManager(pb) do fpm + add!(fpm, LowerExcHandlersPass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, FunctionPass("LowerExcHandlers", lower_exc_handlers)) + end +end + +@static if VERSION < v"1.11-" + function lower_ptls_tm!(pm, tm, dump_native) + lower_ptls!(pm, dump_native) + end +else + function lower_ptls_tm!(pm, tm, dump_native) + function lower_ptls(f) + @dispose pb=PassBuilder(tm) begin + NewPMModulePassManager(pb) do fpm + add!(fpm, LowerPTLSPass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, ModulePass("LowerPTLS", lower_ptls)) + end +end + +@static if VERSION < v"1.11-" + function combine_mul_add_tm!(pm, tm) + combine_mul_add!(pm) + end +else + function combine_mul_add_tm!(pm, tm) + function combine_mul_add(f) + @dispose pb=PassBuilder(tm) begin + NewPMFunctionPassManager(pb) do fpm + add!(fpm, CombineMulAddPass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, FunctionPass("CombineMulAdd", combine_mul_add)) + end +end + +@static if VERSION < v"1.11-" + function late_lower_gc_frame_tm!(pm, tm) + late_lower_gc_frame!(pm) + end +else + function late_lower_gc_frame_tm!(pm, tm) + function late_lower_gc_frame(f) + @dispose pb=PassBuilder(tm) begin + NewPMFunctionPassManager(pb) do fpm + add!(fpm, LateLowerGCPass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, FunctionPass("LateLowerGCFrame", late_lower_gc_frame)) + end +end + +@static if VERSION < v"1.11-" + function final_lower_gc_tm!(pm, tm) + final_lower_gc!(pm) + end +else + function final_lower_gc_tm!(pm, tm) + function final_lower_gc(f) + @dispose pb=PassBuilder(tm) begin + NewPMFunctionPassManager(pb) do fpm + add!(fpm, FinalLowerGCPass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, FunctionPass("FinalLowerGCFrame", final_lower_gc)) + end +end + function addNA(inst, node::LLVM.Metadata, MD) md = metadata(inst) next = nothing @@ -626,6 +876,11 @@ function fix_decayaddr!(mod::LLVM.Module) mayread = false maywrite = false sret = true + sretkind = kind(if LLVM.version().major >= 12 + TypeAttribute("sret", LLVM.Int32Type()) + else + EnumAttribute("sret") + end) for (i, v) in enumerate(operands(st)[1:end-1]) if v == inst readnone = false @@ -633,7 +888,7 @@ function fix_decayaddr!(mod::LLVM.Module) writeonly = false t_sret = false for a in collect(parameter_attributes(fop, i)) - if kind(a) == kind(EnumAttribute("sret")) + if kind(a) == sretkind t_sret = true end if kind(a) == kind(StringAttribute("enzyme_sret")) @@ -803,7 +1058,7 @@ function prop_global!(g) end # From https://llvm.org/doxygen/IR_2Instruction_8cpp_source.html#l00959 -function mayWriteToMemory(inst::LLVM.Instruction)::Bool +function mayWriteToMemory(inst::LLVM.Instruction; err_is_readonly=false)::Bool # we will ignore fense here if isa(inst, LLVM.StoreInst) return true @@ -838,9 +1093,14 @@ function mayWriteToMemory(inst::LLVM.Instruction)::Bool return false end # Note out of spec, and only legal in context of removing unused calls - if kind(attr) == kind(StringAttribute("enzyme_error")) + if kind(attr) == kind(StringAttribute("enzyme_error")) && err_is_readonly return false end + if kind(attr) == kind(StringAttribute("memory")) + if is_readonly(MemoryEffect(value(attr))) + return false + end + end end Libc.free(Attrs) return true @@ -887,8 +1147,7 @@ function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String}) end push!(done, cur) - attrs = collect(function_attributes(cur)) - if any(kind(attr) == kind(EnumAttribute("readonly")) for attr in attrs) || any(kind(attr) == kind(EnumAttribute("readnone")) for attr in attrs) + if is_readonly(cur) continue end @@ -901,7 +1160,7 @@ function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String}) end for bb in blocks(cur) for inst in instructions(bb) - if !mayWriteToMemory(inst) + if !mayWriteToMemory(inst; err_is_readonly=true) continue end if isa(inst, LLVM.CallInst) @@ -917,17 +1176,7 @@ function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String}) end end - changed = false - attrs = collect(function_attributes(fn)) - if !any(kind(attr) == kind(EnumAttribute("readonly")) for attr in attrs) && !any(kind(attr) == kind(EnumAttribute("readnone")) for attr in attrs) - if any(kind(attr) == kind(EnumAttribute("writeonly")) for attr in attrs) - delete!(function_attributes(fn), EnumAttribute("writeonly")) - push!(function_attributes(fn), EnumAttribute("readnone")) - else - push!(function_attributes(fn), EnumAttribute("readonly")) - end - changed = true - end + changed = set_readonly!(fn) if length(calls) == 0 || hasUser return changed @@ -1345,6 +1594,11 @@ function validate_return_roots!(mod) enzyme_srets_v = Int[] rroots = Int[] rroots_v = Int[] + sretkind = kind(if LLVM.version().major >= 12 + TypeAttribute("sret", LLVM.Int32Type()) + else + EnumAttribute("sret") + end) for (i, a) in enumerate(parameters(f)) for attr in collect(parameter_attributes(f, i)) if isa(attr, StringAttribute) @@ -1361,7 +1615,7 @@ function validate_return_roots!(mod) push!(enzyme_srets, i) end end - if kind(attr) == kind(EnumAttribute("sret")) + if kind(attr) == sretkind push!(srets, (i, attr)) end end @@ -1519,7 +1773,7 @@ end cse!(pm) = LLVM.API.LLVMAddEarlyCSEPass(pm) -function removeDeadArgs!(mod::LLVM.Module) +function removeDeadArgs!(mod::LLVM.Module, tm) # We need to run globalopt first. This is because remove dead args will otherwise # take internal functions and replace their args with undef. Then on LLVM up to # and including 12 (but fixed 13+), Attributor will incorrectly change functions that @@ -1531,10 +1785,16 @@ function removeDeadArgs!(mod::LLVM.Module) end # Prevent dead-arg-elimination of functions which we may require args for in the derivative funcT = LLVM.FunctionType(LLVM.VoidType(), LLVMType[], vararg=true) - func, _ = get_function!(mod, "llvm.enzymefakeuse", funcT, [EnumAttribute("readnone"), EnumAttribute("nofree")]) - rfunc, _ = get_function!(mod, "llvm.enzymefakeread", funcT, [EnumAttribute("readonly"), EnumAttribute("nofree"), EnumAttribute("argmemonly")]) - sfunc, _ = get_function!(mod, "llvm.enzyme.sret_use", funcT, [EnumAttribute("readonly"), EnumAttribute("nofree"), EnumAttribute("argmemonly")]) - + if LLVM.version().major <= 15 + func, _ = get_function!(mod, "llvm.enzymefakeuse", funcT, [EnumAttribute("readnone"), EnumAttribute("nofree")]) + rfunc, _ = get_function!(mod, "llvm.enzymefakeread", funcT, [EnumAttribute("readonly"), EnumAttribute("nofree"), EnumAttribute("argmemonly")]) + sfunc, _ = get_function!(mod, "llvm.enzyme.sret_use", funcT, [EnumAttribute("readonly"), EnumAttribute("nofree"), EnumAttribute("argmemonly")]) + else + func, _ = get_function!(mod, "llvm.enzymefakeuse", funcT, [EnumAttribute("memory", NoEffects.data), EnumAttribute("nofree")]) + rfunc, _ = get_function!(mod, "llvm.enzymefakeread", funcT, [EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")]) + sfunc, _ = get_function!(mod, "llvm.enzyme.sret_use", funcT, [EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")]) + end + for fn in functions(mod) if isempty(blocks(fn)) continue @@ -1561,12 +1821,17 @@ function removeDeadArgs!(mod::LLVM.Module) end end end + sretkind = kind(if LLVM.version().major >= 12 + TypeAttribute("sret", LLVM.Int32Type()) + else + EnumAttribute("sret") + end) for idx in (1, 2) if length(collect(parameters(fn))) < idx continue end attrs = collect(parameter_attributes(fn, idx)) - if any( ( kind(attr) == kind(EnumAttribute("sret")) || kind(attr) == kind(StringAttribute("enzyme_sret")) || kind(attr) == kind(StringAttribute("enzyme_sret_v")) ) for attr in attrs) + if any( ( kind(attr) == sretkind || kind(attr) == kind(StringAttribute("enzyme_sret")) || kind(attr) == kind(StringAttribute("enzyme_sret_v")) ) for attr in attrs) for u in LLVM.uses(fn) u = LLVM.user(u) if isa(u, LLVM.ConstantExpr) @@ -1602,7 +1867,7 @@ function removeDeadArgs!(mod::LLVM.Module) ModulePassManager() do pm instruction_combining!(pm) jl_inst_simplify!(pm) - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) scalar_repl_aggregates_ssa!(pm) # SSA variant? cse!(pm) run!(pm, mod) @@ -1621,7 +1886,7 @@ function removeDeadArgs!(mod::LLVM.Module) ModulePassManager() do pm instruction_combining!(pm) jl_inst_simplify!(pm) - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) scalar_repl_aggregates_ssa!(pm) # SSA variant? if RunAttributor[] if LLVM.version().major >= 13 @@ -1666,7 +1931,7 @@ function optimize!(mod::LLVM.Module, tm) add_library_info!(pm, triple(mod)) add_transform_info!(pm, tm) - propagate_julia_addrsp!(pm) + propagate_julia_addrsp_tm!(pm, tm) scoped_no_alias_aa!(pm) type_based_alias_analysis!(pm) basic_alias_analysis!(pm) @@ -1678,7 +1943,7 @@ end scalar_repl_aggregates_ssa!(pm) # SSA variant? mem_cpy_opt!(pm) always_inliner!(pm) - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) LLVM.API.LLVMAddGlobalOptimizerPass(pm) # Extra gvn!(pm) # Extra instruction_combining!(pm) @@ -1693,22 +1958,28 @@ end jl_inst_simplify!(pm) reassociate!(pm) early_cse!(pm) - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) loop_idiom!(pm) loop_rotate!(pm) - lower_simdloop!(pm) - licm!(pm) - if LLVM.version() >= v"15" - simple_loop_unswitch_legacy!(pm) + + if VERSION < v"1.11-" + lower_simdloop_tm!(pm, tm) + licm!(pm) + if LLVM.version() >= v"15" + simple_loop_unswitch_legacy!(pm) + else + loop_unswitch!(pm) + end else - loop_unswitch!(pm) + run_jl_pipeline(pm, tm; lower_intrinsics=false, dump_native=false, external_use=false, llvm_only=false, always_inline=false, enable_early_simplifications=false, enable_early_optimizations=false, enable_scalar_optimizations=false, enable_loop_optimizations=true, enable_vector_pipeline=false, remove_ni=false, cleanup=false) end + instruction_combining!(pm) jl_inst_simplify!(pm) ind_var_simplify!(pm) loop_deletion!(pm) loop_unroll!(pm) - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) scalar_repl_aggregates_ssa!(pm) # SSA variant? gvn!(pm) @@ -1722,7 +1993,7 @@ end jl_inst_simplify!(pm) jump_threading!(pm) dead_store_elimination!(pm) - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) cfgsimplification!(pm) loop_idiom!(pm) loop_deletion!(pm) @@ -1740,7 +2011,7 @@ end # GC passes barrier_noop!(pm) - gc_invariant_verifier!(pm, false) + gc_invariant_verifier_tm!(pm, tm, false) # FIXME: Currently crashes printing cfgsimplification!(pm) @@ -1750,7 +2021,7 @@ end gvn!(pm) # Exxtra run!(pm, mod) end - removeDeadArgs!(mod) + removeDeadArgs!(mod, tm) detect_writeonly!(mod) nodecayed_phis!(mod) end @@ -1762,12 +2033,12 @@ function addTargetPasses!(pm, tm, trip) end # https://github.com/JuliaLang/julia/blob/2eb5da0e25756c33d1845348836a0a92984861ac/src/aotcompile.cpp#L620 -function addOptimizationPasses!(pm) +function addOptimizationPasses!(pm, tm) add!(pm, FunctionPass("ReinsertGCMarker", reinsert_gcmarker_pass!)) constant_merge!(pm) - propagate_julia_addrsp!(pm) + propagate_julia_addrsp_tm!(pm, tm) scoped_no_alias_aa!(pm) type_based_alias_analysis!(pm) basic_alias_analysis!(pm) @@ -1783,7 +2054,7 @@ function addOptimizationPasses!(pm) # merging the `alloca` for the unboxed data and the `alloca` created by the `alloc_opt` # pass. - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) # consider AggressiveInstCombinePass at optlevel > 2 instruction_combining!(pm) @@ -1801,24 +2072,46 @@ function addOptimizationPasses!(pm) # Load forwarding above can expose allocations that aren't actually used # remove those before optimizing loops. - alloc_opt!(pm) - loop_rotate!(pm) - # moving IndVarSimplify here prevented removing the loop in perf_sumcartesian(10:-1:1) - loop_idiom!(pm) - - # LoopRotate strips metadata from terminator, so run LowerSIMD afterwards - lower_simdloop!(pm) # Annotate loop marked with "loopinfo" as LLVM parallel loop - licm!(pm) - julia_licm!(pm) - # Subsequent passes not stripping metadata from terminator - instruction_combining!(pm) # TODO: createInstSimplifyLegacy - jl_inst_simplify!(pm) - ind_var_simplify!(pm) - loop_deletion!(pm) - loop_unroll!(pm) # TODO: in Julia createSimpleLoopUnroll + alloc_opt_tm!(pm, tm) + + + if VERSION < v"1.11-" + loop_rotate!(pm) + # moving IndVarSimplify here prevented removing the loop in perf_sumcartesian(10:-1:1) + loop_idiom!(pm) + + # LoopRotate strips metadata from terminator, so run LowerSIMD afterwards + lower_simdloop_tm!(pm, tm) # Annotate loop marked with "loopinfo" as LLVM parallel loop + licm!(pm) + julia_licm_tm!(pm, tm) + # Subsequent passes not stripping metadata from terminator + instruction_combining!(pm) # TODO: createInstSimplifyLegacy + jl_inst_simplify!(pm) + + ind_var_simplify!(pm) + loop_deletion!(pm) + loop_unroll!(pm) # TODO: in Julia createSimpleLoopUnroll + else + # LowerSIMDLoopPass + # LoopRotatePass [opt >= 2] + # LICMPass + # JuliaLICMPass + # SimpleLoopUnswitchPass + # LICMPass + # JuliaLICMPass + # IRCEPass + # LoopInstSimplifyPass + # - in ours this is instcombine with jlinstsimplify + # LoopIdiomRecognizePass + # IndVarSimplifyPass + # LoopDeletionPass + # LoopFullUnrollPass + run_jl_pipeline(pm, tm; lower_intrinsics=false, dump_native=false, external_use=false, llvm_only=false, always_inline=false, enable_early_simplifications=false, enable_early_optimizations=false, enable_scalar_optimizations=false, enable_loop_optimizations=true, enable_vector_pipeline=false, remove_ni=false, cleanup=false) + end + # Run our own SROA on heap objects before LLVM's - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) # Re-run SROA after loop-unrolling (useful for small loops that operate, # over the structure of an aggregate) scalar_repl_aggregates!(pm) @@ -1840,7 +2133,7 @@ function addOptimizationPasses!(pm) # More dead allocation (store) deletion before loop optimization # consider removing this: - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) # see if all of the constant folding has exposed more loops # to simplification and deletion @@ -1859,31 +2152,31 @@ function addOptimizationPasses!(pm) aggressive_dce!(pm) end -function addMachinePasses!(pm) - combine_mul_add!(pm) +function addMachinePasses!(pm, tm) + combine_mul_add_tm!(pm, tm) # TODO: createDivRemPairs[] - demote_float16!(pm) + demote_float16_tm!(pm, tm) gvn!(pm) end -function addJuliaLegalizationPasses!(pm, lower_intrinsics=true) +function addJuliaLegalizationPasses!(pm, tm, lower_intrinsics=true) if lower_intrinsics # LowerPTLS removes an indirect call. As a result, it is likely to trigger # LLVM's devirtualization heuristics, which would result in the entire # pass pipeline being re-exectuted. Prevent this by inserting a barrier. barrier_noop!(pm) add!(pm, FunctionPass("ReinsertGCMarker", reinsert_gcmarker_pass!)) - lower_exc_handlers!(pm) + lower_exc_handlers_tm!(pm, tm) # BUDE.jl demonstrates a bug here TODO - gc_invariant_verifier!(pm, false) + gc_invariant_verifier_tm!(pm, tm, false) verifier!(pm) # Needed **before** LateLowerGCFrame on LLVM < 12 # due to bug in `CreateAlignmentAssumption`. - remove_ni!(pm) - late_lower_gc_frame!(pm) - final_lower_gc!(pm) + remove_ni_tm!(pm, tm) + late_lower_gc_frame_tm!(pm, tm) + final_lower_gc_tm!(pm, tm) # We need these two passes and the instcombine below # after GC lowering to let LLVM do some constant propagation on the tags. # and remove some unnecessary write barrier checks. @@ -1891,20 +2184,20 @@ function addJuliaLegalizationPasses!(pm, lower_intrinsics=true) sccp!(pm) # Remove dead use of ptls dce!(pm) - lower_ptls!(pm, #=dump_native=# false) + lower_ptls_tm!(pm, tm, #=dump_native=# false) instruction_combining!(pm) jl_inst_simplify!(pm) # Clean up write barrier and ptls lowering cfgsimplification!(pm) else barrier_noop!(pm) - remove_ni!(pm) + remove_ni_tm!(pm, tm) end end function post_optimze!(mod, tm, machine=true) addr13NoAlias(mod) - removeDeadArgs!(mod) + removeDeadArgs!(mod, tm) for f in collect(functions(mod)) API.EnzymeFixupJuliaCallingConvention(f) end @@ -1914,15 +2207,15 @@ function post_optimze!(mod, tm, machine=true) end LLVM.ModulePassManager() do pm addTargetPasses!(pm, tm, LLVM.triple(mod)) - addOptimizationPasses!(pm) + addOptimizationPasses!(pm, tm) run!(pm, mod) end if machine # TODO enable validate_return_roots # validate_return_roots!(mod) LLVM.ModulePassManager() do pm - addJuliaLegalizationPasses!(pm, true) - addMachinePasses!(pm) + addJuliaLegalizationPasses!(pm, tm, true) + addMachinePasses!(pm, tm) run!(pm, mod) end end diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index 4b38256e61..b5bdb3afa2 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -1,3 +1,185 @@ +struct MemoryEffect + data::UInt32 +end + +@enum(ModRefInfo, + MRI_NoModRef = 0, + MRI_Ref = 1, + MRI_Mod = 2, + MRI_ModRef = 3) + +@enum(IRMemLocation, + ArgMem = 0, + InaccessibleMem = 1, + Other = 2) + +const BitsPerLoc = UInt32(2) +const LocMask = UInt32((1 << BitsPerLoc) - 1) +function getLocationPos(Loc::IRMemLocation) + return UInt32(Loc) * BitsPerLoc +end +function Base.:<<(mr::ModRefInfo, rhs::UInt32) + UInt32(mr) << rhs +end +function Base.:|(lhs::ModRefInfo, rhs::ModRefInfo) + ModRefInfo(UInt32(lhs) | UInt32(rhs)) +end +function Base.:&(lhs::ModRefInfo, rhs::ModRefInfo) + ModRefInfo(UInt32(lhs) & UInt32(rhs)) +end +const AllEffects = MemoryEffect((MRI_ModRef << getLocationPos(ArgMem)) | (MRI_ModRef << getLocationPos(InaccessibleMem)) | (MRI_ModRef << getLocationPos(Other))) +const ReadOnlyEffects = MemoryEffect((MRI_Ref << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_Ref << getLocationPos(Other))) +const ReadOnlyArgMemEffects = MemoryEffect((MRI_Ref << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))) +const NoEffects = MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))) + +# Get ModRefInfo for any location. +function getModRef(effect::MemoryEffect, loc::IRMemLocation)::ModRefInfo + ModRefInfo((effect.data >> getLocationPos(loc)) & LocMask) +end + +function getModRef(effect::MemoryEffect)::ModRefInfo + cur = MRI_NoModRef + for loc in (ArgMem, InaccessibleMem, Other) + cur |= getModRef(effect, loc) + end + return cur +end + +function setModRef(effect::MemoryEffect, Loc::IRMemLocation, MR::ModRefInfo)::MemoryEffect + data = effect.data + Data &= ~(LocMask << getLocationPos(Loc)) + Data |= MR << getLocationPos(Loc) + return MemoryEffect(data) +end + +function setModRef(effect::MemoryEffect)::MemoryEffect + for loc in (ArgMem, InaccessibleMem, Other) + effect = setModRef(effect, mri)= getModRef(effect, loc) + end + return effect +end + +function set_readonly(mri::ModRefInfo) + return mri & MRI_Ref +end +function set_writeonly(mri::ModRefInfo) + return mri & MRI_Mod +end +function set_reading(mri::ModRefInfo) + return mri | MRI_Ref +end +function set_writing(mri::ModRefInfo) + return mri | MRI_Mod +end + +function set_readonly(effect::MemoryEffect) + data = UInt32(0) + for loc in (ArgMem, InaccessibleMem, Other) + data = UInt32(set_readonly(getModRef(effect, loc))) << getLocationPos(loc) + end + return MemoryEffect(data) +end + +function is_readonly(mri::ModRefInfo) + return mri == MRI_NoModRef || mri == MRI_Ref +end + +function is_readnone(mri::ModRefInfo) + return mri == MRI_NoModRef +end + +function is_writeonly(mri::ModRefInfo) + return mri == MRI_NoModRef || mri == MRI_Mod +end + +for n in (:is_readonly, :is_readnone, :is_writeonly) +@eval begin + function $n(memeffect::MemoryEffect) + return $n(getModRef(memeffect)) + end +end +end + +function is_readonly(f::LLVM.Function) + for attr in collect(function_attributes(f)) + if kind(attr) == kind(EnumAttribute("readonly")) + return true + end + if kind(attr) == kind(EnumAttribute("readnone")) + return true + end + if LLVM.version().major > 15 + if kind(attr) == kind(EnumAttribute("memory")) + if is_readonly(MemoryEffect(value(attr))) + return true + end + end + end + end + return false +end + +function is_readnone(f::LLVM.Function) + for attr in collect(function_attributes(cur)) + if kind(attr) == kind(EnumAttribute("readnone")) + return true + end + if LLVM.version().major > 15 + if kind(attr) == kind(EnumAttribute("memory")) + if is_readnone(MemoryEffect(value(attr))) + return true + end + end + end + end + return false +end + +function is_writeonly(f::LLVM.Function) + for attr in collect(function_attributes(cur)) + if kind(attr) == kind(EnumAttribute("readnone")) + return true + end + if kind(attr) == kind(EnumAttribute("writeonly")) + return true + end + if LLVM.version().major > 15 + if kind(attr) == kind(EnumAttribute("memory")) + if is_writeonly(MemoryEffect(value(attr))) + return true + end + end + end + end + return false +end + +function set_readonly!(fn::LLVM.Function) + attrs = collect(function_attributes(fn)) + if LLVM.version().major <= 15 + if !any(kind(attr) == kind(EnumAttribute("readonly")) for attr in attrs) && !any(kind(attr) == kind(EnumAttribute("readnone")) for attr in attrs) + if any(kind(attr) == kind(EnumAttribute("writeonly")) for attr in attrs) + delete!(function_attributes(fn), EnumAttribute("writeonly")) + push!(function_attributes(fn), EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("readonly")) + end + return true + end + return false + else + for attr in attrs + if kind(attr) == kind(EnumAttribute("memory")) + old = MemoryEffect(value(attr)) + eff = set_readonly(old) + push!(function_attributes(fn), EnumAttribute("memory", eff.data)) + return old != eff + end + end + push!(function_attributes(fn), EnumAttribute("memory", set_readonly(AllEffects).data)) + return true + end +end function get_function!(mod::LLVM.Module, name::AbstractString, FT::LLVM.FunctionType, attrs=[]) if haskey(functions(mod), name) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 8715ce1991..951d527d6d 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -385,27 +385,90 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) nval = ptrtoint!(b, call!(b, LLVM.function_type(mfn2), mfn2, [LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(inst, 0))]), value_type(inst)) replace_uses!(inst, nval) LLVM.API.LLVMInstructionEraseFromParent(inst) - elseif fn == "jl_load_and_lookup" + elseif fn == "jl_load_and_lookup" || fn == "ijl_load_and_lookup" ofn = LLVM.parent(LLVM.parent(inst)) mod = LLVM.parent(ofn) - legal, flib = abs_cstring(operands(inst)[1]) - legal2, fname = abs_cstring(operands(inst)[2]) - legal &= legal2 + arg1 = operands(inst)[1] - hnd = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(inst, 2)) - if isa(hnd, LLVM.GlobalVariable) - hnd = LLVM.name(hnd) - else - legal = false + while isa(arg1, ConstantExpr) + if opcode(arg1) == LLVM.API.LLVMAddrSpaceCast || opcode(arg1) == LLVM.API.LLVMBitCast || opcode(arg1) == LLVM.API.LLVMIntToPtr + arg1 = operands(arg1)[1] + else + break + end end + if isa(arg1, LLVM.ConstantInt) + arg1 = reinterpret(Ptr{Cvoid}, convert(UInt, arg1)) + legal2, fname = abs_cstring(operands(inst)[2]) + if legal2 + hnd = operands(inst)[3] + if isa(hnd, LLVM.GlobalVariable) + hnd = LLVM.name(hnd) + if fn == "jl_lazy_load_and_lookup" + res = ccall(:jl_load_and_lookup, Ptr{Cvoid}, (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), arg1, fname, reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr)) + else + res = ccall(:ijl_load_and_lookup, Ptr{Cvoid}, (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), arg1, fname, reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr)) + end + replaceWith = LLVM.ConstantInt(LLVM.IntType(8*sizeof(Int)), reinterpret(UInt, res)) + for u in LLVM.uses(inst) + st = LLVM.user(u) + if isa(st, LLVM.StoreInst) && LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 0)) == inst + ptr = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 1)) + for u in LLVM.uses(ptr) + ld = LLVM.user(u) + if isa(ld, LLVM.LoadInst) + b = IRBuilder() + position!(b, ld) + for u in LLVM.uses(ld) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end + end + replace_uses!(ld, LLVM.inttoptr!(b, replaceWith, value_type(inst))) + end + end + end + end - if !legal - return + b = IRBuilder() + position!(b, inst) + replacement = LLVM.inttoptr!(b, replaceWith, value_type(inst)) + for u in LLVM.uses(inst) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end + if isa(u, LLVM.PHIInst) + if all(x->first(x) == inst || first(x) == replacement, LLVM.incoming(u)) + + for u in LLVM.uses(u) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end + if isa(u, LLVM.BitCastInst) + for u1 in LLVM.uses(u) + u1 = LLVM.user(u1) + if isa(u1, LLVM.CallInst) + push!(calls, u1) + end + end + replace_uses!(u, LLVM.inttoptr!(b, replaceWith, value_type(u))) + end + end + end + end + end + replace_uses!(inst, replacement) + LLVM.API.LLVMInstructionEraseFromParent(inst) + end + end end - # res = ccall(:jl_load_and_lookup, Ptr{Cvoid}, (Cstring, Cstring, Ptr{Cvoid}), flib, fname, cglobal(Symbol(hnd))) - push!(errors, ("jl_load_and_lookup", bt, nothing)) + + elseif fn == "jl_lazy_load_and_lookup" || fn == "ijl_lazy_load_and_lookup" ofn = LLVM.parent(LLVM.parent(inst)) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 37208f4afc..d0146a64e2 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -473,6 +473,54 @@ function arrayreshape_rev(B, orig, gutils, tape) return nothing end +function gcloaded_fwd(B, orig, gutils, normalR, shadowR) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) + return true + end + + origops = LLVM.operands(orig) + if is_constant_value(gutils, origops[1]) + emit_error(B, orig, "Enzyme: gcloaded has active return, but inactive input(1)") + end + if is_constant_value(gutils, origops[2]) + emit_error(B, orig, "Enzyme: gcloaded has active return, but inactive input(2)") + end + + width = get_width(gutils) + + shadowin1 = invert_pointer(gutils, origops[1], B) + shadowin2 = invert_pointer(gutils, origops[2], B) + if width == 1 + args = LLVM.Value[shadowin1, shadowin2] + shadowres = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Shadow], #=lookup=#false) + else + shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx in 1:width + args = LLVM.Value[ + extract_value!(B, shadowin1, idx-1) + extract_value!(B, shadowin2, idx-1) + ] + tmp = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Shadow], #=lookup=#false) + shadowres = insert_value!(B, shadowres, tmp, idx-1) + end + end + unsafe_store!(shadowR, shadowres.ref) + + return false +end + +function gcloaded_augfwd(B, orig, gutils, normalR, shadowR, tapeR) + gcloaded_fwd(B, orig, gutils, normalR, shadowR) +end + +function gcloaded_rev(B, orig, gutils, tape) + return nothing +end + function boxfloat_fwd(B, orig, gutils, normalR, shadowR) origops = collect(operands(orig)) width = get_width(gutils) @@ -1206,6 +1254,12 @@ end @revfunc(jlcall2_rev), @fwdfunc(jlcall2_fwd), ) + register_handler!( + ("julia.gc_loaded",), + @augfunc(gcloaded_augfwd), + @revfunc(gcloaded_rev), + @fwdfunc(gcloaded_fwd), + ) register_handler!( ("jl_apply_generic", "ijl_apply_generic"), @augfunc(generic_augfwd), diff --git a/src/typetree.jl b/src/typetree.jl index 79ca41cd81..2c846ae49e 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -152,28 +152,46 @@ function typetree_inner(::Type{<:Union{Ptr{T},Core.LLVMPtr{T}}}, ctx, dl, return tt end -function typetree_inner(::Type{<:Array{T}}, ctx, dl, seen::TypeTreeTable) where {T} - offset = 0 - - tt = copy(typetree(T, ctx, dl, seen)) - if !allocatedinline(T) +@static if VERSION < v"1.11-" + function typetree_inner(::Type{<:Array{T}}, ctx, dl, seen::TypeTreeTable) where {T} + offset = 0 + + tt = copy(typetree(T, ctx, dl, seen)) + if !allocatedinline(T) + merge!(tt, TypeTree(API.DT_Pointer, ctx)) + only!(tt, 0) + end merge!(tt, TypeTree(API.DT_Pointer, ctx)) - only!(tt, 0) - end - merge!(tt, TypeTree(API.DT_Pointer, ctx)) - only!(tt, offset) + only!(tt, offset) + + offset += sizeof(Ptr{Cvoid}) - offset += sizeof(Ptr{Cvoid}) + sizeofstruct = offset + 2 + 2 + 4 + 2 * sizeof(Csize_t) + if true # STORE_ARRAY_LEN + sizeofstruct += sizeof(Csize_t) + end - sizeofstruct = offset + 2 + 2 + 4 + 2 * sizeof(Csize_t) - if true # STORE_ARRAY_LEN - sizeofstruct += sizeof(Csize_t) + for i in offset:(sizeofstruct-1) + merge!(tt, TypeTree(API.DT_Integer, i, ctx)) + end + return tt end +else + function typetree_inner(::Type{<:GenericMemory{kind, T}}, ctx, dl, seen::TypeTreeTable) where {kind, T} + offset = 0 + tt = copy(typetree(T, ctx, dl, seen)) + if !allocatedinline(T) + merge!(tt, TypeTree(API.DT_Pointer, ctx)) + only!(tt, 0) + end + merge!(tt, TypeTree(API.DT_Pointer, ctx)) + only!(tt, sizeof(Csize_t)) - for i in offset:(sizeofstruct-1) - merge!(tt, TypeTree(API.DT_Integer, i, ctx)) + for i in 0:(sizeof(Csize_t)-1) + merge!(tt, TypeTree(API.DT_Integer, i, ctx)) + end + return tt end - return tt end if VERSION >= v"1.7.0-DEV.204" From eb79121ddc5dc94aa66a4c51767041d176e860bc Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 17 Jun 2024 21:38:49 -0400 Subject: [PATCH 123/495] Fix kw rrule closure argument index (#1543) --- src/rules/customrules.jl | 10 +++++----- test/kwrrules.jl | 42 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 1eab07c3f2..9f0fcea5de 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -734,7 +734,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, end end - # push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) + push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) needsTape = !isghostty(TapeT) && !Core.Compiler.isconstType(TapeT) @@ -765,11 +765,11 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, funcTy = rev_TT.parameters[isKWCall ? 4 : 2] if needsTape @assert tape != C_NULL - tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])) + !isghostty(funcTy) - trueidx = tape_idx+(sret !== nothing)+(returnRoots !== nothing)+swiftself+(RT <: Active) + tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup)) + !isghostty(funcTy) + trueidx = tape_idx+(sret !== nothing)+(returnRoots !== nothing)+swiftself + (RT <: Active) innerTy = value_type(parameters(llvmf)[trueidx]) if innerTy != value_type(tape) - if isabstracttype(TapeT) || TapeT == Tuple || TapeT.layout == C_NULL + if isabstracttype(TapeT) || TapeT == Tuple || TapeT.layout == C_NULL || TapeT == Array msg = sprint() do io println(io, "Enzyme : mismatch between innerTy $innerTy and tape type $(value_type(tape))") println(io, "tape_idx=", tape_idx) @@ -831,7 +831,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, if any_jltypes(llty) emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) end - insert!(args, 1+(!isghostty(funcTy))+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])), al) + insert!(args, 1+(!isghostty(funcTy))+(kwtup!==nothing && !isghostty(kwtup)), al) end end diff --git a/test/kwrrules.jl b/test/kwrrules.jl index a62ba94608..f5b9d2338a 100644 --- a/test/kwrrules.jl +++ b/test/kwrrules.jl @@ -111,5 +111,47 @@ g4(x, y) = f_kw4(x; y) @test autodiff(Reverse, g4, Active(2.0), Const(42.0))[1][1] ≈ 42004.0 @test_throws Enzyme.Compiler.EnzymeRuntimeException autodiff(Reverse, g4, Active(2.0), Active(42.0))[1] +struct Closure2 + v::Vector{Float64} + str::String +end + +function (cl::Closure2)(x; width=7) + val = cl.v[1] * x * width + cl.v[1] = 0.0 + return val +end + +function wrapclos(cl, x) + cl(x; width=9) +end + +function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closure2}, + ::Type{<:Active}, args::Vararg{Active,N}; width=7) where {N} + vec = copy(func.val.v) + pval = func.val(args[1].val) + primal = if EnzymeRules.needs_primal(config) + pval + else + nothing + end + return AugmentedReturn(primal, nothing, vec) +end + +function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{Closure2}, + dret::Active, tape, args::Vararg{Active,N}; width=7) where {N} + dargs = ntuple(Val(N)) do i + 7 * args[1].val * dret.val + tape[1] * 1000 + width * 100000 + end + return dargs +end + +@testset "KWClosure rule" begin + cl = Closure2([3.14], "3.14") + res = autodiff(Reverse, wrapclos, Active, Const(cl), Active(2.7))[1][2] + @test res ≈ 7 * 2.7 + 3.14 * 1000 + 9 * 100000 + @test cl.v[1] ≈ 0.0 +end + end # KWReverseRules From a93b618be676c51fd5316794d10961adccd5fdcd Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 18 Jun 2024 11:20:43 -0400 Subject: [PATCH 124/495] Add more blas fns (#1544) * Add more blas fns * fix * Update Project.toml --- Project.toml | 20 ++++++++++---------- src/compiler.jl | 5 ++++- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index f72e22bd39..9441aa6ef1 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,16 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[extensions] +EnzymeChainRulesCoreExt = "ChainRulesCore" +EnzymeSpecialFunctionsExt = "SpecialFunctions" +EnzymeStaticArraysExt = "StaticArrays" + [compat] CEnum = "0.4, 0.5" ChainRulesCore = "1" @@ -29,17 +39,7 @@ SpecialFunctions = "1, 2" StaticArrays = "1" julia = "1.6" -[extensions] -EnzymeChainRulesCoreExt = "ChainRulesCore" -EnzymeSpecialFunctionsExt = "SpecialFunctions" -EnzymeStaticArraysExt = "StaticArrays" - [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[weakdeps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/src/compiler.jl b/src/compiler.jl index 3f57743cc0..830232d4ab 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5154,10 +5154,13 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; check_ir(job, mod) disableFallback = String[] + + ForwardModeDerivatives = ("dot","gemm","gemv","axpy","copy","scal") + ReverseModeDerivatives = ("dot","gemm","gemv","axpy","copy","scal", "trmv", "syrk", "trmm", "trsm") # Tablegen BLAS does not support forward mode yet if !(mode == API.DEM_ForwardMode && Enzyme.API.runtimeActivity()) for ty in ("s", "d") - for func in ("dot","gemm","gemv","axpy","copy","scal") + for func in (mode == API.DEM_ForwardMode ? ForwardModeDerivatives : ReverseModeDerivatives) for prefix in ("", "cblas_") for ending in ("", "_", "64_", "_64_") push!(disableFallback, prefix*ty*func*ending) From 5cfe97bef252cb10552c2908c6fa17a3100789a9 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 18 Jun 2024 16:12:47 -0400 Subject: [PATCH 125/495] Re-enable ci for amd math fns (#1537) * Re-enable ci for amd math fns * better amd gpu errs * print * amd intrs * Update Project.toml --- Project.toml | 2 +- src/compiler.jl | 42 +++++++++++++++++++++++++++++++++++++++++- test/amdgpu.jl | 38 ++++++++++++++++++++------------------ 3 files changed, 62 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index 9441aa6ef1..3f273b5d12 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.5" -Enzyme_jll = "0.0.122" +Enzyme_jll = "0.0.123" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" diff --git a/src/compiler.jl b/src/compiler.jl index 830232d4ab..18be62e92a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1632,7 +1632,7 @@ function emit_error(B::LLVM.IRBuilder, orig, string) string*=sprint(io->Base.show_backtrace(io, bt)) end - ct = if occursin("ptx", LLVM.triple(mod)) + ct = if occursin("ptx", LLVM.triple(mod)) || occursin("amdgcn", LLVM.triple(mod)) GPUCompiler.emit_exception!(B, string, orig) else call!(B, funcT, func, LLVM.Value[globalstring_ptr!(B, string)]) @@ -5932,6 +5932,46 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end end end + if parent_job.config.target isa GPUCompiler.GCNCompilerTarget + arg1 = ("acos", "acosh", "asin", + "asinh", "atan2", "atan", + "atanh", "cbrt", "ceil", + "copysign", "cos", "native_cos", + "cosh", "cospi", "i0", + "i1", "erfc", "erfcinv", + "erfcx", "erf", "erfinv", + "exp10", "native_exp10", "exp2", + "exp", "native_exp", "expm1", + "fabs", "fdim", "floor", + "fma", "fmax", "fmin", + "fmod", "frexp", "hypot", + "ilogb", "isfinite", "isinf", + "isnan", "j0", "j1", + "ldexp", "lgamma", "log10", + "native_log10", "log1p", "log2", + "log2", "logb", "log", + "native_log", "modf", "nearbyint", + "nextafter", "len3", "len4", + "ncdf", "ncdfinv", "pow", + "pown", "rcbrt", "remainder", + "remquo", "rhypot", "rint", + "rlen3", "rlen4", "round", + "rsqrt", "scalb", "scalbn", + "signbit", "sincos", "sincospi", + "sin", "native_sin", "sinh", + "sinpi", "sqrt", "native_sqrt", + "tan", "tanh", "tgamma", + "trunc", "y0", "y1") + for n in arg1, (T, pf, lpf) in ((LLVM.DoubleType(), "", "f64"), (LLVM.FloatType(), "f", "f32")) + fname = "__ocml_"*n*"_"*lpf + if !haskey(functions(mod), fname) + FT = LLVM.FunctionType(T, [T], vararg=false) + wrapper_f = LLVM.Function(mod, fname, FT) + llname = "llvm."*n*"."*lpf + push!(function_attributes(wrapper_f), StringAttribute("implements", llname)) + end + end + end end for (name, fnty) in fnsToInject for (T, JT, pf) in ((LLVM.DoubleType(), Float64, ""), (LLVM.FloatType(), Float32, "f")) diff --git a/test/amdgpu.jl b/test/amdgpu.jl index da826efcc4..09d120e246 100644 --- a/test/amdgpu.jl +++ b/test/amdgpu.jl @@ -38,15 +38,17 @@ function grad_exp_kernel(A, dA) return nothing end -# @testset "exp_kernel" begin -# A = AMDGPU.ones(64,) -# @roc groupsize=length(A) exp_kernel(A) -# A = AMDGPU.ones(64,) -# dA = similar(A) -# dA .= 1 -# @roc groupsize=length(A) grad_exp_kernel(A, dA) -# @test all(dA .== exp(1.f0)) -# end +Enzyme.API.printall!(true) + +@testset "exp_kernel" begin + A = AMDGPU.ones(64,) + @roc groupsize=length(A) exp_kernel(A) + A = AMDGPU.ones(64,) + dA = similar(A) + dA .= 1 + @roc groupsize=length(A) grad_exp_kernel(A, dA) + @test all(dA .== exp(1.f0)) +end function cos_kernel(A) i = workitemIdx().x @@ -61,12 +63,12 @@ function grad_cos_kernel(A, dA) return nothing end -# @testset "cos_kernel" begin -# A = AMDGPU.ones(64,) -# @roc groupsize=length(A) cos_kernel(A) -# A = AMDGPU.ones(64,) -# dA = similar(A) -# dA .= 1 -# @roc groupsize=length(A) grad_cos_kernel(A, dA) -# @test all(dA .≈ -sin(1.f0)) -# end +@testset "cos_kernel" begin + A = AMDGPU.ones(64,) + @roc groupsize=length(A) cos_kernel(A) + A = AMDGPU.ones(64,) + dA = similar(A) + dA .= 1 + @roc groupsize=length(A) grad_cos_kernel(A, dA) + @test all(dA .≈ -sin(1.f0)) +end From c14685c0273ab0e33bccf7c599f1e787601e9d03 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 18 Jun 2024 22:32:25 -0400 Subject: [PATCH 126/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3f273b5d12..864c5c7ee3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.14" +version = "0.12.15" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From d10c708496815b854f485afd9440cca60d01331f Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 21 Jun 2024 09:48:09 -0400 Subject: [PATCH 127/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 864c5c7ee3..2e7083cbb0 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.5" -Enzyme_jll = "0.0.123" +Enzyme_jll = "0.0.124" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" From cfc1470edbd29d0b332bd7617d393ed4c86a25d7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 21 Jun 2024 09:48:24 -0400 Subject: [PATCH 128/495] Don't bother rewriting known calls (#1550) --- src/compiler/validation.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 951d527d6d..1a3ef9f78b 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -947,6 +947,16 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width push!(todo, (cur[off[1]], off[2:end])) continue end + + if isa(cur, LLVM.CallInst) + dest = called_operand(cur) + if isa(dest, LLVM.Function) + fn = LLVM.name(dest) + if fn == "julia.call" || fn == "julia.call2" + continue + end + end + end msg = sprint() do io::IO println(io, "Enzyme Internal Error (rewrite_union_returns_as_ref[2])") From 6b01be14cef3e8a80a518c7ae4991ecc5f813f60 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 21 Jun 2024 11:45:43 -0400 Subject: [PATCH 129/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2e7083cbb0..33d4e88848 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.15" +version = "0.12.16" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 5a9c3e7c80f36874e7c108f81eb29c0941c96b83 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 21 Jun 2024 18:57:54 -0400 Subject: [PATCH 130/495] Type unstable rules (#1553) * Type unstable rules * fixup --- src/absint.jl | 8 +++++++- src/rules/typeunstablerules.jl | 5 +++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 10dc024013..b3d009c4f0 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -13,11 +13,16 @@ function absint(arg::LLVM.Value, partial::Bool=false) ("jl_box_uint64", UInt64), ("ijl_box_uint64", UInt64), ("jl_box_int32", Int32), ("ijl_box_int32", Int32), ("jl_box_uint32", UInt32), ("ijl_box_uint32", UInt32), + ("jl_box_char", Char), ("ijl_box_char", Char), ) if nm == fname v = first(operands(arg)) if isa(v, ConstantInt) - return (true, convert(ty, v)) + if ty == Char + return (true, Char(convert(Int, v))) + else + return (true, convert(ty, v)) + end end end end @@ -144,6 +149,7 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ ("jl_box_int32", Int32), ("ijl_box_int32", Int32), ("jl_box_uint32", UInt32), ("ijl_box_uint32", UInt32), ("jl_box_float32", Float32), ("ijl_box_float32", Float32), + ("jl_box_char", Char), ("ijl_box_char", Char), ) if nm == fname return (true, ty) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 3ed5cfd72f..6ba60d3b3e 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -298,9 +298,10 @@ function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) end if !newstruct_common(#=fwd=#true, #=run=#true, offset, B, orig, gutils, normalR, shadowR) - abs_partial = [abs_typeof(v, true) for v in origops[offset+1:end-1]] origops = collect(operands(orig)) - emit_error(B, orig, "Enzyme: Not yet implemented, mixed activity for jl_new_struct constants="*string(icvs)*" "*string(orig)*" "*string(abs)*" "*string([v for v in origops[offset+1:end-1]])) + abs_partial = [abs_typeof(v, true) for v in origops[offset+1:end-1]] + icvs = [is_constant_value(gutils, v) for v in origops[offset+1:end-1]] + emit_error(B, orig, "Enzyme: Not yet implemented, mixed activity for jl_new_struct constants="*string(icvs)*" "*string(orig)*" "*string(abs_partial)*" "*string([v for v in origops[offset+1:end-1]])) end return false From 31bfad6fdfeca0e4e8e040ccb4c868a9c87aed6f Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 21 Jun 2024 19:43:14 -0400 Subject: [PATCH 131/495] Update AdaptExt.jl (#1551) * Update AdaptExt.jl cc @vchuravy * Update EnzymeCore.jl --- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/ext/AdaptExt.jl | 6 ++++++ lib/EnzymeCore/src/EnzymeCore.jl | 1 + 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index fe809f1a14..57bba3fd71 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.7.5" +version = "0.7.6" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/ext/AdaptExt.jl b/lib/EnzymeCore/ext/AdaptExt.jl index b2234404d3..4d62a20675 100644 --- a/lib/EnzymeCore/ext/AdaptExt.jl +++ b/lib/EnzymeCore/ext/AdaptExt.jl @@ -15,5 +15,11 @@ end function Adapt.adapt_structure(to, x::BatchDuplicatedNoNeed) return BatchDuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval)) end +function Adapt.adapt_structure(to, x::MixedDuplicated) + return MixedDuplicated(adapt(to, x.val), adapt(to, x.dval)) +end +function Adapt.adapt_structure(to, x::BatchMixedDuplicated) + return BatchMixedDuplicated(adapt(to, x.val), adapt(to, x.dval)) +end end #module diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index fee15c9dc6..d82072724f 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -3,6 +3,7 @@ module EnzymeCore export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed +export MixedDuplicated, BatchMixedDuplicated export DefaultABI, FFIABI, InlineABI export BatchDuplicatedFunc From 3f1be4fb3750855260baba5db5fb3b782f0e5b70 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 22 Jun 2024 10:33:35 -0400 Subject: [PATCH 132/495] Fix custom rule overwriting error to be runtime not compile time (#1552) * nicer custom rule error * more error messages --- src/rules/customrules.jl | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 9f0fcea5de..6543374db5 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -134,13 +134,24 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) if value_type(val) != eltype(value_type(ptr)) - @assert !overwritten[end] - val = load!(B, arty, val) + if overwritten[end] + emit_error(B, orig, "Enzyme: active by ref type $Ty is overwritten in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))") + end + if arty != eltype(value_type(val)) + val = load!(B, arty, val) + else + val = LLVM.UndefValue(arty) + emit_error(B, orig, "Enzyme: active by ref type $Ty is wrong type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))") + end end - store!(B, val, ptr) - if any_jltypes(llty) - emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) + if eltype(value_type(ptr)) == value_type(val) + store!(B, val, ptr) + if any_jltypes(llty) + emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) + end + else + emit_error(B, orig, "Enzyme: active by ref type $Ty is wrong store type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))") end push!(args, al) From 74f27889a23dd5132e31e990bdbc54ace3d2799c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 22 Jun 2024 12:06:05 -0400 Subject: [PATCH 133/495] Embarassing rules bugfix --- src/rules/customrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 6543374db5..79d6b96f03 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -137,7 +137,7 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, if overwritten[end] emit_error(B, orig, "Enzyme: active by ref type $Ty is overwritten in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))") end - if arty != eltype(value_type(val)) + if arty == eltype(value_type(val)) val = load!(B, arty, val) else val = LLVM.UndefValue(arty) From 9ad264a328639c1dad055ed090ebb8b541aa6b4a Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 23 Jun 2024 07:47:10 -0400 Subject: [PATCH 134/495] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 33d4e88848..c2ecdcf9df 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.16" +version = "0.12.17" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -30,7 +30,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.5" -Enzyme_jll = "0.0.124" +Enzyme_jll = "0.0.125" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" From 0217565007c22064a6e9c733c99570b4b2b7c19a Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 23 Jun 2024 12:34:27 -0400 Subject: [PATCH 135/495] Inactive get total bytes (#1557) --- src/compiler.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 18be62e92a..81f413f5c1 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -103,6 +103,7 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( end const nofreefns = Set{String}(( + "ijl_gc_get_total_bytes", "jl_gc_get_total_bytes", "ijl_array_grow_at", "jl_array_grow_at", "ijl_try_substrtod", "jl_try_substrtod", "jl_f__apply_iterate", @@ -180,6 +181,7 @@ const nofreefns = Set{String}(( )) const inactivefns = Set{String}(( + "ijl_gc_get_total_bytes", "jl_gc_get_total_bytes", "ijl_try_substrtod", "jl_try_substrtod", "ijl_tagged_gensym", "jl_tagged_gensym", "jl_get_world_counter", "ijl_get_world_counter", From 157d3dbd8f983b916a5ba3d67739ba4240dfed9d Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 23 Jun 2024 17:48:20 -0400 Subject: [PATCH 136/495] Within mi error message improvement (#1556) --- src/compiler.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 81f413f5c1..53e753eb37 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2010,6 +2010,11 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err Base.show_backtrace(io, bt) println(io) end + pscope = parent_scope(val) + mi, rt = enzyme_custom_extract_mi(pscope, #=error=#false) + if mi !== nothing + println(io, "within ", mi) + end end emit_error(B, nothing, msg2) return C_NULL From ce4919b0a7201a8442fc56258873bfce29444a11 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 23 Jun 2024 22:07:49 -0400 Subject: [PATCH 137/495] Support add into vec (#1559) --- src/rules/jitrules.jl | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index dc0bd04d80..763e815d49 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -915,7 +915,19 @@ end return body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) end +function add_into_vec!(val::Base.RefValue, expr, vec, idx_in_vec) + val[] = recursive_add(val[], expr, identity, guaranteed_nonactive) + nothing +end +function add_into_vec!(val::T, expr, vec, idx_in_vec) where T + if ismutable(vec) + @inbounds vec[idx_in_vec] = recursive_add(val, expr, identity, guaranteed_nonactive) + else + error("Enzyme Mutability Error: Cannot in place to immutable value vec[$idx_in_vec] = $val, vec=$vec") + end + nothing +end # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween0}, ::Val{lengths}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, tape, shadowargs, args::Vararg{Annotation, Nargs})::Nothing where {width, dupClosure0, ModifiedBetween0, lengths, FT, tt′, DF, Nargs} @@ -1049,13 +1061,7 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween end) else val = @inbounds vec[idx_in_vec] - if val isa Base.RefValue - val[] = recursive_add(val[], expr) - elseif ismutable(vec) - @inbounds vec[idx_in_vec] = recursive_add(val, expr, identity, guaranteed_nonactive) - else - error("Enzyme Mutability Error: Cannot in place to immutable value vec[$idx_in_vec] = $val, vec=$vec") - end + add_into_vec!(Base.inferencebarrier(val), expr, vec, idx_in_vec) end end From 72a2b066a6c15a42bb44f58fdc10e4aecae9b6b3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 23 Jun 2024 22:07:59 -0400 Subject: [PATCH 138/495] Inactive apply type (#1555) * Inactive apply type * fix * more type rules * mark lookup worlds as inactive * fix * specinfo --- src/absint.jl | 2 ++ src/compiler.jl | 4 ++++ src/compiler/validation.jl | 20 +++++++------------- src/internal_rules.jl | 3 +++ 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index b3d009c4f0..90e7a35427 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -150,6 +150,8 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ ("jl_box_uint32", UInt32), ("ijl_box_uint32", UInt32), ("jl_box_float32", Float32), ("ijl_box_float32", Float32), ("jl_box_char", Char), ("ijl_box_char", Char), + ("jl_specializations_get_linfo", Core.MethodInstance), + ("ijl_specializations_get_linfo", Core.MethodInstance), ) if nm == fname return (true, ty) diff --git a/src/compiler.jl b/src/compiler.jl index 53e753eb37..1c6646b847 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -103,6 +103,8 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( end const nofreefns = Set{String}(( + "ijl_specializations_get_linfo", "jl_specializations_get_linfo", + "ijl_gf_invoke_lookup_worlds", "jl_gf_invoke_lookup_worlds", "ijl_gc_get_total_bytes", "jl_gc_get_total_bytes", "ijl_array_grow_at", "jl_array_grow_at", "ijl_try_substrtod", "jl_try_substrtod", @@ -181,6 +183,8 @@ const nofreefns = Set{String}(( )) const inactivefns = Set{String}(( + "ijl_specializations_get_linfo", "jl_specializations_get_linfo", + "ijl_gf_invoke_lookup_worlds", "jl_gf_invoke_lookup_worlds", "ijl_gc_get_total_bytes", "jl_gc_get_total_bytes", "ijl_try_substrtod", "jl_try_substrtod", "ijl_tagged_gensym", "jl_tagged_gensym", diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 1a3ef9f78b..3649188320 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -642,25 +642,19 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) legal, iterlib = absint(operands(inst)[iteroff+1]) if legal && iterlib == Base.iterate legal, GT = abs_typeof(operands(inst)[4+1], true) - if legal && GT <: Vector - funcoff = 3 - funclib = operands(inst)[funcoff+1] - while isa(funclib, LLVM.ConstantExpr) - funclib = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(funclib, 0)) - end - if isa(funclib, ConstantInt) - rep = reinterpret(Ptr{Cvoid}, convert(Csize_t, funclib)) - funclib = Base.unsafe_pointer_to_objref(rep) - tys = [typeof(funclib), Vararg{Any}] - if is_inactive(tys, world, method_table) + funcoff = 3 + legal2, funclib = abs_typeof(operands(inst)[funcoff+1]) + if legal && (GT <: Vector || GT <: Tuple) + if legal2 + tys = [funclib, Vararg{Any}] + if funclib == typeof(Core.apply_type) || is_inactive(tys, world, method_table) inactive = LLVM.StringAttribute("enzyme_inactive", "") LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) nofree = LLVM.EnumAttribute("nofree") LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) - end - if funclib == Base.tuple && length(operands(inst)) == 4+1+1 && Base.isconcretetype(GT) && Enzyme.Compiler.guaranteed_const_nongen(GT, world) + elseif funclib == typeof(Base.tuple) && length(operands(inst)) == 4+1+1 && Base.isconcretetype(GT) && Enzyme.Compiler.guaranteed_const_nongen(GT, world) inactive = LLVM.StringAttribute("enzyme_inactive", "") LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) nofree = LLVM.EnumAttribute("nofree") diff --git a/src/internal_rules.jl b/src/internal_rules.jl index fe5ed05b89..58a312063f 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -118,6 +118,9 @@ end @inline EnzymeRules.inactive_type(v::Type{T}) where {T<:DataType} = true @inline EnzymeRules.inactive_type(v::Type{T}) where {T<:Module} = true @inline EnzymeRules.inactive_type(v::Type{T}) where {T<:AbstractString} = true +@inline EnzymeRules.inactive_type(v::Type{Core.MethodMatch}) = true +@inline EnzymeRules.inactive_type(v::Type{Core.Compiler.WorldRange}) = true +@inline EnzymeRules.inactive_type(v::Type{Core.MethodInstance}) = true @inline width(::Duplicated) = 1 @inline width(::BatchDuplicated{T, N}) where {T, N} = N From f9b63ad485a9e8fc8b62a678e503818bcb3c954d Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 23 Jun 2024 23:41:32 -0400 Subject: [PATCH 139/495] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index c2ecdcf9df..dea3bee183 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.17" +version = "0.12.18" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -30,7 +30,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.5" -Enzyme_jll = "0.0.125" +Enzyme_jll = "0.0.126" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" From a7153dfb397cc86d143d4fd894b41628b00c6a0f Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 24 Jun 2024 09:20:05 -0400 Subject: [PATCH 140/495] Ensure correct return diffe type (#1565) --- src/rules/customrules.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 79d6b96f03..1d9e568eff 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -281,8 +281,12 @@ function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, # is used as part of differential use analysis, we need to avoid an ininite recursion. Thus use # the version without differential use if actual unreachable results are not available anyways. uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) + cmode = mode + if cmode == API.DEM_ReverseModeGradient + cmode = API.DEM_ReverseModePrimal + end activep = if mode == API.DEM_ForwardMode || API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 1 - API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) + API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, cmode) else actv = API.EnzymeGradientUtilsGetDiffeType(gutils, orig, false) if !isghostty(RealRt) @@ -305,7 +309,6 @@ function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, if !needsPrimal && activep == API.DFT_DUP_ARG activep = API.DFT_DUP_NONEED end - if activep == API.DFT_CONSTANT RT = Const{RealRt} From 72809de4dcf61356b929774d4aff7dcbb9d26be7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 24 Jun 2024 09:20:14 -0400 Subject: [PATCH 141/495] More erros on allocated inline mismatch (#1566) --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 1c6646b847..ecb4b7ae66 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3833,7 +3833,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated @assert !sret_union if allocatedinline(actualRetType) != allocatedinline(literal_rt) - throw(AssertionError("Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype)")) + throw(AssertionError("Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype), sret_union=$(sret_union), pactualRetType=$(pactualRetType)")) end if rettype <: Active if !allocatedinline(actualRetType) From 10b879ea0117eb091d5c613df143979590ac5f2d Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 25 Jun 2024 00:42:45 -0400 Subject: [PATCH 142/495] fixup tuple any (#1567) --- src/compiler.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index ecb4b7ae66..522020c837 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3655,7 +3655,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr retT = (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? Ptr{actualRetType} : actualRetType - retTT = typetree(retT, ctx, dl, seen) + retTT = (actualRetType <: Tuple && in(Any, actualRetType.parameters)) ? TypeTree() : typetree(retT, ctx, dl, seen) typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values) @@ -4579,6 +4579,8 @@ function get_return_info(jlrettype)::Tuple{Union{Nothing, Type}, Union{Nothing, rt = Nothing elseif Base.isstructtype(jlrettype) && Base.issingletontype(jlrettype) &&isa(jlrettype, DataType) rt = Nothing + elseif jlrettype <: Tuple && in(Any, jlrettype.parameters) + rt = Any elseif jlrettype isa Union nbytes = 0 allunbox = for_each_uniontype_small(jlrettype) do jlrettype From f7ef35364f818539a494e6381100e6ac921ff339 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 25 Jun 2024 00:43:06 -0400 Subject: [PATCH 143/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index dea3bee183..78459ef3bc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.18" +version = "0.12.19" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From d473f11a0abe5f98f5a5305c0c61117e1188b954 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 25 Jun 2024 14:06:13 -0400 Subject: [PATCH 144/495] More info of alloc (#1568) --- src/compiler.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 522020c837..68b8fdc1d6 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3833,7 +3833,11 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated @assert !sret_union if allocatedinline(actualRetType) != allocatedinline(literal_rt) - throw(AssertionError("Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype), sret_union=$(sret_union), pactualRetType=$(pactualRetType)")) + msg = sprint() do io + println(io, string(enzymefn)) + println(io, "Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype), sret_union=$(sret_union), pactualRetType=$(pactualRetType)") + end + throw(AssertionError(msg)) end if rettype <: Active if !allocatedinline(actualRetType) From d4f44d6e7f1615fadb158f4ad93632dd413e7a87 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 25 Jun 2024 14:48:16 -0400 Subject: [PATCH 145/495] uparm (#1570) --- Project.toml | 2 +- src/compiler.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 78459ef3bc..d7fb1d827f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.19" +version = "0.12.20" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/src/compiler.jl b/src/compiler.jl index 68b8fdc1d6..36ec31b60b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4583,8 +4583,6 @@ function get_return_info(jlrettype)::Tuple{Union{Nothing, Type}, Union{Nothing, rt = Nothing elseif Base.isstructtype(jlrettype) && Base.issingletontype(jlrettype) &&isa(jlrettype, DataType) rt = Nothing - elseif jlrettype <: Tuple && in(Any, jlrettype.parameters) - rt = Any elseif jlrettype isa Union nbytes = 0 allunbox = for_each_uniontype_small(jlrettype) do jlrettype @@ -4601,6 +4599,8 @@ function get_return_info(jlrettype)::Tuple{Union{Nothing, Type}, Union{Nothing, else rt = Any end + elseif jlrettype <: Tuple && in(Any, jlrettype.parameters) + rt = Any elseif !GPUCompiler.deserves_retbox(jlrettype) lRT = convert(LLVMType, jlrettype ) if !isa(lRT, LLVM.VoidType) && GPUCompiler.deserves_sret(jlrettype, lRT) From c4068fc96a9b574fd08593178611b153cb71ac5b Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Tue, 25 Jun 2024 23:04:06 -0400 Subject: [PATCH 146/495] More union type check fixes --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 36ec31b60b..9a7f0940a5 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3655,7 +3655,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr retT = (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? Ptr{actualRetType} : actualRetType - retTT = (actualRetType <: Tuple && in(Any, actualRetType.parameters)) ? TypeTree() : typetree(retT, ctx, dl, seen) + retTT = (!isa(actualRetType, Union) && actualRetType <: Tuple && in(Any, actualRetType.parameters)) ? TypeTree() : typetree(retT, ctx, dl, seen) typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values) From cdb4df3420f0425f54aa31326ca1d6daf48cb5da Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 27 Jun 2024 02:18:49 +0200 Subject: [PATCH 147/495] Fix getfield of const (#1572) * Fix getfield of const * fix * add test * fixup --- src/compiler.jl | 4 +++- src/rules/typeunstablerules.jl | 34 +++++++++++++++++++++++++++------- test/runtests.jl | 28 ++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 8 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 9a7f0940a5..94ff0bca18 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -103,6 +103,7 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( end const nofreefns = Set{String}(( + "ijl_field_index", "jl_field_index", "ijl_specializations_get_linfo", "jl_specializations_get_linfo", "ijl_gf_invoke_lookup_worlds", "jl_gf_invoke_lookup_worlds", "ijl_gc_get_total_bytes", "jl_gc_get_total_bytes", @@ -183,6 +184,7 @@ const nofreefns = Set{String}(( )) const inactivefns = Set{String}(( + "ijl_field_index", "jl_field_index", "ijl_specializations_get_linfo", "jl_specializations_get_linfo", "ijl_gf_invoke_lookup_worlds", "jl_gf_invoke_lookup_worlds", "ijl_gc_get_total_bytes", "jl_gc_get_total_bytes", @@ -3258,7 +3260,7 @@ function annotate!(mod, mode) end end - for fname in ("jl_excstack_state","ijl_excstack_state") + for fname in ("jl_excstack_state","ijl_excstack_state", "ijl_field_index", "jl_field_index") if haskey(fns, fname) fn = fns[fname] if LLVM.version().major <= 15 diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 6ba60d3b3e..a09968a325 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -606,7 +606,7 @@ function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isco RT = Core.Typeof(res) actreg = active_reg_nothrow(RT, Val(nothing)) - if actreg == ActiveState + if actreg == ActiveState || (isconst && actreg == MixedState) if length(dptrs) == 0 return Ref{RT}(make_zero(res)) else @@ -626,6 +626,16 @@ function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isco end)...)) return fval end + elseif isconst + if length(dptrs) == 0 + return make_zero(res) + else + fval = NT((res, (ntuple(Val(length(dptrs))) do i + Base.@_inline_meta + make_zero(res) + end)...)) + return fval + end else if length(dptrs) == 0 return res @@ -648,7 +658,7 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc end RT = Core.Typeof(res) actreg = active_reg_nothrow(RT, Val(nothing)) - if actreg == ActiveState + if actreg == ActiveState || (isconst && actreg == MixedState) if length(dptrs) == 0 return Ref{RT}(make_zero(res))::Any else @@ -659,7 +669,7 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc end elseif actreg == MixedState if length(dptrs) == 0 - return Ref{RT}(res)::Any + return Ref{RT}(res) else fval = NT((Ref{RT}(res), (ntuple(Val(length(dptrs))) do i Base.@_inline_meta @@ -668,6 +678,16 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc end)...)) return fval end + elseif isconst + if length(dptrs) == 0 + return make_zero(res)::Any + else + fval = NT((res, (ntuple(Val(length(dptrs))) do i + Base.@_inline_meta + make_zero(res) + end)...)) + return fval + end else if length(dptrs) == 0 return res::Any @@ -858,7 +878,7 @@ function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, ta sym = emit_apply_type!(B, Base.Val, [sym]) push!(vals, sym) - push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, orig)))) + push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, ops[2])))) for v in inps[2:end] push!(vals, v) @@ -944,7 +964,7 @@ function common_jl_getfield_rev(offset, B, orig, gutils, tape) sym = emit_apply_type!(B, Base.Val, [sym]) push!(vals, sym) - push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, orig)))) + push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, ops[2])))) for v in inps[2:end] push!(vals, v) @@ -1037,7 +1057,7 @@ function jl_nthfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) sym = emit_apply_type!(B, Base.Val, [sym]) push!(vals, sym) - push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, orig)))) + push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, ops[1])))) for v in inps[2:end] push!(vals, v) @@ -1125,7 +1145,7 @@ function jl_nthfield_rev(B, orig, gutils, tape) sym = emit_apply_type!(B, Base.Val, [sym]) push!(vals, sym) - push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, orig)))) + push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, ops[1])))) for v in inps[2:end] push!(vals, v) diff --git a/test/runtests.jl b/test/runtests.jl index 719687ad42..b05df68e18 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2327,7 +2327,35 @@ end adres = Enzyme.autodiff(Reverse, sf_for3, Duplicated(mt3, dmt3), Const(:x), Const(:x), Active(3.1)) @test adres[1][4] ≈ 5050.0 + + mutable struct MyTypeM + x::Float64 + y + end + + @noinline function unstable_mul(x, y) + return (x*y)::Float64 + end + + function gf3(y, v::MyTypeM, fld::Symbol) + x = getfield(v, fld) + unstable_mul(x, y) + end + function gf3(y, v::MyTypeM, fld::Integer) + x = getfield_idx(v, fld) + unstable_mul(x, y) + end + + mx = MyTypeM(3.0, 1) + res = Enzyme.autodiff(Reverse, gf3, Active, Active(2.7), Const(mx), Const(:x)) + @test mx.x ≈ 3.0 + @test res[1][1] ≈ 3.0 + + mx = MyTypeM(3.0, 1) + res = Enzyme.autodiff(Reverse, gf3, Active, Active(2.7), Const(mx), Const(0)) + @test mx.x ≈ 3.0 + @test res[1][1] ≈ 3.0 end From 89f94149fc3a4334cf834d6b70c728edf4fa4fa0 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 27 Jun 2024 14:02:20 +0200 Subject: [PATCH 148/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d7fb1d827f..3a66b40d27 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.5" -Enzyme_jll = "0.0.126" +Enzyme_jll = "0.0.127" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" From 40fe290d9589a1125101d38e8dd6a558f183fc5b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 27 Jun 2024 14:02:43 +0200 Subject: [PATCH 149/495] CompatHelper: bump compat for LLVM to 8, (keep existing compat) (#1576) Co-authored-by: CompatHelper Julia --- Project.toml | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index 3a66b40d27..527ab6b14b 100644 --- a/Project.toml +++ b/Project.toml @@ -16,30 +16,30 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -[weakdeps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[extensions] -EnzymeChainRulesCoreExt = "ChainRulesCore" -EnzymeSpecialFunctionsExt = "SpecialFunctions" -EnzymeStaticArraysExt = "StaticArrays" - [compat] CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.5" Enzyme_jll = "0.0.127" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" -LLVM = "6.1, 7" +LLVM = "6.1, 7, 8" ObjectFile = "0.4" Preferences = "1.4" SpecialFunctions = "1, 2" StaticArrays = "1" julia = "1.6" +[extensions] +EnzymeChainRulesCoreExt = "ChainRulesCore" +EnzymeSpecialFunctionsExt = "SpecialFunctions" +EnzymeStaticArraysExt = "StaticArrays" + [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" From 91083c1ba1592c5ad6805f6eb30e298678670b4c Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 27 Jun 2024 18:18:04 +0200 Subject: [PATCH 150/495] Update api.jl --- src/api.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/api.jl b/src/api.jl index d68d904d5a..fd7b48afcc 100644 --- a/src/api.jl +++ b/src/api.jl @@ -455,7 +455,7 @@ end """ - maxtypeoffset!(val::Bool) + maxtypeoffset!(val::Int) Enzyme runs a type analysis to deduce the corresponding types of all values being differentiated. This is necessary to compute correct derivatives of various values. @@ -472,7 +472,7 @@ function maxtypeoffset!(val) end """ - maxtypedepth!(val::Bool) + maxtypedepth!(val::Int) Enzyme runs a type analysis to deduce the corresponding types of all values being differentiated. This is necessary to compute correct derivatives of various values. From 28686b81db5843040dbfc30be6febce157eb6ebf Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 28 Jun 2024 02:28:44 +0200 Subject: [PATCH 151/495] Fix LLVM feature check (#1581) * Fix LLVM feature check * Update src/compiler/optimize.jl Co-authored-by: Avik Pal * fix? --------- Co-authored-by: Valentin Churavy Co-authored-by: Avik Pal --- src/compiler/optimize.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index dbda1b0880..5e83084419 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -1937,8 +1937,12 @@ function optimize!(mod::LLVM.Module, tm) basic_alias_analysis!(pm) cfgsimplification!(pm) dce!(pm) +@static if isdefined(LLVM.Interop, :cpu_features!) + LLVM.Interop.cpu_features!(pm) +else @static if isdefined(GPUCompiler, :cpu_features!) GPUCompiler.cpu_features!(pm) +end end scalar_repl_aggregates_ssa!(pm) # SSA variant? mem_cpy_opt!(pm) From 7ffeab30a5e24e782837fa310b38c8ba8c33c9b8 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 28 Jun 2024 11:11:07 +0200 Subject: [PATCH 152/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 527ab6b14b..8db022e186 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.5" -Enzyme_jll = "0.0.127" +Enzyme_jll = "0.0.128" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" ObjectFile = "0.4" From 9115ce5188cc51ef48fcb97a10a0ca69e949318d Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 29 Jun 2024 19:17:43 +0100 Subject: [PATCH 153/495] Add hvp helpers (#1583) * Add hvp helpers * fixup * more fixes * fix * more tests * fixup * correct docs * try reduce length of test * cleanup dc * try doctest again * try more * try double escape * Update Project.toml * try again * Update Project.toml --- docs/src/index.md | 59 ++++++++++++++++++++- src/Enzyme.jl | 132 +++++++++++++++++++++++++++++++++++++++++++++- test/abi.jl | 7 +++ 3 files changed, 195 insertions(+), 3 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 0c42482eec..7ea84296ad 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -135,12 +135,12 @@ julia> autodiff(Forward, rosenbrock_inp, BatchDuplicated, BatchDuplicated(x, (dx (400.0, (var"1" = -800.0, var"2" = 400.0)) ``` -## Convenience functions +## Gradient Convenience functions !!! note While the convenience functions discussed below use [`autodiff`](@ref) internally, they are generally more limited in their functionality. Beyond that, these convenience functions may also come with performance penalties; especially if one makes a closure of a multi-argument function instead of calling the appropriate multi-argument [`autodiff`](@ref) function directly. -Key convenience functions for common derivative computations are [`gradient`](@ref) (and its inplace variant [`gradient!`](@ref)) and [`jacobian`](@ref). +Key convenience functions for common derivative computations are [`gradient`](@ref) (and its inplace variant [`gradient!`](@ref)). Like [`autodiff`](@ref), the mode (forward or reverse) is determined by the first argument. The functions [`gradient`](@ref) and [`gradient!`](@ref) compute the gradient of function with vector input and scalar return. @@ -174,7 +174,10 @@ julia> # in forward mode, we can also optionally pass a chunk size (-400.0, 200.0) ``` +## Jacobian Convenience functions + The function [`jacobian`](@ref) computes the Jacobian of a function vector input and vector return. +Like [`autodiff`](@ref) and [`gradient`](@ref), the mode (forward or reverse) is determined by the first argument. ```jldoctest rosenbrock julia> foo(x) = [rosenbrock_inp(x), prod(x)]; @@ -202,3 +205,55 @@ julia> # Again, the optinal chunk size argument allows us to use vector forward -400.0 200.0 2.0 1.0 ``` + +## Hessian Vector Product Convenience functions + +Enzyme provides convenience functions for second-order derivative computations, like [`hvp`](@ref) to compute Hessian vector products. Mathematically, this computes $H(x) v$, where $H$ is the hessian operator. + +Unlike [`autodiff`](@ref) and [`gradient`](@ref), a mode is not specified. Here, Enzyme will choose to perform forward over reverse mode (generally the fastest for this type of operation). + +```jldoctest hvp; filter = r"([0-9]+\\.[0-9]{8})[0-9]+" => s"\\1***" +julia> f(x) = sin(x[1] * x[2]); + +julia> hvp(f, [2.0, 3.0], [5.0, 2.7]) +2-element Vector{Float64}: + 19.6926882637302 + 16.201003759768003 +``` + +Enzyme also provides an in-place variant which will store the hessian vector product in a pre-allocated array (this will, however, still allocate another array for storing an intermediate gradient). + +```jldoctest hvp2; filter = r"([0-9]+\\.[0-9]{8})[0-9]+" => s"\\1***" +julia> f(x) = sin(x[1] * x[2]) + +julia> res = Vector{Float64}(undef, 2); + +julia> hvp!(res, f, [2.0, 3.0], [5.0, 2.7]); + +julia> res +2-element Vector{Float64}: + 19.6926882637302 + 16.201003759768003 +``` + +Finally. Enzyme provides a second in-place variant which simultaneously computes both the hessian vector product, and the gradient. This function uses no additional allocation, and is much more efficient than separately computing the hvp and the gradient. + +```jldoctest hvp3; filter = r"([0-9]+\\.[0-9]{8})[0-9]+" => s"\\1***" +julia> f(x) = sin(x[1] * x[2]); + +julia> res = Vector{Float64}(undef, 2); + +julia> grad = Vector{Float64}(undef, 2); + +julia> hvp_and_gradient!(res, grad, f, [2.0, 3.0], [5.0, 2.7]) + +julia> res +2-element Vector{Float64}: + 19.6926882637302 + 16.201003759768003 + +julia> grad +2-element Vector{Float64}: + 2.880510859951098 + 1.920340573300732 +``` \ No newline at end of file diff --git a/src/Enzyme.jl b/src/Enzyme.jl index de694b04f3..7c283d0d1e 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -20,7 +20,7 @@ export batch_size, get_func import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero, make_zero! export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero, make_zero! -export jacobian, gradient, gradient! +export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient! export markType, batch_size, onehot, chunkedonehot using LinearAlgebra @@ -1007,6 +1007,22 @@ grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) end end +""" + gradient_deferred(::ReverseMode, f, x) + +Like [`gradient`](@ref), except it using deferred mode. +""" +@inline function gradient_deferred(rm::ReverseMode, f::F, x::X) where {F, X} + if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState + dx = Ref(make_zero(x)) + autodiff_deferred(rm, f∘only, Active, Duplicated(Ref(x), dx)) + return only(dx) + else + dx = make_zero(x) + autodiff_deferred(rm, f, Active, Duplicated(x, dx)) + return dx + end +end """ gradient!(::ReverseMode, dx, f, x) @@ -1036,6 +1052,18 @@ gradient!(Reverse, dx, f, [2.0, 3.0]) dx end + +""" + gradient_deferred!(::ReverseMode, f, x) + +Like [`gradient!`](@ref), except it using deferred mode. +""" +@inline function gradient_deferred!(::ReverseMode, dx::X, f::F, x::X) where {X<:Array, F} + make_zero!(dx) + autodiff_deferred(Reverse, f, Active, Duplicated(x, dx)) + dx +end + """ gradient(::ForwardMode, f, x; shadow=onehot(x)) @@ -1249,6 +1277,108 @@ end mapreduce(LinearAlgebra.adjoint, vcat, rows) end +""" + hvp(f::F, x::X, v::X) where {F, X} + +Compute the Hessian-vector product of an array-input scalar-output function `f`, as evaluated at `x` times the vector `v`. + +In other words, compute hessian(f)(x) * v + +See [`hvp!`](@ref) for a version which stores the result in an existing buffer and also [`hvp_and_gradient!`](@ref) for a function to compute both the hvp and the gradient in a single call. + +Example: + +```jldoctest hvp; filter = r"([0-9]+\\.[0-9]{8})[0-9]+" => s"\\1***" +f(x) = sin(x[1] * x[2]) + +hvp(f, [2.0, 3.0], [5.0, 2.7]) + +# output +2-element Vector{Float64}: + 19.6926882637302 + 16.201003759768003 +``` +""" +@inline function hvp(f::F, x::X, v::X) where {F, X} + res = make_zero(x) + hvp!(res, f, x, v) + return res +end + + +""" + hvp!(res::X, f::F, x::X, v::X) where {F, X} + +Compute an in-place Hessian-vector product of an array-input scalar-output function `f`, as evaluated at `x` times the vector `v`. +The result will be stored into `res`. The function still allocates and zero's a buffer to store the intermediate gradient, which is +not returned to the user. + +In other words, compute res .= hessian(f)(x) * v + +See [`hvp_and_gradient!`](@ref) for a function to compute both the hvp and the gradient in a single call. + +Example: + +```jldoctest hvpip; filter = r"([0-9]+\\.[0-9]{8})[0-9]+" => s"\\1***" +f(x) = sin(x[1] * x[2]) + +res = Vector{Float64}(undef, 2) +hvp!(res, f, [2.0, 3.0], [5.0, 2.7]) + +res +# output +2-element Vector{Float64}: + 19.6926882637302 + 16.201003759768003 +``` +""" + +@inline function hvp!(res::X, f::F, x::X, v::X) where {F, X} + grad = make_zero(x) + Enzyme.autodiff(Forward, gradient_deferred!, Const(Reverse), DuplicatedNoNeed(grad, res), Const(f), Duplicated(x, v)) + return nothing +end + + + +""" + hvp_and_gradient!(res::X, grad::X, f::F, x::X, v::X) where {F, X} + +Compute an in-place Hessian-vector product of an array-input scalar-output function `f`, as evaluated at `x` times the vector `v` as well as +the gradient, storing the gradient into `grad`. Both the hessian vector product and the gradient can be computed together more efficiently +than computing them separately. + +The result will be stored into `res`. The gradient will be stored into `grad`. + +In other words, compute res .= hessian(f)(x) * v and grad .= gradient(Reverse, f)(x) + +Example: + +```jldoctest hvp_and_gradient; filter = r"([0-9]+\\.[0-9]{8})[0-9]+" => s"\\1***" +f(x) = sin(x[1] * x[2]) + +res = Vector{Float64}(undef, 2) +grad = Vector{Float64}(undef, 2) +hvp_and_gradient!(res, grad, f, [2.0, 3.0], [5.0, 2.7]) + +res +grad +# output +2-element Vector{Float64}: + 19.6926882637302 + 16.201003759768003 +2-element Vector{Float64}: + 2.880510859951098 + 1.920340573300732 +``` +""" + +@inline function hvp_and_gradient!(res::X, grad::X, f::F, x::X, v::X) where {F, X} + Enzyme.autodiff(Forward, gradient_deferred!, Const(Reverse), Duplicated(grad, res), Const(f), Duplicated(x, v)) + return nothing +end + + function _import_frule end # defined in EnzymeChainRulesCoreExt extension """ diff --git a/test/abi.jl b/test/abi.jl index 8d4251bb70..1bc4490a2c 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -411,6 +411,8 @@ end abssum(x) = sum(abs2, x); +mulsin(x) = sin(x[1] * x[2]) + @testset "Type inference" begin x = ones(10) @inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x)) @@ -440,6 +442,11 @@ abssum(x) = sum(abs2, x); @inferred gradient(Reverse, abssum, tx) @inferred gradient(Forward, abssum, tx) + @inferred hvp(mulsin, [2.0, 3.0], [5.0, 2.7]) + + @inferred hvp!(zeros(2), mulsin, [2.0, 3.0], [5.0, 2.7]) + + @inferred hvp_and_gradient!(zeros(2), zeros(2), mulsin, [2.0, 3.0], [5.0, 2.7]) end include("usermixed.jl") \ No newline at end of file From 4bff1fe1a803d84ed107ac41e2f33a5ab9f30225 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 29 Jun 2024 19:45:44 +0100 Subject: [PATCH 154/495] Fixup batched calling conv (#1586) * Fixup batched calling conv * lets add a test why not * Update Project.toml --- deps/build_local.jl | 3 ++- src/api.jl | 1 + src/compiler/optimize.jl | 3 +++ test/abi.jl | 19 ++++++++++++++++++- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/deps/build_local.jl b/deps/build_local.jl index 5f833ce1a5..f615f25eec 100644 --- a/deps/build_local.jl +++ b/deps/build_local.jl @@ -21,6 +21,7 @@ while length(args) > 0 global args global branch global source_dir + global BUILD_TYPE if length(args) >= 2 && args[1] == "--branch" branch = args[2] args = (args[3:end]...,) @@ -84,7 +85,7 @@ LLVM_DIR = joinpath(LLVM.artifact_dir, "lib", "cmake", "llvm") LLVM_VER_MAJOR = Base.libllvm_version.major # Build! -@info "Building" source_dir scratch_dir LLVM_DIR +@info "Building" source_dir scratch_dir LLVM_DIR BUILD_TYPE run(`cmake -DLLVM_DIR=$(LLVM_DIR) -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) -DENZYME_EXTERNAL_SHARED_LIB=ON -B$(scratch_dir) -S$(source_dir)`) if BCLoad diff --git a/src/api.jl b/src/api.jl index fd7b48afcc..da0958e75c 100644 --- a/src/api.jl +++ b/src/api.jl @@ -771,6 +771,7 @@ EnzymeAttributeKnownFunctions(f) = ccall((:EnzymeAttributeKnownFunctions, libEnz EnzymeAnonymousAliasScopeDomain(str, ctx) = LLVM.Metadata(ccall((:EnzymeAnonymousAliasScopeDomain, libEnzyme), LLVM.API.LLVMMetadataRef, (Cstring,LLVMContextRef), str, ctx)) EnzymeAnonymousAliasScope(dom::LLVM.Metadata, str) = LLVM.Metadata(ccall((:EnzymeAnonymousAliasScope, libEnzyme), LLVM.API.LLVMMetadataRef, (LLVM.API.LLVMMetadataRef,Cstring), dom.ref, str)) EnzymeFixupJuliaCallingConvention(f) = ccall((:EnzymeFixupJuliaCallingConvention, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), f) +EnzymeFixupBatchedJuliaCallingConvention(f) = ccall((:EnzymeFixupBatchedJuliaCallingConvention, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), f) e_extract_value!(builder, AggVal, Index, Name::String="") = GC.@preserve Index begin diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 5e83084419..fa2bb9168f 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -2205,6 +2205,9 @@ function post_optimze!(mod, tm, machine=true) for f in collect(functions(mod)) API.EnzymeFixupJuliaCallingConvention(f) end + for f in collect(functions(mod)) + API.EnzymeFixupBatchedJuliaCallingConvention(f) + end out_error = Ref{Cstring}() if LLVM.API.LLVMVerifyModule(mod, LLVM.API.LLVMReturnStatusAction, out_error) != 0 throw(LLVM.LLVMException("broken gc calling conv fix\n"*string(unsafe_string(out_error[]))*"\n"*string(mod))) diff --git a/test/abi.jl b/test/abi.jl index 1bc4490a2c..d371a7d0a0 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -449,4 +449,21 @@ mulsin(x) = sin(x[1] * x[2]) @inferred hvp_and_gradient!(zeros(2), zeros(2), mulsin, [2.0, 3.0], [5.0, 2.7]) end -include("usermixed.jl") \ No newline at end of file +struct ByRefStruct + x::Vector{Float64} + v::Vector{Float64} +end + +@noinline function byrefg(bref) + return bref.x[1] .+ bref.v[1] +end +function byrefs(x, v) + byrefg(ByRefStruct(x, v)) +end + +@testset "Batched byref struct" begin + + Enzyme.autodiff(Forward, byrefs, BatchDuplicated([1.0], ([1.0], [1.0])), BatchDuplicated([1.0], ([1.0], [1.0]) ) ) +end + +include("usermixed.jl") From a21e60ddbdab383f7caba147514afc95a8bdb150 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 29 Jun 2024 19:46:22 +0100 Subject: [PATCH 155/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8db022e186..1e298e92ba 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.20" +version = "0.12.21" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From b3a2bfd68aea9e0a4729e74ebd0b30f8e02ff235 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 30 Jun 2024 11:03:14 +0100 Subject: [PATCH 156/495] Custom rule fix zeroing of inductive var (#1589) * Custom rule fix zeroing of inductive var * Update rrules.jl * Update rrules.jl --- ext/EnzymeChainRulesCoreExt.jl | 6 +++--- src/rules/customrules.jl | 5 +++-- test/rrules.jl | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/ext/EnzymeChainRulesCoreExt.jl b/ext/EnzymeChainRulesCoreExt.jl index 9da9eb97fd..81491f608e 100644 --- a/ext/EnzymeChainRulesCoreExt.jl +++ b/ext/EnzymeChainRulesCoreExt.jl @@ -60,7 +60,7 @@ function Enzyme._import_frule(fn, tys...) dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval cres = $ChainRulesCore.frule((dfn, $(tangents...),), fn.val, $(primals...); kwargs...) if RetAnnotation <: Const - return nothing + return cres[2]::eltype(RetAnnotation) elseif RetAnnotation <: Duplicated return Duplicated(cres[1], cres[2]) elseif RetAnnotation <: DuplicatedNoNeed @@ -70,12 +70,12 @@ function Enzyme._import_frule(fn, tys...) end else if RetAnnotation <: Const - ntuple(Val(batchsize)) do i + cres = ntuple(Val(batchsize)) do i Base.@_inline_meta dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i] $ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...) end - return nothing + return cres[1][2]::eltype(RetAnnotation) # nothing elseif RetAnnotation <: BatchDuplicated cres1 = begin i = 1 diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 1d9e568eff..a0283f899e 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -779,7 +779,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, funcTy = rev_TT.parameters[isKWCall ? 4 : 2] if needsTape @assert tape != C_NULL - tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup)) + !isghostty(funcTy) + tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup)) + !isghostty(funcTy) + (rev_RT == Union{}) trueidx = tape_idx+(sret !== nothing)+(returnRoots !== nothing)+swiftself + (RT <: Active) innerTy = value_type(parameters(llvmf)[trueidx]) if innerTy != value_type(tape) @@ -823,6 +823,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, if API.EnzymeGradientUtilsGetDiffeType(gutils, orig, #=isforeign=#false) == API.DFT_OUT_DIFF val = LLVM.Value(API.EnzymeGradientUtilsDiffe(gutils, orig, B)) + API.EnzymeGradientUtilsSetDiffe(gutils, orig, LLVM.null(value_type(val)), B) else llety = convert(LLVMType, eltype(RT)) ptr_val = invert_pointer(gutils, operands(orig)[1 + !isghostty(funcTy)], B) @@ -845,7 +846,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, if any_jltypes(llty) emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) end - insert!(args, 1+(!isghostty(funcTy))+(kwtup!==nothing && !isghostty(kwtup)), al) + insert!(args, 1+(!isghostty(funcTy))+(kwtup!==nothing && !isghostty(kwtup)) + (rev_RT == Union{}), al) end end diff --git a/test/rrules.jl b/test/rrules.jl index ee3b9af138..b6681e6739 100644 --- a/test/rrules.jl +++ b/test/rrules.jl @@ -345,5 +345,38 @@ end @test cl.v[1] ≈ 0.0 end + +function times2(wt_y) + return wt_y*2 +end +function EnzymeRules.augmented_primal(config, ::Const{typeof(times2)}, FA, x) + return EnzymeRules.AugmentedReturn(2*x.val, nothing, nothing) +end +function EnzymeRules.reverse(config, ::Const{typeof(times2)}, FA, tape, arg) + return (46.7*FA.val,) +end + + +function times2_ar(x) + n = length(x) + res = Vector{Float64}(undef, n) + i = 1 + while true + @inbounds res[i] = @inbounds times2(@inbounds x[i]) + if i == n + break + end + i+=1 + end + return res[3]::Float64 +end + +@testset "Zero diffe result" begin + vals = [2.7, 5.6, 7.8, 12.2] + dvals = zero(vals) + Enzyme.autodiff(Reverse, times2_ar, Duplicated(vals, dvals)) + @test dvals ≈ [0., 0., 46.7, 0.] +end + include("mixedrrule.jl") end # ReverseRules From 54dc0720ebc15a80791d871afb59361f1ca2daae Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 30 Jun 2024 16:37:41 +0100 Subject: [PATCH 157/495] Fix make_zero of arbitrary struct (#1591) * Fix make_zero of arbitrary struct * fix * Update runtests.jl --- src/compiler.jl | 2 +- test/runtests.jl | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 94ff0bca18..f9106e7602 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1573,7 +1573,7 @@ end if guaranteed_const_nongen(T, nothing) return end - if in(seen, prev) + if in(prev, seen) return end @assert !Base.isabstracttype(T) diff --git a/test/runtests.jl b/test/runtests.jl index b05df68e18..d646206e95 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -207,6 +207,19 @@ end z4 = sin Enzyme.make_zero!(z4) + + struct Dense + n_inp::Int + b::Vector{Float64} + end + + function Dense(n) + Dense(n, rand(n)) + end + + nn = Dense(4) + Enzyme.make_zero!(nn) + @test nn.b ≈ [0.0, 0.0, 0.0, 0.0] end @testset "Reflection" begin From 94ddc47cf11713404510cb328c103bd00760f642 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 30 Jun 2024 16:38:01 +0100 Subject: [PATCH 158/495] Cherry pick of forward cholesky updates (#1592) * Cherry pick of forward cholesky updates * Add necessary tests --- src/internal_rules.jl | 90 +++++++++++++++++++++++++----------------- test/internal_rules.jl | 15 +++++++ 2 files changed, 68 insertions(+), 37 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 58a312063f..f155f69a29 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -744,6 +744,32 @@ function EnzymeRules.reverse( return (nothing, nothing) end +function _cholesky_forward(C::Cholesky, Ȧ) + # Computes the cholesky forward mode update rule + # C.f. eq. 8 in https://arxiv.org/pdf/1602.07527.pdf + if C.uplo === 'U' + U = C.U + U̇ = Ȧ / U + ldiv!(U', U̇) + idx = diagind(U̇) + U̇[idx] ./= 2 + triu!(U̇) + rmul!(U̇, U) + U̇ .+= UpperTriangular(Ȧ)' .- Diagonal(Ȧ) # correction for unused triangle + return Cholesky(U̇, 'U', C.info) + else + L = C.L + L̇ = L \ Ȧ + rdiv!(L̇, L') + idx = diagind(L̇) + L̇[idx] ./= 2 + tril!(L̇) + lmul!(L, L̇) + L̇ .+= LowerTriangular(Ȧ)' .- Diagonal(Ȧ) # correction for unused triangle + return Cholesky(L̇, 'L', C.info) + end +end + function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...) fact = cholesky(A.val; kwargs...) if RT <: Const @@ -756,7 +782,7 @@ function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...) dA = if isa(A, Const) ntuple(Val(N)) do i Base.@_inline_meta - zeros(A.val) + zero(A.val) end else if N == 1 @@ -768,9 +794,7 @@ function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...) dfact = ntuple(Val(N)) do i Base.@_inline_meta - Cholesky( - Matrix(fact.L * LowerTriangular(invL * dA[i] * invL' * 0.5 * I)), 'L', 0 - ) + return _cholesky_forward(fact, dA[i]) end if (RT <: DuplicatedNoNeed) || (RT <: BatchDuplicatedNoNeed) @@ -788,49 +812,41 @@ end # -> # B(out) = inv(A) B(in) # dB(out) = inv(A) [ dB(in) - dA B(out) ] -function EnzymeRules.forward( - func::Const{typeof(ldiv!)}, - RT::Type, - fact::Annotation{<:Cholesky}, - B; - kwargs... -) - if isa(B, Const) - @assert (RT <: Const) +function EnzymeRules.forward(func::Const{typeof(ldiv!)}, + RT::Type{<:Union{Const,Duplicated,BatchDuplicated}}, + fact::Annotation{<:Cholesky}, + B::Annotation{<:AbstractVecOrMat}; + kwargs...) + if B isa Const return func.val(fact.val, B.val; kwargs...) else N = width(B) + retval = B.val - @assert !isa(B, Const) + L = fact.val.L + U = fact.val.U - retval = if !isa(fact, Const) || (RT <: Const) || (RT <: Duplicated) || (RT <: BatchDuplicated) - func.val(fact.val, B.val; kwargs...) - else - nothing + ldiv!(L, B.val) + ntuple(Val(N)) do b + Base.@_inline_meta + dB = N == 1 ? B.dval : B.dval[b] + if !(fact isa Const) + dL = N == 1 ? fact.dval.L : fact.dval[b].L + mul!(dB, dL, B.val, -1, 1) + end + ldiv!(L, dB) end + ldiv!(U, B.val) dretvals = ntuple(Val(N)) do b Base.@_inline_meta - - dB = if N == 1 - B.dval - else - B.dval[b] + dB = N == 1 ? B.dval : B.dval[b] + if !(fact isa Const) + dU = N == 1 ? fact.dval.U : fact.dval[b].U + mul!(dB, dU, B.val, -1, 1) end - - if !isa(fact, Const) - - dfact = if N == 1 - fact.dval - else - fact.dval[b] - end - - tmp = dfact.U * retval - mul!(dB, dfact.L, tmp, -1, 1) - end - - func.val(fact.val, dB; kwargs...) + ldiv!(U, dB) + return dB end if RT <: Const diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 965d8d4b55..fe16a52588 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -413,6 +413,21 @@ end end end +function chol_upper(x) + x = reshape(x, 4, 4) + x = parent(cholesky(Hermitian(x)).U) + x = convert(typeof(x), UpperTriangular(x)) + return x[1,2] +end + +@testset "Cholesky upper triangular" begin + x = [1.0, -0.10541615131279458, 0.6219810761363638, 0.293343219811946, -0.10541615131279458, 1.0, -0.05258941747718969, 0.34629296878264443, 0.6219810761363638, -0.05258941747718969, 1.0, 0.4692436399208845, 0.293343219811946, 0.34629296878264443, 0.4692436399208845, 1.0] + + @test collect(Enzyme.gradient(Forward, chol_upper, x)) ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + + @test_broken Enzyme.gradient(Reverse, chol_upper, x) ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +end + @testset "Linear solve for triangular matrices" begin @testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3)) From 800053b348210e092c3a81de95a8b27256e5e23a Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 30 Jun 2024 18:41:04 +0100 Subject: [PATCH 159/495] More fwdblas (#1590) * More fwdblas * Update compiler.jl --- Project.toml | 2 +- src/compiler.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 1e298e92ba..3b64ba9426 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.5" -Enzyme_jll = "0.0.128" +Enzyme_jll = "0.0.129" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" ObjectFile = "0.4" diff --git a/src/compiler.jl b/src/compiler.jl index f9106e7602..7be1f68725 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5174,8 +5174,8 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; disableFallback = String[] - ForwardModeDerivatives = ("dot","gemm","gemv","axpy","copy","scal") - ReverseModeDerivatives = ("dot","gemm","gemv","axpy","copy","scal", "trmv", "syrk", "trmm", "trsm") + ForwardModeDerivatives = ("nrm2", "dot","gemm","gemv","axpy","copy","scal", "syrk") + ReverseModeDerivatives = (#="nrm2",=# "dot","gemm","gemv","axpy","copy","scal", "trmv", "syrk", "trmm", "trsm") # Tablegen BLAS does not support forward mode yet if !(mode == API.DEM_ForwardMode && Enzyme.API.runtimeActivity()) for ty in ("s", "d") From f26f6f5639d70c5eee74cd5b0d0bf583392f6c70 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 2 Jul 2024 16:45:26 +0100 Subject: [PATCH 160/495] Fix noreturn removal (#1602) --- src/compiler/optimize.jl | 5 ++++- src/compiler/utils.jl | 9 +++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index fa2bb9168f..946e7ca90a 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -1158,9 +1158,12 @@ function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String}) if isempty(blocks(cur)) return false end + + err_is_readonly = !is_noreturn(cur) + for bb in blocks(cur) for inst in instructions(bb) - if !mayWriteToMemory(inst; err_is_readonly=true) + if !mayWriteToMemory(inst; err_is_readonly) continue end if isa(inst, LLVM.CallInst) diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index b5bdb3afa2..9595927aff 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -100,6 +100,15 @@ for n in (:is_readonly, :is_readnone, :is_writeonly) end end +function is_noreturn(f::LLVM.Function) + for attr in collect(function_attributes(f)) + if kind(attr) == kind(EnumAttribute("noreturn")) + return true + end + end + return false +end + function is_readonly(f::LLVM.Function) for attr in collect(function_attributes(f)) if kind(attr) == kind(EnumAttribute("readonly")) From 0a7eca7417756b7958a3cb363d3e0a512d7a2290 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 4 Jul 2024 14:29:02 -0400 Subject: [PATCH 161/495] Mark hasproperty as inactivenoinl (#1607) --- src/internal_rules.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index f155f69a29..6abb5b469a 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -106,6 +106,10 @@ function EnzymeRules.inactive_noinl(::typeof(Base.setindex!), ::IdDict{K, V}, :: return nothing end +function EnzymeRules.inactive_noinl(::typeof(Base.hasproperty), args...) + return nothing +end + if VERSION >= v"1.9" Enzyme.EnzymeRules.inactive_noinl(::typeof(Core._compute_sparams), args...) = nothing end From ba38959ab899a72b47fad9b3ee50994e2bf9f17e Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 4 Jul 2024 16:57:19 -0400 Subject: [PATCH 162/495] Inactive isdefined (#1609) --- src/absint.jl | 4 ++++ src/compiler.jl | 2 ++ 2 files changed, 6 insertions(+) diff --git a/src/absint.jl b/src/absint.jl index 90e7a35427..f8b33d0197 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -185,6 +185,10 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ nm = LLVM.name(fn) index += 1 end + + if nm == "jl_f_isdefined" || nm == "ijl_f_isdefined" + return Bool + end if nm == "jl_new_structv" || nm == "ijl_new_structv" @assert index == 2 diff --git a/src/compiler.jl b/src/compiler.jl index 7be1f68725..dde54978ec 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -103,6 +103,7 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( end const nofreefns = Set{String}(( + "ijl_f_isdefined", "jl_f_isdefined", "ijl_field_index", "jl_field_index", "ijl_specializations_get_linfo", "jl_specializations_get_linfo", "ijl_gf_invoke_lookup_worlds", "jl_gf_invoke_lookup_worlds", @@ -184,6 +185,7 @@ const nofreefns = Set{String}(( )) const inactivefns = Set{String}(( + "ijl_f_isdefined", "jl_f_isdefined", "ijl_field_index", "jl_field_index", "ijl_specializations_get_linfo", "jl_specializations_get_linfo", "ijl_gf_invoke_lookup_worlds", "jl_gf_invoke_lookup_worlds", From 14f87b4afd879a93543cd0dabacd272875ec826a Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 6 Jul 2024 00:47:24 -0400 Subject: [PATCH 163/495] Fix make_zero! issues (#1612) --- src/compiler.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index dde54978ec..a043183525 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1439,8 +1439,9 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T, S} if guaranteed_const_nongen(T, nothing) return prev end - @assert !ismutable(T) + @assert !ismutable(prev) + RT = Core.Typeof(prev) @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) @@ -6148,7 +6149,7 @@ end @inline mutable_register(::Type{T}) where T <: NamedTuple = false @inline mutable_register(::Type{Core.Box}) = true @inline mutable_register(::Type{T}) where T <: Array = true -@inline mutable_register(::Type{T}) where T = ismutable(T) +@inline mutable_register(::Type{T}) where T = ismutabletype(T) # Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) @inline function recursive_accumulate(x::Array{T}, y::Array{T}, f::F=identity) where {T, F} From dd5338f6e62b41e100f2fc3b21de386d4c1f8929 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 6 Jul 2024 09:20:06 -0400 Subject: [PATCH 164/495] Fix isdefined abstypeof (#1613) --- src/absint.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/absint.jl b/src/absint.jl index f8b33d0197..5d9c595059 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -187,7 +187,7 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ end if nm == "jl_f_isdefined" || nm == "ijl_f_isdefined" - return Bool + return (true, Bool) end if nm == "jl_new_structv" || nm == "ijl_new_structv" From 4da44dd8f7bd1234d2dee23ec8a14cc440e2225f Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 6 Jul 2024 10:20:47 -0400 Subject: [PATCH 165/495] Additional tests for jll bump (#1598) * Additional tests for jll bump * Update compiler.jl * Update Project.toml * Update Project.toml * Update compiler.jl * Update runtests.jl * Update runtests.jl * Update runtests.jl * Update runtests.jl * try name fix * fix syntax * Revert "fix syntax" This reverts commit e782914380609681d7076e00d5b6245ff2a2fe89. * Revert "try name fix" This reverts commit 0fa09e49f99c522b81e68309e94fb8121afb6fc1. * permit demotion into addr10 of vector insert * tmp to debug * fix * set operand fix * restore weird decayaddr update * bref sret * with dce * fix * move * up * fix * re-enable nrm2 * Update Project.toml * Update runtests.jl --- Project.toml | 24 ++-- deps/build_local.jl | 1 + src/compiler.jl | 19 ++- src/compiler/optimize.jl | 16 ++- src/internal_rules.jl | 272 +++++++++++---------------------------- src/rules/llvmrules.jl | 6 +- test/internal_rules.jl | 119 +++++++++++++++-- test/runtests.jl | 47 +++++++ 8 files changed, 280 insertions(+), 224 deletions(-) diff --git a/Project.toml b/Project.toml index 3b64ba9426..fad09d3ef0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.21" +version = "0.12.22" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -16,11 +16,21 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[extensions] +EnzymeChainRulesCoreExt = "ChainRulesCore" +EnzymeSpecialFunctionsExt = "SpecialFunctions" +EnzymeStaticArraysExt = "StaticArrays" + [compat] CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.5" -Enzyme_jll = "0.0.129" +Enzyme_jll = "0.0.131" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" ObjectFile = "0.4" @@ -29,17 +39,7 @@ SpecialFunctions = "1, 2" StaticArrays = "1" julia = "1.6" -[extensions] -EnzymeChainRulesCoreExt = "ChainRulesCore" -EnzymeSpecialFunctionsExt = "SpecialFunctions" -EnzymeStaticArraysExt = "StaticArrays" - [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[weakdeps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/deps/build_local.jl b/deps/build_local.jl index f615f25eec..32089e8553 100644 --- a/deps/build_local.jl +++ b/deps/build_local.jl @@ -22,6 +22,7 @@ while length(args) > 0 global branch global source_dir global BUILD_TYPE + global BCLoad if length(args) >= 2 && args[1] == "--branch" branch = args[2] args = (args[3:end]...,) diff --git a/src/compiler.jl b/src/compiler.jl index a043183525..4090bc304c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3736,7 +3736,18 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr @assert "Unhandled derivative mode", mode end API.EnzymeLogicErasePreprocessedFunctions(logic) + adjointfname = adjointf == nothing ? nothing : LLVM.name(adjointf) + augmented_primalfname = augmented_primalf == nothing ? nothing : LLVM.name(augmented_primalf) + for f in collect(functions(mod)) + API.EnzymeFixupBatchedJuliaCallingConvention(f) + end + ModulePassManager() do pm + dce!(pm) + run!(pm, mod) + end fix_decayaddr!(mod) + adjointf = adjointf == nothing ? nothing : functions(mod)[adjointfname] + augmented_primalf = augmented_primalf == nothing ? nothing : functions(mod)[augmented_primalfname] return adjointf, augmented_primalf, TapeType end @@ -5177,11 +5188,13 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; disableFallback = String[] - ForwardModeDerivatives = ("nrm2", "dot","gemm","gemv","axpy","copy","scal", "syrk") - ReverseModeDerivatives = (#="nrm2",=# "dot","gemm","gemv","axpy","copy","scal", "trmv", "syrk", "trmm", "trsm") + ForwardModeDerivatives = ("nrm2","dot","gemm","gemv","axpy","copy","scal", "syrk", "potrf") + ReverseModeDerivatives = ("nrm2","dot","gemm","gemv","axpy","copy","scal", "trmv", "syrk", "trmm", "trsm", "potrf") + ForwardModeTypes = ("s", "d", "c", "z") + ReverseModeTypes = ("s", "d") # Tablegen BLAS does not support forward mode yet if !(mode == API.DEM_ForwardMode && Enzyme.API.runtimeActivity()) - for ty in ("s", "d") + for ty in (mode == API.DEM_ForwardMode ? ForwardModeTypes : ReverseModeTypes) for func in (mode == API.DEM_ForwardMode ? ForwardModeDerivatives : ReverseModeDerivatives) for prefix in ("", "cblas_") for ending in ("", "_", "64_", "_64_") diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 946e7ca90a..578059803a 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -786,7 +786,7 @@ end function fix_decayaddr!(mod::LLVM.Module) for f in functions(mod) - invalid = LLVM.AddrSpaceCastInst[] + invalid = LLVM.Instruction[] for bb in blocks(f), inst in instructions(bb) if !isa(inst, LLVM.AddrSpaceCastInst) continue @@ -815,6 +815,18 @@ function fix_decayaddr!(mod::LLVM.Module) continue end end + # if isa(st, LLVM.InsertValueInst) + # if operands(st)[1] == inst + # push!(invalid, st) + # LLVM.API.LLVMSetOperand(st, 1-1, LLVM.UndefValue(value_type(inst))) + # continue + # end + # if operands(st)[2] == inst + # push!(invalid, st) + # LLVM.API.LLVMSetOperand(st, 2-1, LLVM.UndefValue(value_type(inst))) + # continue + # end + # end if !isa(st, LLVM.CallInst) bt = GPUCompiler.backtrace(st) msg = sprint() do io::IO @@ -828,7 +840,7 @@ function fix_decayaddr!(mod::LLVM.Module) println(io) end end - throw(AssertionError(msg)) + throw(AssertionError(msg)) end fop = operands(st)[end] diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 6abb5b469a..fb3a60954f 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -748,69 +748,6 @@ function EnzymeRules.reverse( return (nothing, nothing) end -function _cholesky_forward(C::Cholesky, Ȧ) - # Computes the cholesky forward mode update rule - # C.f. eq. 8 in https://arxiv.org/pdf/1602.07527.pdf - if C.uplo === 'U' - U = C.U - U̇ = Ȧ / U - ldiv!(U', U̇) - idx = diagind(U̇) - U̇[idx] ./= 2 - triu!(U̇) - rmul!(U̇, U) - U̇ .+= UpperTriangular(Ȧ)' .- Diagonal(Ȧ) # correction for unused triangle - return Cholesky(U̇, 'U', C.info) - else - L = C.L - L̇ = L \ Ȧ - rdiv!(L̇, L') - idx = diagind(L̇) - L̇[idx] ./= 2 - tril!(L̇) - lmul!(L, L̇) - L̇ .+= LowerTriangular(Ȧ)' .- Diagonal(Ȧ) # correction for unused triangle - return Cholesky(L̇, 'L', C.info) - end -end - -function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...) - fact = cholesky(A.val; kwargs...) - if RT <: Const - return fact - else - N = width(RT) - - invL = inv(fact.L) - - dA = if isa(A, Const) - ntuple(Val(N)) do i - Base.@_inline_meta - zero(A.val) - end - else - if N == 1 - (A.dval,) - else - A.dval - end - end - - dfact = ntuple(Val(N)) do i - Base.@_inline_meta - return _cholesky_forward(fact, dA[i]) - end - - if (RT <: DuplicatedNoNeed) || (RT <: BatchDuplicatedNoNeed) - return dfact - elseif RT <: Duplicated - return Duplicated(fact, dfact[1]) - else - return BatchDuplicated(fact, dfact) - end - end -end - # y = inv(A) B # dY = inv(A) [ dB - dA y ] # -> @@ -867,63 +804,6 @@ function EnzymeRules.forward(func::Const{typeof(ldiv!)}, end end -function EnzymeRules.augmented_primal( - config, - func::Const{typeof(cholesky)}, - RT::Type, - A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; - kwargs...) - fact = if EnzymeRules.needs_primal(config) - cholesky(A.val; kwargs...) - else - nothing - end - - # dfact would be a dense matrix, prepare buffer - dfact = if RT <: Const - nothing - else - if EnzymeRules.width(config) == 1 - Enzyme.make_zero(fact) - else - ntuple(Val(EnzymeRules.width(config))) do i - Base.@_inline_meta - Enzyme.make_zero(fact) - end - end - end - cache = if isa(A, Const) - nothing - else - dfact - end - - return EnzymeRules.AugmentedReturn(fact, dfact, cache) -end - -function EnzymeRules.reverse( - config, - ::Const{typeof(cholesky)}, - RT::Type, - dfact, - A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; - kwargs...) - - if !(RT <: Const) && !isa(A, Const) - dAs = EnzymeRules.width(config) == 1 ? (A.dval,) : A.dval - dfacts = EnzymeRules.width(config) == 1 ? (dfact,) : dfact - - for (dA, dfact) in zip(dAs, dfacts) - _dA = dA isa LinearAlgebra.RealHermSym ? dA.data : dA - if _dA !== dfact.factors - _dA .+= dfact.factors - dfact.factors .= 0 - end - end - end - return (nothing,) -end - # y=inv(A) B # dA −= z y^T @@ -933,79 +813,79 @@ end # B(out)=inv(A) B(in) # dA −= z B(out)^T # dB = z, where z = inv(A^T) dB -function EnzymeRules.augmented_primal( - config, - func::Const{typeof(ldiv!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}}, - - A::Annotation{<:Cholesky}, - B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; - kwargs... -) - func.val(A.val, B.val; kwargs...) - - cache_Bout = if !isa(A, Const) && !isa(B, Const) - if EnzymeRules.overwritten(config)[3] - copy(B.val) - else - B.val - end - else - nothing - end - - cache_A = if !isa(B, Const) - if EnzymeRules.overwritten(config)[2] - copy(A.val) - else - A.val - end - else - nothing - end - - primal = if EnzymeRules.needs_primal(config) - B.val - else - nothing - end - - shadow = if EnzymeRules.needs_shadow(config) - B.dval - else - nothing - end - - return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_Bout)) -end - -function EnzymeRules.reverse( - config, - func::Const{typeof(ldiv!)}, - dret, - cache, - A::Annotation{<:Cholesky}, - B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; - kwargs... -) - if !isa(B, Const) - - (cache_A, cache_Bout) = cache - - for b in 1:EnzymeRules.width(config) - - dB = EnzymeRules.width(config) == 1 ? B.dval : B.dval[b] - - # dB = z, where z = inv(A^T) dB - # dA −= z B(out)^T - - func.val(cache_A, dB; kwargs...) - if !isa(A, Const) - dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] - mul!(dA.factors, dB, transpose(cache_Bout), -1, 1) - end - end - end - - return (nothing, nothing) -end +# function EnzymeRules.augmented_primal( +# config, +# func::Const{typeof(ldiv!)}, +# RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}}, +# +# A::Annotation{<:Cholesky}, +# B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; +# kwargs... +# ) +# func.val(A.val, B.val; kwargs...) +# +# cache_Bout = if !isa(A, Const) && !isa(B, Const) +# if EnzymeRules.overwritten(config)[3] +# copy(B.val) +# else +# B.val +# end +# else +# nothing +# end +# +# cache_A = if !isa(B, Const) +# if EnzymeRules.overwritten(config)[2] +# copy(A.val) +# else +# A.val +# end +# else +# nothing +# end +# +# primal = if EnzymeRules.needs_primal(config) +# B.val +# else +# nothing +# end +# +# shadow = if EnzymeRules.needs_shadow(config) +# B.dval +# else +# nothing +# end +# +# return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_Bout)) +# end +# +# function EnzymeRules.reverse( +# config, +# func::Const{typeof(ldiv!)}, +# dret, +# cache, +# A::Annotation{<:Cholesky}, +# B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; +# kwargs... +# ) +# if !isa(B, Const) +# +# (cache_A, cache_Bout) = cache +# +# for b in 1:EnzymeRules.width(config) +# +# dB = EnzymeRules.width(config) == 1 ? B.dval : B.dval[b] +# +# # dB = z, where z = inv(A^T) dB +# # dA −= z B(out)^T +# +# func.val(cache_A, dB; kwargs...) +# if !isa(A, Const) +# dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] +# mul!(dA.factors, dB, transpose(cache_Bout), -1, 1) +# end +# end +# end +# +# return (nothing, nothing) +# end diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index d0146a64e2..91d6b44820 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -247,12 +247,12 @@ function arraycopy_fwd(B, orig, gutils, normalR, shadowR) ev = extract_value!(B, shadowin, idx-1) callv = call_samefunc_with_inverted_bundles!(B, gutils, orig, [ev], [API.VT_Shadow], #=lookup=#false) if is_constant_value(gutils, origops[1]) - elSize = get_array_elsz(B, shadowin) + elSize = get_array_elsz(B, ev) elSize = LLVM.zext!(B, elSize, LLVM.IntType(8*sizeof(Csize_t))) - len = get_array_len(B, shadowin) + len = get_array_len(B, ev) length = LLVM.mul!(B, len, elSize) GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type" - LLVM.memset!(B, get_array_data(callv), LLVM.ConstantInt(i8, 0, false), length, algn) + LLVM.memset!(B, get_array_data(B, callv), LLVM.ConstantInt(i8, 0, false), length, algn) end if API.runtimeActivity() prev = new_from_original(gutils, orig) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index fe16a52588..a5479629a0 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -135,6 +135,108 @@ end @test dA ≈ (-z * transpose(y)) end +function chol_lower0(x) + c = copy(x) + C, info = LinearAlgebra.LAPACK.potrf!('L', c) + return c[2,1] +end + +function chol_upper0(x) + c = copy(x) + C, info = LinearAlgebra.LAPACK.potrf!('U', c) + return c[1,2] +end + +@testset "Cholesky PotRF" begin + x = reshape([1.0, -0.10541615131279458, 0.6219810761363638, 0.293343219811946, -0.10541615131279458, 1.0, -0.05258941747718969, 0.34629296878264443, 0.6219810761363638, -0.05258941747718969, 1.0, 0.4692436399208845, 0.293343219811946, 0.34629296878264443, 0.4692436399208845, 1.0], 4, 4) + dL = zero(x) + dL[2, 1] = 1.0 + + @test Enzyme.gradient(Reverse, chol_lower0, x) ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + + @test reshape(collect(Enzyme.gradient(Forward, chol_lower0, x)), 4, 4) ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + + @test FiniteDifferences.grad(central_fdm(5, 1), chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + + @test reshape(collect(Enzyme.gradient(Forward, chol_upper0, x)), 4, 4) ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + @test Enzyme.gradient(Reverse, chol_upper0, x) ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + @test FiniteDifferences.grad(central_fdm(5, 1), chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] +end + + +function tchol_lower(x, row, col) + c = copy(x) + C, info = LinearAlgebra.LAPACK.potrf!('L', c) + return c[row, col] +end +function tchol_upper(x, row, col) + c = copy(x) + C, info = LinearAlgebra.LAPACK.potrf!('U', c) + return c[row, col] +end + +@testset "Cholesky PotRF 3x3" begin + + x = [1.0 0.13147601759884564 0.5282944836504488; 0.13147601759884564 1.0 0.18506733179093515; 0.5282944836504488 0.18506733179093515 1.0] + for i in 1:size(x, 1) + for j in 1:size(x, 2) + reverse_grad = Enzyme.gradient(Reverse, x -> tchol_lower(x, i, j), x) + forward_grad = reshape(collect(Enzyme.gradient(Forward, x -> tchol_lower(x, i, j), x)), size(x)) + finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tchol_lower(x, i, j), x)[1] + @test reverse_grad ≈ finite_diff + @test forward_grad ≈ finite_diff + + reverse_grad = Enzyme.gradient(Reverse, x -> tchol_upper(x, i, j), x) + forward_grad = reshape(collect(Enzyme.gradient(Forward, x -> tchol_upper(x, i, j), x)), size(x)) + finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tchol_upper(x, i, j), x)[1] + @test reverse_grad ≈ finite_diff + @test forward_grad ≈ finite_diff + end + end +end + +function tcholsolv_lower(A, B, i) + c = copy(B) + C, info = LinearAlgebra.LAPACK.potrs!('L', A, c) + return c[i] +end +function tcholsolv_upper(A, B, i) + c = copy(B) + C, info = LinearAlgebra.LAPACK.potrs!('U', A, c) + return c[i] +end + +@testset "Cholesky PotRS 3x5" begin + + x = [1.0 0.13147601759884564 0.5282944836504488; 0.13147601759884564 1.0 0.18506733179093515; 0.5282944836504488 0.18506733179093515 1.0] + for i in 1:15 + B = [3.1 2.7 5.9 2.4 1.6; 7.9 8.2 1.3 9.4 5.5; 4.7 2.9 9.8 7.1 4.3] + reverse_grad = Enzyme.gradient(Reverse, B -> tcholsolv_lower(x, B, i), B) + # forward_grad = reshape(collect(Enzyme.gradient(Forward, B -> tcholsolv_lower(x, B, i), B)), size(B)) + finite_diff = FiniteDifferences.grad(central_fdm(5, 1), B -> tcholsolv_lower(x, B, i), B)[1] + @test reverse_grad ≈ finite_diff + # @test forward_grad ≈ finite_diff + + reverse_grad = Enzyme.gradient(Reverse, B -> tcholsolv_upper(x, B, i), B) + # forward_grad = reshape(collect(Enzyme.gradient(Forward, B -> tcholsolv_upper(x, B, i), B)), size(B)) + finite_diff = FiniteDifferences.grad(central_fdm(5, 1), B -> tcholsolv_upper(x, B, i), B)[1] + @test reverse_grad ≈ finite_diff + # @test forward_grad ≈ finite_diff + + reverse_grad = Enzyme.gradient(Reverse, x -> tcholsolv_lower(x, B, i), x) + #forward_grad = reshape(collect(Enzyme.gradient(Forward, x -> tcholsolv_lower(x, B, i), x)), size(x)) + finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tcholsolv_lower(x, B, i), x)[1] + @test reverse_grad ≈ finite_diff + #@test forward_grad ≈ finite_diff + # + reverse_grad = Enzyme.gradient(Reverse, x -> tcholsolv_upper(x, B, i), x) + #forward_grad = reshape(collect(Enzyme.gradient(Forward, x -> tcholsolv_upper(x, B, i), x)), size(x)) + finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tcholsolv_upper(x, B, i), x)[1] + @test reverse_grad ≈ finite_diff + #@test forward_grad ≈ finite_diff + end +end + @static if VERSION > v"1.8" @testset "Cholesky" begin function symmetric_definite(n :: Int=10) @@ -396,20 +498,21 @@ end @test isapprox(fwdJ, revJ) function h(A, b) - C = cholesky(A) + A = copy(A) + LinearAlgebra.LAPACK.potrf!('U', A) b2 = copy(b) - ldiv!(C, b2) + LinearAlgebra.LAPACK.potrs!('U', A, b2) @inbounds b2[1] end A = [1.3 0.5; 0.5 1.5] b = [1., 2.] - V = [1.0 0.0; 0.0 0.0] dA = zero(A) Enzyme.autodiff(Reverse, h, Active, Duplicated(A, dA), Const(b)) + # dA_fwd = Enzyme.gradient(Forward, A->h(A, b), A) + dA_fd = FiniteDifferences.grad(central_fdm(5, 1), A->h(A, b), A)[1] - dA_sym = - (transpose(A) \ [1.0, 0.0]) * transpose(A \ b) - @test isapprox(dA, dA_sym) + @test isapprox(dA, dA_fd) end end @@ -420,14 +523,14 @@ function chol_upper(x) return x[1,2] end -@testset "Cholesky upper triangular" begin +@testset "Cholesky upper triangular v1" begin x = [1.0, -0.10541615131279458, 0.6219810761363638, 0.293343219811946, -0.10541615131279458, 1.0, -0.05258941747718969, 0.34629296878264443, 0.6219810761363638, -0.05258941747718969, 1.0, 0.4692436399208845, 0.293343219811946, 0.34629296878264443, 0.4692436399208845, 1.0] @test collect(Enzyme.gradient(Forward, chol_upper, x)) ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - @test_broken Enzyme.gradient(Reverse, chol_upper, x) ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + @test Enzyme.gradient(Reverse, chol_upper, x) ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] end - + @testset "Linear solve for triangular matrices" begin @testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3)) diff --git a/test/runtests.jl b/test/runtests.jl index d646206e95..dbc9d7c836 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -287,6 +287,34 @@ sqrtsumsq2(x) = (sum(abs2, x)*sum(abs2,x)) Enzyme.autodiff(Reverse, sqrtsumsq2, Duplicated(x,dx)) end +@noinline function prt_sret(A) + A[1] *= 2 + return (A, A[2]) +end + +@noinline function sretf(A2, x, c) + x[3] = c * A2[3] +end + +@noinline function batchdecaysret0(x, A, b) + A2, c = prt_sret(A) + sretf(A2, x, c) + return nothing +end + +function batchdecaysret(x, A, b) + batchdecaysret0(x, A, b) + A[2] = 0 + return nothing +end + +@testset "Batch Reverse sret fix" begin + Enzyme.autodiff(Reverse, batchdecaysret, + BatchDuplicated(ones(3), (ones(3), ones(3))), + BatchDuplicated(ones(3), (ones(3), ones(3))), + BatchDuplicated(ones(3), (ones(3), ones(3)))) +end + # @testset "Split Tape" begin # f(x) = x[1] * x[1] @@ -3185,6 +3213,25 @@ end end end +struct GDoubleField{T} + this_field_does_nothing::T + b::T +end + +GDoubleField() = GDoubleField{Float64}(0.0, 1.0) +function fexpandempty(vec) + x = vec[1] + empty = [] + d = GDoubleField(empty...) + return x ≤ d.b ? x * d.b : zero(x) +end + +@testset "Constant Complex return" begin + vec = [0.5] + @test Enzyme.gradient(Enzyme.Reverse, fexpandempty, vec)[1] ≈ 1.0 + @test Enzyme.gradient(Enzyme.Forward, fexpandempty, vec)[1] ≈ 1.0 +end + const CUmemoryPool2 = Ptr{Float64} struct CUmemPoolProps2 From c799584d85afeff75eb304ea57583d5fd97de98b Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 7 Jul 2024 20:58:53 -0400 Subject: [PATCH 166/495] Writeonly capture fix (#1616) * Writeonly capture fix * Update runtests.jl * Update runtests.jl * Update runtests.jl * Update runtests.jl --- src/compiler/optimize.jl | 24 +++++++++++++++--------- test/runtests.jl | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 578059803a..7aebf9be0f 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -1543,8 +1543,11 @@ function detect_writeonly!(mod::LLVM.Module) end for (i, a) in enumerate(parameters(f)) if isa(value_type(a), LLVM.PointerType) - todo = LLVM.Value[a] - seen = Set{LLVM.Value}() + todo = Tuple{LLVM.Value, LLVM.Instruction}[] + for u in LLVM.uses(a) + push!(todo, (a, LLVM.user(u))) + end + seen = Set{Tuple{LLVM.Value, LLVM.Instruction}}() mayread = false maywrite = false while length(todo) > 0 @@ -1553,20 +1556,23 @@ function detect_writeonly!(mod::LLVM.Module) continue end push!(seen, cur) + curv, curi = cur - if isa(cur, LLVM.StoreInst) - maywrite = true - continue + if isa(curi, LLVM.StoreInst) + if operands(curi)[1] != curv + maywrite = true + continue + end end - if isa(cur, LLVM.LoadInst) + if isa(curi, LLVM.LoadInst) mayread = true continue end - if isa(cur, LLVM.Argument) || isa(cur, LLVM.GetElementPtrInst) || isa(cur, LLVM.BitCastInst) || isa(cur, LLVM.AddrSpaceCastInst) - for u in LLVM.uses(cur) - push!(todo, LLVM.user(u)) + if isa(curi, LLVM.GetElementPtrInst) || isa(curi, LLVM.BitCastInst) || isa(curi, LLVM.AddrSpaceCastInst) + for u in LLVM.uses(curi) + push!(todo, (curi, LLVM.user(u))) end continue end diff --git a/test/runtests.jl b/test/runtests.jl index dbc9d7c836..9bebab68b0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -718,6 +718,24 @@ function euroad(f::T) where T return g end +@noinline function womylogpdf(X::AbstractArray{<:Real}) + map(womylogpdf, X) +end + +function womylogpdf(x::Real) + (x - 2) +end + + +function wologpdf_test(x) + return womylogpdf(x) +end + +@testset "Ensure writeonly deduction combines with capture" begin + res = Enzyme.autodiff(Enzyme.Forward, wologpdf_test, Duplicated([0.5], [0.7])) + @test res[1] ≈ [0.7] +end + euroad′(x) = first(autodiff(Reverse, euroad, Active, Active(x)))[1] @test euroad(0.5) ≈ -log(0.5) # -log(1-x) From c83fcf8960abfba435752dba318a883da4d752c6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 8 Jul 2024 19:29:20 -0400 Subject: [PATCH 167/495] Explore removal of triangular rule (#1614) * Explore removal of triangular rule * Test triangular solve * restore for complex * Update Project.toml * Update runtests.jl * Update Project.toml --- Project.toml | 2 +- src/internal_rules.jl | 94 ++---------------------------------------- test/internal_rules.jl | 32 ++++++++++++++ test/runtests.jl | 13 +++--- 4 files changed, 43 insertions(+), 98 deletions(-) diff --git a/Project.toml b/Project.toml index fad09d3ef0..b5f6f50d24 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.5" -Enzyme_jll = "0.0.131" +Enzyme_jll = "0.0.133" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" ObjectFile = "0.4" diff --git a/src/internal_rules.jl b/src/internal_rules.jl index fb3a60954f..fd18bb0261 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -458,10 +458,10 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, end const EnzymeTriangulars = Union{ - UpperTriangular, - LowerTriangular, - UnitUpperTriangular, - UnitLowerTriangular + UpperTriangular{<:Complex}, + LowerTriangular{<:Complex}, + UnitUpperTriangular{<:Complex}, + UnitLowerTriangular{<:Complex} } function EnzymeRules.augmented_primal( @@ -803,89 +803,3 @@ function EnzymeRules.forward(func::Const{typeof(ldiv!)}, end end end - - -# y=inv(A) B -# dA −= z y^T -# dB += z, where z = inv(A^T) dy -# -> -# -# B(out)=inv(A) B(in) -# dA −= z B(out)^T -# dB = z, where z = inv(A^T) dB -# function EnzymeRules.augmented_primal( -# config, -# func::Const{typeof(ldiv!)}, -# RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}}, -# -# A::Annotation{<:Cholesky}, -# B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; -# kwargs... -# ) -# func.val(A.val, B.val; kwargs...) -# -# cache_Bout = if !isa(A, Const) && !isa(B, Const) -# if EnzymeRules.overwritten(config)[3] -# copy(B.val) -# else -# B.val -# end -# else -# nothing -# end -# -# cache_A = if !isa(B, Const) -# if EnzymeRules.overwritten(config)[2] -# copy(A.val) -# else -# A.val -# end -# else -# nothing -# end -# -# primal = if EnzymeRules.needs_primal(config) -# B.val -# else -# nothing -# end -# -# shadow = if EnzymeRules.needs_shadow(config) -# B.dval -# else -# nothing -# end -# -# return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_Bout)) -# end -# -# function EnzymeRules.reverse( -# config, -# func::Const{typeof(ldiv!)}, -# dret, -# cache, -# A::Annotation{<:Cholesky}, -# B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; -# kwargs... -# ) -# if !isa(B, Const) -# -# (cache_A, cache_Bout) = cache -# -# for b in 1:EnzymeRules.width(config) -# -# dB = EnzymeRules.width(config) == 1 ? B.dval : B.dval[b] -# -# # dB = z, where z = inv(A^T) dB -# # dA −= z B(out)^T -# -# func.val(cache_A, dB; kwargs...) -# if !isa(A, Const) -# dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] -# mul!(dA.factors, dB, transpose(cache_Bout), -1, 1) -# end -# end -# end -# -# return (nothing, nothing) -# end diff --git a/test/internal_rules.jl b/test/internal_rules.jl index a5479629a0..7cc5c07321 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -135,6 +135,38 @@ end @test dA ≈ (-z * transpose(y)) end +function tr_solv(A, B, uplo, trans, diag, idx) + B = copy(B) + LAPACK.trtrs!(uplo, trans, diag, A, B) + return @inbounds B[idx] +end + + +@testset "Reverse triangular solve" begin + A = [0.7550523937508613 0.7979976952197996 0.29318222271218364; 0.4416768066117529 0.4335305304334933 0.8895389673238051; 0.07752980210005678 0.05978245503334367 0.4504482683752542] + B = [0.10527381151977078 0.5450388247476627 0.3179106723232359 0.43919576779182357 0.20974326586875847; 0.7551160501548224 0.049772782182839426 0.09284926395551141 0.07862188927391855 0.17346407477062986; 0.6258040138863172 0.5928022963567454 0.24251650865340169 0.6626410383247967 0.32752198021506784] + for idx in 1:15 + for uplo in ('L', 'U') + for diag in ('N', 'U') + for trans in ('N', 'T') + dA = zero(A) + dB = zero(B) + Enzyme.autodiff(Reverse, tr_solv, Duplicated(A, dA), Duplicated(B, dB), Const(uplo),Const(trans), Const(diag), Const(idx)) + fA = FiniteDifferences.grad(central_fdm(5, 1), A->tr_solv(A, B, uplo, trans, diag, idx), A)[1] + fB = FiniteDifferences.grad(central_fdm(5, 1), B->tr_solv(A, B, uplo, trans, diag, idx), B)[1] + + if max(abs.(dA)...) >= 1e-10 || max(abs.(fA)...) >= 1e-10 + @test dA ≈ fA + end + if max(abs.(dB)...) >= 1e-10 || max(abs.(fB)...) >= 1e-10 + @test dB ≈ fB + end + end + end + end + end +end + function chol_lower0(x) c = copy(x) C, info = LinearAlgebra.LAPACK.potrf!('L', c) diff --git a/test/runtests.jl b/test/runtests.jl index 9bebab68b0..610284bfb8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -718,6 +718,12 @@ function euroad(f::T) where T return g end +euroad′(x) = first(autodiff(Reverse, euroad, Active, Active(x)))[1] + +@test euroad(0.5) ≈ -log(0.5) # -log(1-x) +@test euroad′(0.5) ≈ 2.0 # d/dx -log(1-x) = 1/(1-x) +test_scalar(euroad, 0.5) +end @noinline function womylogpdf(X::AbstractArray{<:Real}) map(womylogpdf, X) end @@ -736,13 +742,6 @@ end @test res[1] ≈ [0.7] end -euroad′(x) = first(autodiff(Reverse, euroad, Active, Active(x)))[1] - -@test euroad(0.5) ≈ -log(0.5) # -log(1-x) -@test euroad′(0.5) ≈ 2.0 # d/dx -log(1-x) = 1/(1-x) -test_scalar(euroad, 0.5) -end - @testset "Nested AD" begin tonest(x,y) = (x + y)^2 From ff9d320cf694ca2477517ef47f9cdcbf364b8b1d Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 11 Jul 2024 19:18:17 -0400 Subject: [PATCH 168/495] Handle xlogy limit (#1615) * Handle xlogy limit * with test * fixup --- Project.toml | 4 ++ ext/EnzymeLogExpFunctionsExt.jl | 10 +++ src/compiler.jl | 107 ++++++++++++++++---------------- src/compiler/interpreter.jl | 60 +----------------- test/Project.toml | 1 + test/ext/logexpfunctions.jl | 14 +++++ test/runtests.jl | 1 + 7 files changed, 86 insertions(+), 111 deletions(-) create mode 100644 ext/EnzymeLogExpFunctionsExt.jl create mode 100644 test/ext/logexpfunctions.jl diff --git a/Project.toml b/Project.toml index b5f6f50d24..625486b9a2 100644 --- a/Project.toml +++ b/Project.toml @@ -18,11 +18,13 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] EnzymeChainRulesCoreExt = "ChainRulesCore" +EnzymeLogExpFunctionsExt = "LogExpFunctions" EnzymeSpecialFunctionsExt = "SpecialFunctions" EnzymeStaticArraysExt = "StaticArrays" @@ -33,6 +35,7 @@ EnzymeCore = "0.7.5" Enzyme_jll = "0.0.133" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" +LogExpFunctions = "0.3" ObjectFile = "0.4" Preferences = "1.4" SpecialFunctions = "1, 2" @@ -41,5 +44,6 @@ julia = "1.6" [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/ext/EnzymeLogExpFunctionsExt.jl b/ext/EnzymeLogExpFunctionsExt.jl new file mode 100644 index 0000000000..b1189170c1 --- /dev/null +++ b/ext/EnzymeLogExpFunctionsExt.jl @@ -0,0 +1,10 @@ +module EnzymeLogExpFunctionsExt + +using LogExpFunctions +using Enzyme + +function __init__() + Enzyme.Compiler.known_ops[typeof(LogExpFunctions.xlogy)] = (:xlogy_jl, 2, nothing) +end + +end diff --git a/src/compiler.jl b/src/compiler.jl index 4090bc304c..974c1f1fd9 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -101,6 +101,58 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( @static if VERSION >= v"1.8.0" known_ops[typeof(Base.fma_emulated)] = (:fma, 3, nothing) end +@inline function find_math_method(@nospecialize(func), sparam_vals) + if func ∈ keys(known_ops) + name, arity, toinject = known_ops[func] + Tys = (Float32, Float64) + + if length(sparam_vals) == arity + T = first(sparam_vals) + legal = T ∈ Tys + + if legal + if name == :ldexp + if !(sparam_vals[2] <: Integer) + legal = false + end + elseif name == :pow + if sparam_vals[2] <: Integer + name = :powi + elseif sparam_vals[2] != T + legal = false + end + elseif name == :jl_rem2pi + else + if !all(==(T), sparam_vals) + legal = false + end + end + end + if legal + return name, toinject, T + end + end + end + + if func ∈ keys(cmplx_known_ops) + name, arity, toinject = cmplx_known_ops[func] + Tys = (Complex{Float32}, Complex{Float64}) + if length(sparam_vals) == arity + T = first(sparam_vals) + legal = T ∈ Tys + + if legal + if !all(==(T), sparam_vals) + legal = false + end + end + if legal + return name, toinject, T + end + end + end + return nothing, nothing, nothing +end const nofreefns = Set{String}(( "ijl_f_isdefined", "jl_f_isdefined", @@ -5621,61 +5673,8 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end continue end - - @inline function find_math_method() - if func ∈ keys(known_ops) - name, arity, toinject = known_ops[func] - Tys = (Float32, Float64) - - if length(sparam_vals) == arity - T = first(sparam_vals) - legal = T ∈ Tys - - if legal - if name == :ldexp - if !(sparam_vals[2] <: Integer) - legal = false - end - elseif name == :pow - if sparam_vals[2] <: Integer - name = :powi - elseif sparam_vals[2] != T - legal = false - end - elseif name == :jl_rem2pi - else - if !all(==(T), sparam_vals) - legal = false - end - end - end - if legal - return name, toinject, T - end - end - end - - if func ∈ keys(cmplx_known_ops) - name, arity, toinject = cmplx_known_ops[func] - Tys = (Complex{Float32}, Complex{Float64}) - if length(sparam_vals) == arity - T = first(sparam_vals) - legal = T ∈ Tys - - if legal - if !all(==(T), sparam_vals) - legal = false - end - end - if legal - return name, toinject, T - end - end - end - return nothing, nothing, nothing - end - name, toinject, T = find_math_method() + name, toinject, T = find_math_method(func, sparam_vals) if name === nothing continue end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 95ff12a422..e1652c5895 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -108,65 +108,11 @@ function is_primitive_func(@nospecialize(TT)) if ft == typeof(Enzyme.pmap) return true end - if ft === typeof(Base.rem2pi) - if TT <: Tuple{ft, Float32, <:Any} || TT <: Tuple{ft, Float64, <:Any} || TT <: Tuple{ft, Float16, <:Any} - return true - end - end - - if ft == typeof(Base.inv) || ft == typeof(Base.sqrt) - if TT <: Tuple{ft, Complex{Float32}} || TT <: Tuple{ft, Complex{Float64}} - return true - end - end - - @static if VERSION >= v"1.9-" - if ft === typeof(Base.rem) - if TT <: Tuple{ft, Float32, Float32} || TT <: Tuple{ft, Float64, Float64} - return true - end - end + match = Enzyme.Compiler.find_math_method(ft, TT.parameters[2:end])[1] + if match !== nothing + return true end - if ft === typeof(Base.cbrt) || ft === typeof(Base.sin) || ft === typeof(Base.cos) || - ft === typeof(Base.sinc) || - ft === typeof(Base.tan) || ft === typeof(Base.exp) || ft === typeof(Base.FastMath.exp_fast) || - ft === typeof(Base.exp10) || - ft === typeof(Base.exp2) || - ft === typeof(Base.expm1) || - ft === typeof(Base.log) || ft === typeof(Base.FastMath.log) || - ft === typeof(Base.log1p) || - ft === typeof(Base.log2) || - ft === typeof(Base.log10) || - ft === typeof(Base.asin) || - ft === typeof(Base.acos) || - ft === typeof(Base.atan) || - ft === typeof(Base.sinpi) || - ft === typeof(Base.cospi) || - ft === typeof(Base.sinh) || ft === typeof(Base.FastMath.sinh_fast) || - ft === typeof(Base.cosh) || ft === typeof(Base.FastMath.cosh_fast) || - ft === typeof(Base.tanh) || ft === typeof(Base.FastMath.tanh_fast) || - ft === typeof(Base.sqrt) || ft === typeof(Base.sincos) || ft === typeof(Base.sincospi) - if TT <: Tuple{ft, Float32} || TT <: Tuple{ft, Float64} || TT <: Tuple{ft, Float16} - return true - end - end -@static if VERSION < v"1.8.0" -else - if ft === typeof(Base.fma_emulated) - if TT <: Tuple{ft, Float32, Float32, Float32} || TT <: Tuple{ft, Float64, Float64, Float64} - return true - end - end -end - if ft === typeof(Base.:^) || ft === typeof(Base.atan) - if TT <: Tuple{ft, Float32, Float32} || TT <: Tuple{ft, Float64, Float64} - return true - end - if TT <: Tuple{ft, Float32, <:Integer} || TT <: Tuple{ft, Float64, <:Integer} - return true - end - end # FIXME(@wsmoses): For which types should we not inline? if ft === typeof(Base.wait) || ft === typeof(Base._wait) || ft === typeof(Base.enq_work) || ft === typeof(Base.Threads.threadid) || ft == typeof(Base.Threads.nthreads) || diff --git a/test/Project.toml b/test/Project.toml index bf44952c27..5c8286d1af 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,6 +11,7 @@ InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/test/ext/logexpfunctions.jl b/test/ext/logexpfunctions.jl new file mode 100644 index 0000000000..69ee7f2e73 --- /dev/null +++ b/test/ext/logexpfunctions.jl @@ -0,0 +1,14 @@ +using LogExpFunctions + + +xlogydiff(x) = xlogy(x[1], 23.0) +@testset "LogExpFunctions" begin + + x = [0.0] + + grad_forward = Enzyme.gradient(Enzyme.Forward, xlogydiff, x) + grad_reverse = Enzyme.gradient(Enzyme.Reverse, xlogydiff, x) + + @test grad_forward[1] ≈ log(23.0) + @test grad_reverse[1] ≈ log(23.0) +end diff --git a/test/runtests.jl b/test/runtests.jl index 610284bfb8..3570c06e52 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3511,6 +3511,7 @@ end @testset "ChainRulesCore ext" begin include("ext/chainrulescore.jl") end + include("ext/logexpfunctions.jl") end From 3eed408e38819ed5674b31f7f28813c8952caea6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 12 Jul 2024 10:35:17 -0400 Subject: [PATCH 169/495] Simplify no derivative message (#1632) --- src/compiler.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 974c1f1fd9..9e8781d871 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1825,7 +1825,7 @@ end function Base.showerror(io::IO, ece::NoDerivativeException) print(io, "Enzyme compilation failed.\n") - if ece.ir !== nothing && !occursin("No create nofree of empty function", ece.msg) + if ece.ir !== nothing print(io, "Current scope: \n") print(io, ece.ir) end @@ -1996,6 +1996,9 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end if errtype == API.ET_NoDerivative + if occursin("No create nofree of empty function", msg) || occursin("No forward mode derivative found for", msg) || occursin("No augmented forward mode derivative found for", msg) || occursin("No reverse pass found", msg) + ir = nothing + end exc = NoDerivativeException(msg, ir, bt) if B != C_NULL B = IRBuilder(B) From b3e1ac95f1039c4f4d6e89a5ae9b114006b43ebb Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Sat, 13 Jul 2024 21:25:52 -0400 Subject: [PATCH 170/495] Add abstypeof to del_end warning --- src/rules/llvmrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 91d6b44820..f3ece69053 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -983,7 +983,7 @@ function jl_array_del_end_rev(B, orig, gutils, tape) length = LLVM.mul!(B, len, elSize) - GPUCompiler.@safe_warn "TODO reverse jl_array_del_end zero-set used memset rather than runtime type" + GPUCompiler.@safe_warn "TODO reverse jl_array_del_end zero-set used memset rather than runtime type of $(abs_typeof(origops[1]))" toset = get_array_data(B, anti) toset = gep!(B, i8, toset, LLVM.Value[length]) LLVM.memset!(B, toset, LLVM.ConstantInt(i8, 0, false), elSize, algn) From bed3d6688c985840605001e748aa0b93a16ed498 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 14 Jul 2024 02:00:32 -0400 Subject: [PATCH 171/495] Mark type assert as inactive (#1639) --- src/compiler.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 9e8781d871..0a84db122c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -155,6 +155,7 @@ end end const nofreefns = Set{String}(( + "ijl_typeassert", "jl_typeassert", "ijl_f_isdefined", "jl_f_isdefined", "ijl_field_index", "jl_field_index", "ijl_specializations_get_linfo", "jl_specializations_get_linfo", @@ -237,6 +238,7 @@ const nofreefns = Set{String}(( )) const inactivefns = Set{String}(( + "ijl_typeassert", "jl_typeassert", "ijl_f_isdefined", "jl_f_isdefined", "ijl_field_index", "jl_field_index", "ijl_specializations_get_linfo", "jl_specializations_get_linfo", From 38c2d31f86dd8194d764a574142d06d252da4984 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 14 Jul 2024 10:17:00 -0400 Subject: [PATCH 172/495] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 625486b9a2..b0c158155b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.22" +version = "0.12.23" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -32,7 +32,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.5" -Enzyme_jll = "0.0.133" +Enzyme_jll = "0.0.134" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" LogExpFunctions = "0.3" From 4a52bdf2417ffc25fe01543797877358023298b2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 14 Jul 2024 12:43:17 -0400 Subject: [PATCH 173/495] Add more inst type info (#1640) * Add more inst type info * non-box * Update runtests.jl * Let 1.8 be broken --- src/absint.jl | 40 +++++++++++++++++++++++++++++++++++++++- src/compiler.jl | 40 +++++++++++++++++++++++++--------------- src/rules/llvmrules.jl | 2 +- test/runtests.jl | 28 ++++++++++++++++++++++++++++ 4 files changed, 93 insertions(+), 17 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 5d9c595059..6462162bc6 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -312,8 +312,46 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ end end end - end + + if isa(arg, LLVM.ExtractValueInst) + larg = operands(arg)[1] + indptrs = LLVM.API.LLVMGetIndices(arg) + numind = LLVM.API.LLVMGetNumIndices(arg) + offset = Cuint[unsafe_load(indptrs, i) for i in 1:numind] + if isa(larg, LLVM.Argument) || isa(larg, LLVM.ExtractValueInst) + typ, byref = if isa(larg, LLVM.Argument) + f = LLVM.Function(LLVM.API.LLVMGetParamParent(larg)) + idx = only([i for (i, v) in enumerate(LLVM.parameters(f)) if v == larg]) + enzyme_extract_parm_type(f, idx, #=error=#false) + else + found, typ = abs_typeof(larg, partial) + if !found + return (false, nothing) + end + (typ, GPUCompiler.BITS_VALUE) + end + if typ !== nothing && byref == GPUCompiler.BITS_VALUE + for ind in offset + @assert Base.isconcretetype(typ) + cnt = 0 + for i in 1:fieldcount(typ) + styp = fieldtype(typ, i) + if isghostty(styp) + continue + end + if cnt == ind + typ = styp + break + end + cnt+=1 + end + end + return (true, typ) + end + end + end + if isa(arg, LLVM.Argument) f = LLVM.Function(LLVM.API.LLVMGetParamParent(arg)) diff --git a/src/compiler.jl b/src/compiler.jl index 0a84db122c..ac0760abb6 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4913,7 +4913,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function push!(wrapper_args, ptr) push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(arg.typ, ctx, dl, seen)))) push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))))) - push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) + push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_VALUE)))) elseif arg.arg_i in raisedArgs wrapparm = load!(builder, convert(LLVMType, arg.typ), wrapparm) ctx = LLVM.context(wrapparm) @@ -5802,13 +5802,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; dl = string(LLVM.datalayout(mod)) ctx = LLVM.context(mod) for f in functions(mod), bb in blocks(f), inst in instructions(bb) - if !isa(inst, LLVM.CallInst) - continue - end + fn = isa(inst, LLVM.CallInst) ? LLVM.called_operand(inst) : nothing - fn = LLVM.called_operand(inst) - - if !API.HasFromStack(inst) && (!isa(fn, LLVM.Function) || isempty(blocks(fn))) + if !API.HasFromStack(inst) && isa(inst, LLVM.CallInst) && (!isa(fn, LLVM.Function) || isempty(blocks(fn))) legal, source_typ = abs_typeof(inst) codegen_typ = value_type(inst) if legal @@ -5834,17 +5830,27 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; codegen_typ end - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_type", string(typetree(typ, ctx, dl, seen)))) + if isa(inst, LLVM.CallInst) + LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_type", string(typetree(typ, ctx, dl, seen)))) + else + metadata(inst)["enzyme_type"] = to_md(typetree(arg.typ, ctx, dl, seen), ctx) + end elseif codegen_typ == T_prjlvalue - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_type", "{[-1]:Pointer}")) + if isa(inst, LLVM.CallInst) + LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_type", "{[-1]:Pointer}")) + else + metadata(inst)["enzyme_type"] = to_md(typetree(Ptr{Cvoid}, ctx, dl, seen), ctx) + end end end - if !isa(fn, LLVM.Function) - continue - end - if length(blocks(fn)) != 0 - continue + if isa(inst, LLVM.CallInst) + if !isa(fn, LLVM.Function) + continue + end + if length(blocks(fn)) != 0 + continue + end end ty = value_type(inst) if ty == LLVM.VoidType() @@ -5858,7 +5864,11 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if !guaranteed_const_nongen(jTy, world) continue end - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_inactive")) + if isa(inst, LLVM.CallInst) + LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_inactive")) + else + metadata(inst)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) + end end diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index f3ece69053..b9606cf9f8 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -983,7 +983,7 @@ function jl_array_del_end_rev(B, orig, gutils, tape) length = LLVM.mul!(B, len, elSize) - GPUCompiler.@safe_warn "TODO reverse jl_array_del_end zero-set used memset rather than runtime type of $(abs_typeof(origops[1]))" + GPUCompiler.@safe_warn "TODO reverse jl_array_del_end zero-set used memset rather than runtime type of $(abs_typeof(origops[1])) in $(string(origops[1]))" toset = get_array_data(B, anti) toset = gep!(B, i8, toset, LLVM.Value[length]) LLVM.memset!(B, toset, LLVM.ConstantInt(i8, 0, false), elSize, algn) diff --git a/test/runtests.jl b/test/runtests.jl index 3570c06e52..15b3c3ddf4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3132,6 +3132,34 @@ end end end +@static if VERSION < v"1.8-" || VERSION >= v"1.9-" +@inline extract_bc(bc, ::Val{:north}) = (bc.north) +@inline extract_bc(bc, ::Val{:top}) = (bc.top) + +function permute_boundary_conditions(boundary_conditions) + sides = [:top, :north] # changing the order of these actually changes the error + boundary_conditions = Tuple(extract_bc(boundary_conditions, Val(side)) for side in sides) + + return nothing +end + +@testset "Extract abstype" begin + + parameters = (a = 1, b = 0.1) + + bc = (north=1, top=tuple(parameters, tuple(:c))) + d_bc = Enzyme.make_zero(bc) + Enzyme.API.looseTypeAnalysis!(true) + + dc²_dκ = autodiff(Enzyme.Reverse, + permute_boundary_conditions, + Duplicated(bc, d_bc)) + + Enzyme.API.looseTypeAnalysis!(false) +end +end + + @testset "Static activity" begin struct Test2{T} From ff11b44bf8730edc8e7272b0aac71694939f4a21 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 14 Jul 2024 13:14:28 -0400 Subject: [PATCH 174/495] Improve names on alloca lowering convention --- src/compiler.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index ac0760abb6..038645a253 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4869,7 +4869,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function dl = string(LLVM.datalayout(LLVM.parent(entry_f))) if sret if !in(0, parmsRemoved) - sretPtr = alloca!(builder, eltype(value_type(parameters(entry_f)[1]))) + sretPtr = alloca!(builder, eltype(value_type(parameters(entry_f)[1])), "innersret") ctx = LLVM.context(entry_f) if RetActivity <: Const metadata(sretPtr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) @@ -4879,7 +4879,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function push!(wrapper_args, sretPtr) end if returnRoots && !in(1, parmsRemoved) - retRootPtr = alloca!(builder, eltype(value_type(parameters(entry_f)[1+sret]))) + retRootPtr = alloca!(builder, eltype(value_type(parameters(entry_f)[1+sret])), "innerreturnroots") # retRootPtr = alloca!(builder, parameters(wrapper_f)[1]) push!(wrapper_args, retRootPtr) end @@ -4898,7 +4898,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if !isa(ty, LLVM.PointerType) throw(AssertionError("ty is not a LLVM.PointerType: entry_f = $(entry_f), args = $(args), parm = $(parm), ty = $(ty)")) end - ptr = alloca!(builder, eltype(ty)) + ptr = alloca!(builder, eltype(ty), LLVM.name(parm)*".innerparm") if TT !== nothing && TT.parameters[arg.arg_i] <: Const metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end From d527ffa2ff06a1e59e46c6bce536b7fca2567f6c Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 14 Jul 2024 13:38:30 -0400 Subject: [PATCH 175/495] First stab at nongenabi (#1575) * First stab at nongenabi * Update compiler.jl * completely rm world * fix * Simple tests * fix * no world in compiler job * fix --- lib/EnzymeCore/src/EnzymeCore.jl | 8 ++- src/Enzyme.jl | 73 +++++++++++++++++-------- src/compiler.jl | 94 +++++++++++++++++--------------- src/rules/jitrules.jl | 18 ++++-- src/rules/parallelrules.jl | 6 +- test/abi.jl | 15 +++++ 6 files changed, 139 insertions(+), 75 deletions(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index d82072724f..fa0a31d44a 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -4,7 +4,7 @@ export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWi export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed export MixedDuplicated, BatchMixedDuplicated -export DefaultABI, FFIABI, InlineABI +export DefaultABI, FFIABI, InlineABI, NonGenABI export BatchDuplicatedFunc function batch_size end @@ -196,6 +196,12 @@ struct FFIABI <: ABI end Inlining function call ABI. """ struct InlineABI <: ABI end +""" + struct NonGenABI <: ABI + +Non-generated function ABI. +""" +struct NonGenABI <: ABI end const DefaultABI = FFIABI """ diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 7c283d0d1e..05ba2ae4c0 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -5,8 +5,8 @@ import EnzymeCore import EnzymeCore: Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal -import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI -export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI +import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI +export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI import EnzymeCore: BatchDuplicatedFunc export BatchDuplicatedFunc @@ -239,7 +239,6 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) ModifiedBetween = Val(falses_from_args(Nargs+1)) tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - world = codegen_world_age(Core.Typeof(f.val), tt) rt = if A isa UnionAll Core.Compiler.return_type(f.val, tt) @@ -247,9 +246,15 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) eltype(A) end + opt_mi = if RABI <: NonGenABI + Compiler.fspec(eltype(FA), tt′) + else + Val(codegen_world_age(Core.Typeof(f.val), tt)) + end + if A <: Active if (!allocatedinline(rt) || rt isa Union) && rt != Union{} - forward, adjoint = Enzyme.Compiler.thunk(Val(world), FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI) + forward, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI) res = forward(f, args...) tape = res[1] if ReturnPrimal @@ -279,7 +284,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) args = seed_complex_args(seen, seen2, args...) tt′ = vaTypeof(args...) - thunk = Enzyme.Compiler.thunk(Val(world), typeof(f), A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) + thunk = Enzyme.Compiler.thunk(opt_mi, typeof(f), A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) results = thunk(f, args..., (rt(0), rt(1), rt(im))) @@ -301,7 +306,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) throw(ErrorException("Reverse-mode Active Complex return is ambiguous and requires more information to specify the desired result. See https://enzyme.mit.edu/julia/stable/faq/#Complex-numbers for more details.")) end - thunk = Enzyme.Compiler.thunk(Val(world), FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) + thunk = Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) if A <: Active args = (args..., Compiler.default_adjoint(rt)) @@ -410,9 +415,14 @@ f(x) = x*x ModifiedBetween = Val(falses_from_args(Nargs+1)) tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - world = codegen_world_age(Core.Typeof(f.val), tt) - thunk = Enzyme.Compiler.thunk(Val(world), FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), + opt_mi = if RABI <: NonGenABI + Compiler.fspec(eltype(FA), tt′) + else + Val(codegen_world_age(Core.Typeof(f.val), tt)) + end + + thunk = Enzyme.Compiler.thunk(opt_mi, FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI) thunk(f, args...) end @@ -606,13 +616,17 @@ result, ∂v, ∂A end tt = Tuple{map(eltype, args)...} - - world = codegen_world_age(eltype(FA), tt) if !(A <: Const) @assert ReturnShadow end - Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) + tt′ = Tuple{args...} + opt_mi = if RABI <: NonGenABI + Compiler.fspec(eltype(FA), tt′) + else + Val(codegen_world_age(eltype(FA), tt)) + end + Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) end """ @@ -671,10 +685,14 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated ModifiedBetween = Val(falses_from_args(Nargs+1)) tt = Tuple{map(eltype, args)...} - - world = codegen_world_age(eltype(FA), tt) - Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI) + tt′ = Tuple{args...} + opt_mi = if RABI <: NonGenABI + Compiler.fspec(eltype(FA), tt′) + else + Val(codegen_world_age(eltype(FA), tt)) + end + Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI) end @inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} @@ -698,8 +716,12 @@ end TT = Tuple{args...} primal_tt = Tuple{map(eltype, args)...} - world = codegen_world_age(eltype(FA), primal_tt) - nondef = Enzyme.Compiler.thunk(Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) + opt_mi = if RABI <: NonGenABI + Compiler.fspec(eltype(FA), TT) + else + Val(codegen_world_age(eltype(FA), primal_tt)) + end + nondef = Enzyme.Compiler.thunk(opt_mi, FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) TapeType = EnzymeRules.tape_type(nondef[1]) return TapeType end @@ -1220,12 +1242,15 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2)) tt′ = Tuple{BatchDuplicated{Core.Typeof(x), chunk}} tt = Tuple{Core.Typeof(x)} - world = codegen_world_age(Core.Typeof(f), tt) rt = Core.Compiler.return_type(f, tt) ModifiedBetween = Val((false, false)) FA = Const{Core.Typeof(f)} - World = Val(nothing) - primal, adjoint = Enzyme.Compiler.thunk(Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI) + opt_mi = if RABI <: NonGenABI + Compiler.fspec(eltype(FA), tt′) + else + Val(codegen_world_age(Core.Typeof(f), tt)) + end + primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI) if num * chunk == n_out_val last_size = chunk @@ -1233,7 +1258,7 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2)) else last_size = n_out_val - (num-1)*chunk tt′ = Tuple{BatchDuplicated{Core.Typeof(x), last_size}} - primal2, adjoint2 = Enzyme.Compiler.thunk(Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI) + primal2, adjoint2 = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI) end tmp = ntuple(num) do i @@ -1260,11 +1285,15 @@ end @assert !ReturnPrimal tt′ = Tuple{Duplicated{Core.Typeof(x)}} tt = Tuple{Core.Typeof(x)} - world = codegen_world_age(Core.Typeof(f), tt) rt = Core.Compiler.return_type(f, tt) ModifiedBetween = Val((false, false)) FA = Const{Core.Typeof(f)} - primal, adjoint = Enzyme.Compiler.thunk(Val(world), FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI) + opt_mi = if RABI <: NonGenABI + Compiler.fspec(eltype(FA), tt′) + else + Val(codegen_world_age(Core.Typeof(f), tt)) + end + primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI) rows = ntuple(n_outs) do i Base.@_inline_meta dx = zero(x) diff --git a/src/compiler.jl b/src/compiler.jl index 038645a253..aa4e36c8c3 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -698,7 +698,7 @@ struct AdjointThunk{PT, FA, RT, TT, Width, TapeType} <: AbstractThunk{FA, RT, TT adjoint::PT end -struct PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World} <: AbstractThunk{FA, RT, TT, Width} +struct PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal} <: AbstractThunk{FA, RT, TT, Width} adjoint::PT end @@ -3226,13 +3226,17 @@ import .Interpreter: isKWCallSignature """ Create the methodinstance pair, and lookup the primal return type. """ -@inline function fspec(@nospecialize(F), @nospecialize(TT), world::Integer) +@inline function fspec(@nospecialize(F), @nospecialize(TT), world::Union{Integer, Nothing}=nothing) # primal function. Inferred here to get return type _tt = (TT.parameters...,) primal_tt = Tuple{map(eltype, _tt)...} - primal = GPUCompiler.methodinstance(F, primal_tt, world) + primal = if world isa Nothing + GPUCompiler.methodinstance(F, primal_tt) + else + GPUCompiler.methodinstance(F, primal_tt, world) + end return primal end @@ -6108,8 +6112,8 @@ struct CompileResult{AT, PT} TapeType::Type end -@inline (thunk::PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World})(fn::FA, args...) where {PT, FA, RT, TT, Width, ReturnPrimal, World} = -enzyme_call(Val(false), thunk.adjoint, PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) +@inline (thunk::PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, RT, TT, Width, ReturnPrimal, World} = +enzyme_call(Val(false), thunk.adjoint, PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) @inline (thunk::CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal} = enzyme_call(Val(false), thunk.adjoint, CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) @@ -6703,7 +6707,7 @@ function _thunk(job, postopt::Bool=true) # Run post optimization pipeline if postopt - if job.config.params.ABI <: FFIABI + if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI post_optimze!(mod, JIT.get_tm()) else propagate_returned!(mod) @@ -6742,13 +6746,16 @@ end @inline remove_innerty(::Type{<:MixedDuplicated}) = MixedDuplicated @inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated -@inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} +@inline function thunkbase(mi::Core.MethodInstance, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} JuliaContext() do ctx - mi = fspec(eltype(FA), TT, World) target = Compiler.EnzymeTarget() params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) - tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) + tmp_job = if World isa Nothing + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) + else + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) + end interp = GPUCompiler.get_interpreter(tmp_job) @@ -6759,32 +6766,35 @@ end run_enzyme = true - if rrt == Union{} + A2 = if rrt == Union{} run_enzyme = false - A = Const + Const + else + A end - if run_enzyme && !(A <: Const) && guaranteed_const_nongen(rrt, World) - estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant" - return quote - error($estr) - end + if run_enzyme && !(A2 <: Const) && guaranteed_const_nongen(rrt, World) + estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant" + return error(estr) end rt2 = if !run_enzyme Const{rrt} - elseif A isa UnionAll - A{rrt} + elseif A2 isa UnionAll + A2{rrt} else @assert A isa DataType # Can we relax this condition? # @assert eltype(A) == rrt - A + A2 end params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, run_enzyme, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) - job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) - + job = if World isa Nothing + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) + else + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) + end # We need to use primal as the key, to lookup the right method # but need to mixin the hash of the adjoint to avoid cache collisions # This is counter-intuitive since we would expect the cache to be split @@ -6794,46 +6804,42 @@ end compile_result = cached_compilation(job) if !run_enzyme - ErrT = PrimalErrorThunk{typeof(compile_result.adjoint), FA, rt2, TT, width, ReturnPrimal, World} + ErrT = PrimalErrorThunk{typeof(compile_result.adjoint), FA, rt2, TT, width, ReturnPrimal} if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient - return quote - Base.@_inline_meta - ($ErrT($(compile_result.adjoint)), $ErrT($(compile_result.adjoint))) - end + return (ErrT(compile_result.adjoint), ErrT(compile_result.adjoint)) else - return quote - Base.@_inline_meta - $ErrT($(compile_result.adjoint)) - end + return ErrT(compile_result.adjoint) end elseif Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient TapeType = compile_result.TapeType AugT = AugmentedForwardThunk{typeof(compile_result.primal), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal, TapeType} AdjT = AdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, TapeType} - return quote - Base.@_inline_meta - augmented = $AugT($(compile_result.primal)) - adjoint = $AdjT($(compile_result.adjoint)) - (augmented, adjoint) - end + return (AugT(compile_result.primal), AdjT(compile_result.adjoint)) elseif Mode == API.DEM_ReverseModeCombined CAdjT = CombinedAdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal} - return quote - Base.@_inline_meta - $CAdjT($(compile_result.adjoint)) - end + return CAdjT(compile_result.adjoint) elseif Mode == API.DEM_ForwardMode FMT = ForwardModeThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal} - return quote - Base.@_inline_meta - $FMT($(compile_result.adjoint)) - end + return FMT(compile_result.adjoint) else @assert false end end end +@inline function thunk(mi::Core.MethodInstance, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, ABI} + return thunkbase(mi, Val(#=World=#nothing), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI) +end + +@inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} + mi = fspec(eltype(FA), TT, World) + res = thunkbase(mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI) + return quote + Base.@_inline_meta + return $(res) + end +end + import GPUCompiler: deferred_codegen_jobs @generated function deferred_codegen(::Val{World}, ::Type{FA}, ::Val{TT}, ::Val{A},::Val{Mode}, diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 763e815d49..e99d118772 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -204,7 +204,8 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) end world = codegen_world_age(FT, tt) - forward = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val($ModifiedBetween), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + opt_mi = Val(world) + forward = thunk(opt_mi, (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val($ModifiedBetween), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) res = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) @@ -304,7 +305,8 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) end world = codegen_world_age(FT, tt) - forward, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT}, + opt_mi = Val(world) + forward, adjoint = thunk(opt_mi, dupClosure0 ? Duplicated{FT} : Const{FT}, annotationA, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) @@ -439,7 +441,8 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act world = codegen_world_age(FT, tt) - _, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT}, + opt_mi = Val(world) + _, adjoint = thunk(opt_mi, dupClosure0 ? Duplicated{FT} : Const{FT}, annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) @@ -697,7 +700,8 @@ function fwddiff_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType else Const(f) end - res = thunk(Val(world), FA, annotation, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), + opt_mi = Val(world) + res = thunk(opt_mi, FA, annotation, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI)(fa, args...) return if annotation <: Const ReturnType(allFirst(Val(width+1), res)) @@ -830,7 +834,8 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} Const(f) end world = codegen_world_age(FT, tt) - forward, adjoint = thunk(Val(world), FA, + opt_mi = Val(world) + forward, adjoint = thunk(opt_mi, FA, annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) forward(fa, args...) @@ -976,7 +981,8 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween else Const(f) end - forward, adjoint = thunk(Val(world), FA, + opt_mi = Val(world) + forward, adjoint = thunk(opt_mi, FA, annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index c13e21c1ce..6aa661feef 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -2,7 +2,8 @@ function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, ::Val{width}) where {FT1, FT2, World, width} FT = Core.Typeof(fn) ghos = guaranteed_const(FT) - forward = thunk(world, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI) + opt_mi = world + forward = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI) ft = ghos ? Const(fn) : Duplicated(fn, dfn) function fclosure() res = forward(ft) @@ -16,7 +17,8 @@ function runtime_newtask_augfwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, # TODO make this AD subcall type stable FT = Core.Typeof(fn) ghos = guaranteed_const(FT) - forward, adjoint = thunk(world, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI) + opt_mi = world + forward, adjoint = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI) ft = ghos ? Const(fn) : Duplicated(fn, dfn) taperef = Ref{Any}() diff --git a/test/abi.jl b/test/abi.jl index d371a7d0a0..93bf471fde 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -9,7 +9,11 @@ using Test res = autodiff(Reverse, f, Const, Const(nothing)) @test res === ((nothing,),) + res = autodiff(ReverseMode{false,NonGenABI, false}(), f, Const, Const(nothing)) + @test res === ((nothing,),) + @test () === autodiff(Forward, f, Const, Const(nothing)) + @test () === autodiff(ForwardMode{NonGenABI}(), f, Const, Const(nothing)) res = autodiff(Reverse, f, Const(nothing)) @test res === ((nothing,),) @@ -18,7 +22,11 @@ using Test res = autodiff_deferred(Reverse, f, Const(nothing)) @test res === ((nothing,),) + res = autodiff_deferred(ReverseMode{false,NonGenABI, false}(), f, Const, Const(nothing)) + @test res === ((nothing,),) + @test () === autodiff_deferred(Forward, f, Const(nothing)) + @test () === autodiff_deferred(ForwardMode{NonGenABI}(), f, Const, Const(nothing)) # ConstType -> Type{Int} res = autodiff(Reverse, f, Const, Const(Int)) @@ -56,10 +64,17 @@ using Test unused(_, y) = y _, res0 = autodiff(Reverse, unused, Active, Const(nothing), Active(2.0))[1] @test res0 ≈ 1.0 + + _, res0 = autodiff(ReverseMode{false, NonGenABI, false}(), unused, Active, Const(nothing), Active(2.0))[1] + @test res0 ≈ 1.0 + res0, = autodiff(Forward, unused, DuplicatedNoNeed, Const(nothing), Duplicated(2.0, 1.0)) @test res0 ≈ 1.0 res0, = autodiff(Forward, unused, DuplicatedNoNeed, Const(nothing), DuplicatedNoNeed(2.0, 1.0)) @test res0 ≈ 1.0 + + res0, = autodiff(ForwardMode{NonGenABI}(), unused, DuplicatedNoNeed, Const(nothing), Duplicated(2.0, 1.0)) + @test res0 ≈ 1.0 _, res0 = autodiff(Reverse, unused, Const(nothing), Active(2.0))[1] @test res0 ≈ 1.0 From 0169959c35be2fd49ef154b3f9a88977eb9bd568 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 14 Jul 2024 13:38:48 -0400 Subject: [PATCH 176/495] Update Project.toml --- lib/EnzymeCore/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 57bba3fd71..e0861dadda 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.7.6" +version = "0.7.7" [compat] Adapt = "3, 4" From 3d5f8e237cb5c5a39b7f53c9bc1ac1478b4e9ccd Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 14 Jul 2024 13:39:09 -0400 Subject: [PATCH 177/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b0c158155b..3200fe8d45 100644 --- a/Project.toml +++ b/Project.toml @@ -31,7 +31,7 @@ EnzymeStaticArraysExt = "StaticArrays" [compat] CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.7.5" +EnzymeCore = "0.7.7" Enzyme_jll = "0.0.134" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" From 674fd0d21d394019ef91888ba6df9546e1152178 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 14 Jul 2024 20:23:30 -0400 Subject: [PATCH 178/495] Cleanup objid attributes (#1642) --- src/compiler.jl | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index aa4e36c8c3..8f0058ef20 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -155,6 +155,7 @@ end end const nofreefns = Set{String}(( + "ijl_gc_run_pending_finalizers", "jl_gc_run_pending_finalizers", "ijl_typeassert", "jl_typeassert", "ijl_f_isdefined", "jl_f_isdefined", "ijl_field_index", "jl_field_index", @@ -3312,7 +3313,7 @@ function annotate!(mod, mode) end end - for fname in ("julia.typeof",) + for fname in ("julia.typeof", "jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id") if haskey(fns, fname) fn = fns[fname] if LLVM.version().major <= 15 @@ -3487,17 +3488,6 @@ function annotate!(mod, mode) end end - for rfn in ("jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id") - if haskey(fns, rfn) - fn = fns[rfn] - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) - else - push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) - end - end - end - # Key of jl_eqtable_get/put is inactive, definitionally for rfn in ("jl_eqtable_get", "ijl_eqtable_get") if haskey(fns, rfn) From f16795c43969267768fada72db9002099bf7afee Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 14 Jul 2024 20:24:21 -0400 Subject: [PATCH 179/495] Cleanup solve rules and bigfloat (#1641) * Cleanup solve rules and bigfloat * fix * Fix aug fwd msg * fix cache ty --- src/compiler.jl | 2 +- src/internal_rules.jl | 13 +++++++++++-- src/typetree.jl | 4 ++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 8f0058ef20..06a681dca7 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1999,7 +1999,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end if errtype == API.ET_NoDerivative - if occursin("No create nofree of empty function", msg) || occursin("No forward mode derivative found for", msg) || occursin("No augmented forward mode derivative found for", msg) || occursin("No reverse pass found", msg) + if occursin("No create nofree of empty function", msg) || occursin("No forward mode derivative found for", msg) || occursin("No augmented forward pass", msg) || occursin("No reverse pass found", msg) ir = nothing end exc = NoDerivativeException(msg, ir, bt) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index fd18bb0261..65933b4237 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -390,11 +390,20 @@ else } end - cache = NamedTuple{(Symbol("1"),Symbol("2"), Symbol("3"), Symbol("4")), Tuple{typeof(res), typeof(dres), UT, typeof(cache_b)}}( + cache = NamedTuple{(Symbol("1"),Symbol("2"), Symbol("3"), Symbol("4")), Tuple{ + eltype(RT), + EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing, + UT, + typeof(cache_b) + }}( (cache_res, dres, cache_A, cache_b) ) - return EnzymeRules.AugmentedReturn{typeof(retres), typeof(dres), typeof(cache)}(retres, dres, cache) + return EnzymeRules.AugmentedReturn{ + EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing, + EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing, + typeof(cache) + }(retres, dres, cache) end function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, A::Annotation{<:Array}, b::Annotation{<:Array}) where RT diff --git a/src/typetree.jl b/src/typetree.jl index 2c846ae49e..2ab6cd4a50 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -111,6 +111,10 @@ function typetree_inner(::Type{Float64}, ctx, dl, seen::TypeTreeTable) return TypeTree(API.DT_Double, -1, ctx) end +function typetree_inner(::Type{BigFloat}, ctx, dl, seen::TypeTreeTable) + return TypeTree() +end + function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where {T<:AbstractFloat} GPUCompiler.@safe_warn "Unknown floating point type" T return TypeTree() From 3de5f617c980c976f333ddbfa7a1e478f46fb4cb Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 14 Jul 2024 22:02:46 -0400 Subject: [PATCH 180/495] Remove unnecessary jl_array_del_end zero-set used memset error (#1643) --- src/rules/llvmrules.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index b9606cf9f8..beecfa5a2f 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -974,16 +974,22 @@ function jl_array_del_end_rev(B, orig, gutils, tape) end args = LLVM.Value[anti, offset] + found, arty = abs_typeof(origops[1]) anti = shadowin - elSize = get_array_elsz(B, anti) - elSize = LLVM.zext!(B, elSize, LLVM.IntType(8*sizeof(Csize_t))) + elSize = if found + LLVM.ConstantInt(Csize_t(sizeof(eltype(arty)))) + else + elSize = LLVM.zext!(B, get_array_elsz(B, anti), LLVM.IntType(8*sizeof(Csize_t))) + end len = get_array_len(B, anti) LLVM.call!(B, fty, delF, args) length = LLVM.mul!(B, len, elSize) - GPUCompiler.@safe_warn "TODO reverse jl_array_del_end zero-set used memset rather than runtime type of $(abs_typeof(origops[1])) in $(string(origops[1]))" + if !found && !(eltype(arty) <: Base.IEEEFloat) + GPUCompiler.@safe_warn "TODO reverse jl_array_del_end zero-set used memset rather than runtime type of $((found, arty)) in $(string(origops[1]))" + end toset = get_array_data(B, anti) toset = gep!(B, i8, toset, LLVM.Value[length]) LLVM.memset!(B, toset, LLVM.ConstantInt(i8, 0, false), elSize, algn) From 40f009e4ce8114c854db3f75d7fef651ec5d589c Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 15 Jul 2024 09:56:57 -0400 Subject: [PATCH 181/495] Update Project.toml (#1646) --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 3200fe8d45..e0f9f3f645 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.23" +version = "0.12.24" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -32,7 +32,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.7" -Enzyme_jll = "0.0.134" +Enzyme_jll = "0.0.135" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" LogExpFunctions = "0.3" From 0382d59ea21696c7776c8d588ddec30a63ca586f Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 15 Jul 2024 10:53:52 -0400 Subject: [PATCH 182/495] Reduce time for enzyme import (#1645) * Reduce time for enzyme import * fix symbol * tryfix * fix * fix * Update compiler.jl --- src/compiler.jl | 884 +++++++++++++++++---------------- src/rules/customrules.jl | 11 +- src/rules/jitrules.jl | 24 +- src/rules/llvmrules.jl | 196 +++++--- src/rules/parallelrules.jl | 30 +- src/rules/typeunstablerules.jl | 43 +- 6 files changed, 634 insertions(+), 554 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 06a681dca7..f53996aed5 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1335,13 +1335,11 @@ function allocate_sret!(B::LLVM.IRBuilder, N) end function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, N) - sret = LLVM.IRBuilder() do B - position!(B, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) - allocate_sret!(B, N) - end + B = LLVM.IRBuilder() + position!(B, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) + allocate_sret!(B, N) end - @inline function EnzymeCore.make_zero(x::FT)::FT where {FT <: AbstractFloat} return Base.zero(x) end @@ -3137,6 +3135,10 @@ function __init__() shadow_alloc_rewrite, Cvoid, (LLVM.API.LLVMValueRef,API.EnzymeGradientUtilsRef))) register_alloc_rules() register_llvm_rules() + + # Force compilation of AD stack + # thunk = Enzyme.Compiler.thunk(Enzyme.Compiler.fspec(typeof(Base.identity), Tuple{Active{Float64}}), Const{typeof(Base.identity)}, Active, Tuple{Active{Float64}}, #=Split=# Val(Enzyme.API.DEM_ReverseModeCombined), #=width=#Val(1), #=ModifiedBetween=#Val((false,false)), Val(#=ReturnPrimal=#false), #=ShadowInit=#Val(false), NonGenABI) + # thunk(Const(Base.identity), Active(1.0), 1.0) end # Define EnzymeTarget @@ -4042,364 +4044,363 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, params = [parameters(llvm_f)...] - LLVM.IRBuilder() do builder - entry = BasicBlock(llvm_f, "entry") - position!(builder, entry) + builder = LLVM.IRBuilder() + entry = BasicBlock(llvm_f, "entry") + position!(builder, entry) - realparms = LLVM.Value[] - i = 1 + realparms = LLVM.Value[] + i = 1 - if returnRoots - sret = params[i] - i+= 1 + if returnRoots + sret = params[i] + i+= 1 - attr = if LLVM.version().major >= 12 - TypeAttribute("sret", jltype) - else - EnumAttribute("sret") - end - push!(parameter_attributes(llvm_f, 1), attr) - push!(parameter_attributes(llvm_f, 1), EnumAttribute("noalias")) - push!(parameter_attributes(llvm_f, 2), EnumAttribute("noalias")) - elseif jltype != T_void - sret = alloca!(builder, jltype) - end - rootRet = nothing - if returnRoots - rootRet = params[i] - i+=1 + attr = if LLVM.version().major >= 12 + TypeAttribute("sret", jltype) + else + EnumAttribute("sret") end + push!(parameter_attributes(llvm_f, 1), attr) + push!(parameter_attributes(llvm_f, 1), EnumAttribute("noalias")) + push!(parameter_attributes(llvm_f, 2), EnumAttribute("noalias")) + elseif jltype != T_void + sret = alloca!(builder, jltype) + end + rootRet = nothing + if returnRoots + rootRet = params[i] + i+=1 + end - activeNum = 0 + activeNum = 0 - for T in TT.parameters - T′ = eltype(T) + for T in TT.parameters + T′ = eltype(T) - if isghostty(T′) || Core.Compiler.isconstType(T′) - continue - end + if isghostty(T′) || Core.Compiler.isconstType(T′) + continue + end - isboxed = GPUCompiler.deserves_argbox(T′) + isboxed = GPUCompiler.deserves_argbox(T′) - llty = value_type(params[i]) + llty = value_type(params[i]) - convty = convert(LLVMType, T′; allow_boxed=true) + convty = convert(LLVMType, T′; allow_boxed=true) - if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType)) - al = emit_allocobj!(builder, Base.RefValue{T′}) - al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) - store!(builder, params[i], al) - al = addrspacecast!(builder, al, LLVM.PointerType(llty, Derived)) - push!(realparms, al) - else - push!(realparms, params[i]) - end + if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType)) + al = emit_allocobj!(builder, Base.RefValue{T′}) + al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) + store!(builder, params[i], al) + al = addrspacecast!(builder, al, LLVM.PointerType(llty, Derived)) + push!(realparms, al) + else + push!(realparms, params[i]) + end - i += 1 - if T <: Const - elseif T <: Active - isboxed = GPUCompiler.deserves_argbox(T′) - if isboxed - if is_split - msg = sprint() do io - println(io, "Unimplemented: Had active input arg needing a box in split mode") - println(io, T, " at index ", i) - println(io, TT) - end - throw(AssertionError(msg)) - end - @assert !is_split - # TODO replace with better enzyme_zero - ptr = gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), activeNum)]) - cst = pointercast!(builder, ptr, ptr8) - push!(realparms, ptr) - - LLVM.memset!(builder, cst, LLVM.ConstantInt(LLVM.IntType(8), 0), - LLVM.ConstantInt(LLVM.IntType(64), LLVM.storage_size(dl, Base.eltype(LLVM.value_type(ptr)) )), - #=align=#0 ) - end - activeNum += 1 - elseif T <: Duplicated || T <: DuplicatedNoNeed - push!(realparms, params[i]) - i += 1 - elseif T <: MixedDuplicated || T <: BatchMixedDuplicated - parmsi = params[i] - - if T <: BatchMixedDuplicated - if GPUCompiler.deserves_argbox(NTuple{width, Base.RefValue{T′}}) - njlvalue = LLVM.ArrayType(Int(width), T_prjlvalue) - parmsi = bitcast!(builder, parmsi, LLVM.PointerType(njlvalue, addrspace(value_type(parmsi)))) - parmsi = load!(builder, njlvalue, parmsi) + i += 1 + if T <: Const + elseif T <: Active + isboxed = GPUCompiler.deserves_argbox(T′) + if isboxed + if is_split + msg = sprint() do io + println(io, "Unimplemented: Had active input arg needing a box in split mode") + println(io, T, " at index ", i) + println(io, TT) end + throw(AssertionError(msg)) end + @assert !is_split + # TODO replace with better enzyme_zero + ptr = gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), activeNum)]) + cst = pointercast!(builder, ptr, ptr8) + push!(realparms, ptr) - isboxed = GPUCompiler.deserves_argbox(T′) - - resty = isboxed ? llty : LLVM.PointerType(llty, Derived) + LLVM.memset!(builder, cst, LLVM.ConstantInt(LLVM.IntType(8), 0), + LLVM.ConstantInt(LLVM.IntType(64), LLVM.storage_size(dl, Base.eltype(LLVM.value_type(ptr)) )), + #=align=#0 ) + end + activeNum += 1 + elseif T <: Duplicated || T <: DuplicatedNoNeed + push!(realparms, params[i]) + i += 1 + elseif T <: MixedDuplicated || T <: BatchMixedDuplicated + parmsi = params[i] - ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, resty))) - for idx in 1:width - pv = (width == 1) ? parmsi : extract_value!(builder, parmsi, idx-1) - pv = bitcast!(builder, pv, LLVM.PointerType(llty, addrspace(value_type(pv)))) - pv = addrspacecast!(builder, pv, LLVM.PointerType(llty, Derived)) - if isboxed - pv = load!(builder, llty, pv, "mixedboxload") - end - ival = (width == 1 ) ? pv : insert_value!(builder, ival, pv, idx-1) + if T <: BatchMixedDuplicated + if GPUCompiler.deserves_argbox(NTuple{width, Base.RefValue{T′}}) + njlvalue = LLVM.ArrayType(Int(width), T_prjlvalue) + parmsi = bitcast!(builder, parmsi, LLVM.PointerType(njlvalue, addrspace(value_type(parmsi)))) + parmsi = load!(builder, njlvalue, parmsi) end + end - push!(realparms, ival) - i += 1 - elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed - isboxed = GPUCompiler.deserves_argbox(NTuple{width, T′}) - val = params[i] + isboxed = GPUCompiler.deserves_argbox(T′) + + resty = isboxed ? llty : LLVM.PointerType(llty, Derived) + + ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, resty))) + for idx in 1:width + pv = (width == 1) ? parmsi : extract_value!(builder, parmsi, idx-1) + pv = bitcast!(builder, pv, LLVM.PointerType(llty, addrspace(value_type(pv)))) + pv = addrspacecast!(builder, pv, LLVM.PointerType(llty, Derived)) if isboxed - val = load!(builder, val) - end - i += 1 - push!(realparms, val) - elseif T <: BatchDuplicatedFunc - Func = get_func(T) - funcspec = GPUCompiler.methodinstance(Func, Tuple{}, world) - llvmf = nested_codegen!(Mode, mod, funcspec, world) - push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) - Func_RT = Core.Compiler.typeinf_ext_toplevel(interp, funcspec).rettype - @assert Func_RT == NTuple{width, T′} - _, psret, _ = get_return_info(Func_RT) - args = LLVM.Value[] - if psret !== nothing - psret = alloca!(builder, convert(LLVMType, Func_RT)) - push!(args, psret) + pv = load!(builder, llty, pv, "mixedboxload") end - res = LLVM.call!(builder, LLVM.function_type(llvmf), llvmf, args) - if LLVM.get_subprogram(llvmf) !== nothing - metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) ) - end - if psret !== nothing - res = load!(builder, convert(LLVMType, Func_RT), psret) - end - push!(realparms, res) - else - @assert false + ival = (width == 1 ) ? pv : insert_value!(builder, ival, pv, idx-1) end - end - if is_adjoint && (rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated) - push!(realparms, params[i]) + push!(realparms, ival) i += 1 - end - - if needs_tape - # Fix calling convention within julia that Tuple{Float,Float} ->[2 x float] rather than {float, float} - # and that Bool -> i8, not i1 - tparm = params[i] - tparm = calling_conv_fixup(builder, tparm, tape) - push!(realparms, tparm) + elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed + isboxed = GPUCompiler.deserves_argbox(NTuple{width, T′}) + val = params[i] + if isboxed + val = load!(builder, val) + end i += 1 + push!(realparms, val) + elseif T <: BatchDuplicatedFunc + Func = get_func(T) + funcspec = GPUCompiler.methodinstance(Func, Tuple{}, world) + llvmf = nested_codegen!(Mode, mod, funcspec, world) + push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) + Func_RT = Core.Compiler.typeinf_ext_toplevel(interp, funcspec).rettype + @assert Func_RT == NTuple{width, T′} + _, psret, _ = get_return_info(Func_RT) + args = LLVM.Value[] + if psret !== nothing + psret = alloca!(builder, convert(LLVMType, Func_RT)) + push!(args, psret) + end + res = LLVM.call!(builder, LLVM.function_type(llvmf), llvmf, args) + if LLVM.get_subprogram(llvmf) !== nothing + metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) ) + end + if psret !== nothing + res = load!(builder, convert(LLVMType, Func_RT), psret) + end + push!(realparms, res) + else + @assert false end + end - val = call!(builder, LLVM.function_type(enzymefn), enzymefn, realparms) - if LLVM.get_subprogram(llvm_f) !== nothing - metadata(val)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) ) - end + if is_adjoint && (rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated) + push!(realparms, params[i]) + i += 1 + end - @inline function fixup_abi(index, value) - valty = sret_types[index] - # Union becoming part of a tuple needs to be adjusted - # See https://github.com/JuliaLang/julia/blob/81afdbc36b365fcbf3ae25b7451c6cb5798c0c3d/src/cgutils.cpp#L3795C1-L3801C121 - if valty isa Union - T_int8 = LLVM.Int8Type() - if value_type(value) == T_int8 - value = nuwsub!(builder, value, LLVM.ConstantInt(T_int8, 1)) - end + if needs_tape + # Fix calling convention within julia that Tuple{Float,Float} ->[2 x float] rather than {float, float} + # and that Bool -> i8, not i1 + tparm = params[i] + tparm = calling_conv_fixup(builder, tparm, tape) + push!(realparms, tparm) + i += 1 + end + + val = call!(builder, LLVM.function_type(enzymefn), enzymefn, realparms) + if LLVM.get_subprogram(llvm_f) !== nothing + metadata(val)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) ) + end + + @inline function fixup_abi(index, value) + valty = sret_types[index] + # Union becoming part of a tuple needs to be adjusted + # See https://github.com/JuliaLang/julia/blob/81afdbc36b365fcbf3ae25b7451c6cb5798c0c3d/src/cgutils.cpp#L3795C1-L3801C121 + if valty isa Union + T_int8 = LLVM.Int8Type() + if value_type(value) == T_int8 + value = nuwsub!(builder, value, LLVM.ConstantInt(T_int8, 1)) end - return value end + return value + end - if Mode == API.DEM_ReverseModePrimal + if Mode == API.DEM_ReverseModePrimal - # if in split mode and the return is a union marked duplicated, upgrade floating point like shadow returns into ref{ty} since otherwise use of the value will create problems. - # 3 is index of shadow - if existed[3] != 0 && sret_union && active_reg_inner(pactualRetType, (), world, #=justActive=#Val(true), #=UnionSret=#Val(true)) == ActiveState - rewrite_union_returns_as_ref(enzymefn, data[3], world, width) - end - returnNum = 0 - for i in 1:3 - if existed[i] != 0 - eval = val - if data[i] != -1 - eval = extract_value!(builder, val, data[i]) - end - if i == 3 - if rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated - ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue))) - for idx in 1:width - pv = (width == 1) ? eval : extract_value!(builder, eval, idx-1) - al0 = al = emit_allocobj!(builder, Base.RefValue{eltype(rettype)}) - llty = value_type(pv) - al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) - store!(builder, pv, al) - emit_writebarrier!(builder, get_julia_inner_types(builder, al0, pv)) - ival = (width == 1 ) ? al0 : insert_value!(builder, ival, al0, idx-1) - end - eval = ival + # if in split mode and the return is a union marked duplicated, upgrade floating point like shadow returns into ref{ty} since otherwise use of the value will create problems. + # 3 is index of shadow + if existed[3] != 0 && sret_union && active_reg_inner(pactualRetType, (), world, #=justActive=#Val(true), #=UnionSret=#Val(true)) == ActiveState + rewrite_union_returns_as_ref(enzymefn, data[3], world, width) + end + returnNum = 0 + for i in 1:3 + if existed[i] != 0 + eval = val + if data[i] != -1 + eval = extract_value!(builder, val, data[i]) + end + if i == 3 + if rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue))) + for idx in 1:width + pv = (width == 1) ? eval : extract_value!(builder, eval, idx-1) + al0 = al = emit_allocobj!(builder, Base.RefValue{eltype(rettype)}) + llty = value_type(pv) + al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) + store!(builder, pv, al) + emit_writebarrier!(builder, get_julia_inner_types(builder, al0, pv)) + ival = (width == 1 ) ? al0 : insert_value!(builder, ival, al0, idx-1) end + eval = ival end - eval = fixup_abi(i, eval) - ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) - ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) - si = store!(builder, eval, ptr) - returnNum+=1 - if i == 3 && shadow_init - shadows = LLVM.Value[] - if width == 1 - push!(shadows, eval) - else - for i in 1:width - push!(shadows, extract_value!(builder, eval, i-1)) - end + end + eval = fixup_abi(i, eval) + ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) + ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) + si = store!(builder, eval, ptr) + returnNum+=1 + if i == 3 && shadow_init + shadows = LLVM.Value[] + if width == 1 + push!(shadows, eval) + else + for i in 1:width + push!(shadows, extract_value!(builder, eval, i-1)) end + end - cf = nested_codegen!(Mode, mod, add_one_in_place, Tuple{Any}, world) - push!(function_attributes(cf), EnumAttribute("alwaysinline", 0)) - for shadowv in shadows - c = call!(builder, LLVM.function_type(cf), cf, [shadowv]) - if LLVM.get_subprogram(llvm_f) !== nothing - metadata(c)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) ) - end + cf = nested_codegen!(Mode, mod, add_one_in_place, Tuple{Any}, world) + push!(function_attributes(cf), EnumAttribute("alwaysinline", 0)) + for shadowv in shadows + c = call!(builder, LLVM.function_type(cf), cf, [shadowv]) + if LLVM.get_subprogram(llvm_f) !== nothing + metadata(c)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) ) end end - elseif !isghostty(sret_types[i]) - ty = sret_types[i] - # if primal return, we can upgrade to the full known type - if i == 2 - ty = actualRetType - end - @assert !(isghostty(combinedReturn) || Core.Compiler.isconstType(combinedReturn) ) - @assert Core.Compiler.isconstType(ty) - eval = makeInstanceOf(ty) - eval = fixup_abi(i, eval) - ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) - ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) - si = store!(builder, eval, ptr) - returnNum+=1 - end - end - @assert returnNum == numLLVMReturns - elseif Mode == API.DEM_ForwardMode - count_Sret = 0 - count_llvm_Sret = 0 - if !isghostty(actualRetType) - if returnPrimal - count_llvm_Sret += 1 - end - if !(rettype <: Const) - count_llvm_Sret += 1 - end - end - if !isghostty(literal_rt) - if returnPrimal - count_Sret += 1 end - if !(rettype <: Const) - count_Sret += 1 + elseif !isghostty(sret_types[i]) + ty = sret_types[i] + # if primal return, we can upgrade to the full known type + if i == 2 + ty = actualRetType end - end - for returnNum in 0:(count_Sret-1) - eval = fixup_abi(returnNum+1, if count_llvm_Sret == 0 - makeInstanceOf(sret_types[returnNum+1]) - elseif count_llvm_Sret == 1 - val - else - @assert count_llvm_Sret > 1 - extract_value!(builder, val, returnNum) - end) + @assert !(isghostty(combinedReturn) || Core.Compiler.isconstType(combinedReturn) ) + @assert Core.Compiler.isconstType(ty) + eval = makeInstanceOf(ty) + eval = fixup_abi(i, eval) ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) si = store!(builder, eval, ptr) + returnNum+=1 end - @assert count_Sret == numLLVMReturns - else - activeNum = 0 - returnNum = 0 - if Mode == API.DEM_ReverseModeCombined - if returnPrimal - if !isghostty(literal_rt) - eval = fixup_abi(returnNum+1, if !isghostty(actualRetType) - extract_value!(builder, val, returnNum) - else - makeInstanceOf(sret_types[returnNum+1]) - end) - store!(builder, eval, inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), length(elements(jltype))-1 )])) - returnNum+=1 - end + end + @assert returnNum == numLLVMReturns + elseif Mode == API.DEM_ForwardMode + count_Sret = 0 + count_llvm_Sret = 0 + if !isghostty(actualRetType) + if returnPrimal + count_llvm_Sret += 1 + end + if !(rettype <: Const) + count_llvm_Sret += 1 + end + end + if !isghostty(literal_rt) + if returnPrimal + count_Sret += 1 + end + if !(rettype <: Const) + count_Sret += 1 + end + end + for returnNum in 0:(count_Sret-1) + eval = fixup_abi(returnNum+1, if count_llvm_Sret == 0 + makeInstanceOf(sret_types[returnNum+1]) + elseif count_llvm_Sret == 1 + val + else + @assert count_llvm_Sret > 1 + extract_value!(builder, val, returnNum) + end) + ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) + ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) + si = store!(builder, eval, ptr) + end + @assert count_Sret == numLLVMReturns + else + activeNum = 0 + returnNum = 0 + if Mode == API.DEM_ReverseModeCombined + if returnPrimal + if !isghostty(literal_rt) + eval = fixup_abi(returnNum+1, if !isghostty(actualRetType) + extract_value!(builder, val, returnNum) + else + makeInstanceOf(sret_types[returnNum+1]) + end) + store!(builder, eval, inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), length(elements(jltype))-1 )])) + returnNum+=1 end end - for T in TT.parameters[2:end] - if T <: Active - T′ = eltype(T) - isboxed = GPUCompiler.deserves_argbox(T′) - if !isboxed - eval = extract_value!(builder, val, returnNum) - store!(builder, eval, inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0), LLVM.ConstantInt(LLVM.IntType(32), activeNum)])) - returnNum+=1 - end - activeNum+=1 + end + for T in TT.parameters[2:end] + if T <: Active + T′ = eltype(T) + isboxed = GPUCompiler.deserves_argbox(T′) + if !isboxed + eval = extract_value!(builder, val, returnNum) + store!(builder, eval, inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0), LLVM.ConstantInt(LLVM.IntType(32), activeNum)])) + returnNum+=1 end + activeNum+=1 end - @assert (returnNum - activeNum) + (activeNum != 0 ? 1 : 0) == numLLVMReturns end + @assert (returnNum - activeNum) + (activeNum != 0 ? 1 : 0) == numLLVMReturns + end - if returnRoots - count = 0 - todo = Tuple{Vector{LLVM.Value},LLVM.LLVMType}[([LLVM.ConstantInt(LLVM.IntType(64), 0)], jltype)] - while length(todo) != 0 - path, ty = popfirst!(todo) - if isa(ty, LLVM.PointerType) - loc = inbounds_gep!(builder, root_ty, rootRet, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), count)]) - count+=1 - outloc = inbounds_gep!(builder, jltype, sret, path) - store!(builder, load!(builder, ty, outloc), loc) - continue - end - if isa(ty, LLVM.ArrayType) - if any_jltypes(ty) - for i=1:length(ty) - npath = copy(path) - push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) - push!(todo, (npath, eltype(ty))) - end + if returnRoots + count = 0 + todo = Tuple{Vector{LLVM.Value},LLVM.LLVMType}[([LLVM.ConstantInt(LLVM.IntType(64), 0)], jltype)] + while length(todo) != 0 + path, ty = popfirst!(todo) + if isa(ty, LLVM.PointerType) + loc = inbounds_gep!(builder, root_ty, rootRet, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), count)]) + count+=1 + outloc = inbounds_gep!(builder, jltype, sret, path) + store!(builder, load!(builder, ty, outloc), loc) + continue + end + if isa(ty, LLVM.ArrayType) + if any_jltypes(ty) + for i=1:length(ty) + npath = copy(path) + push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) + push!(todo, (npath, eltype(ty))) end - continue end - if isa(ty, LLVM.VectorType) - if any_jltypes(ty) - for i=1:size(ty) - npath = copy(path) - push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) - push!(todo, (npath, eltype(ty))) - end + continue + end + if isa(ty, LLVM.VectorType) + if any_jltypes(ty) + for i=1:size(ty) + npath = copy(path) + push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) + push!(todo, (npath, eltype(ty))) end - continue end - if isa(ty, LLVM.StructType) - for (i, t) in enumerate(LLVM.elements(ty)) - if any_jltypes(t) - npath = copy(path) - push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) - push!(todo, (npath, t)) - end + continue + end + if isa(ty, LLVM.StructType) + for (i, t) in enumerate(LLVM.elements(ty)) + if any_jltypes(t) + npath = copy(path) + push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) + push!(todo, (npath, t)) end - continue end + continue end - @assert count == tracked.count - end - if T_ret != T_void - ret!(builder, load!(builder, T_ret, sret)) - else - ret!(builder) end + @assert count == tracked.count + end + if T_ret != T_void + ret!(builder, load!(builder, T_ret, sret)) + else + ret!(builder) end # make sure that arguments are rooted if necessary @@ -4717,6 +4718,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function prargs = classify_arguments(functy, entry_ft, sret, returnRoots, swiftself, parmsRemoved) args = copy(prargs) filter!(args) do arg + Base.@_inline_meta arg.cc != GPUCompiler.GHOST && arg.cc != RemovedParam end @@ -6102,7 +6104,7 @@ struct CompileResult{AT, PT} TapeType::Type end -@inline (thunk::PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, RT, TT, Width, ReturnPrimal, World} = +@inline (thunk::PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, RT, TT, Width, ReturnPrimal} = enzyme_call(Val(false), thunk.adjoint, PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) @inline (thunk::CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal} = @@ -6232,6 +6234,7 @@ end rt::Type{RT}, fn::FA, ::Type{TapeType}, args::Vararg{Any, N}) where {RawCall, PT, FA, T, RT, TapeType, N, CC, width, returnPrimal} JuliaContext() do ctx + Base.@_inline_meta F = eltype(FA) is_forward = CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk is_adjoint = CC <: AdjointThunk || CC <: CombinedAdjointThunk @@ -6555,67 +6558,67 @@ end mod = LLVM.parent(llvm_f) i64 = LLVM.IntType(64) - LLVM.IRBuilder() do builder - entry = BasicBlock(llvm_f, "entry") - position!(builder, entry) - callparams = collect(LLVM.Value, parameters(llvm_f)) - - if !(GPUCompiler.isghosttype(PT) || Core.Compiler.isconstType(PT)) - lfn = callparams[1] - deleteat!(callparams, 1) - end - if returnRoots - tracked = CountTrackedPointers(jltype) - pushfirst!(callparams, alloca!(builder, LLVM.ArrayType(T_prjlvalue, tracked.count))) - pushfirst!(callparams, alloca!(builder, jltype)) - end + builder = LLVM.IRBuilder() + entry = BasicBlock(llvm_f, "entry") + position!(builder, entry) + callparams = collect(LLVM.Value, parameters(llvm_f)) - if needs_tape && !(isghostty(TapeType) || Core.Compiler.isconstType(TapeType)) - tape = callparams[end] - if TapeType <: EnzymeTapeToLoad - llty = from_tape_type(eltype(TapeType)) - tape = bitcast!(builder, LLVM.PointerType(llty, LLVM.addrspace(value_type(tape)))) - tape = load!(builder, llty, tape) - API.SetMustCache!(tape) - callparams[end] = tape - else - llty = from_tape_type(TapeType) - @assert value_type(tape) == llty - end - end + if !(GPUCompiler.isghosttype(PT) || Core.Compiler.isconstType(PT)) + lfn = callparams[1] + deleteat!(callparams, 1) + end - if !(GPUCompiler.isghosttype(PT) || Core.Compiler.isconstType(PT)) - FT = LLVM.FunctionType(returnRoots ? T_void : T_ret, [value_type(x) for x in callparams]) - lfn = inttoptr!(builder, lfn, LLVM.PointerType(FT)) + if returnRoots + tracked = CountTrackedPointers(jltype) + pushfirst!(callparams, alloca!(builder, LLVM.ArrayType(T_prjlvalue, tracked.count))) + pushfirst!(callparams, alloca!(builder, jltype)) + end + + if needs_tape && !(isghostty(TapeType) || Core.Compiler.isconstType(TapeType)) + tape = callparams[end] + if TapeType <: EnzymeTapeToLoad + llty = from_tape_type(eltype(TapeType)) + tape = bitcast!(builder, LLVM.PointerType(llty, LLVM.addrspace(value_type(tape)))) + tape = load!(builder, llty, tape) + API.SetMustCache!(tape) + callparams[end] = tape else - val_inner(::Type{Val{V}}) where V = V - submod, subname = val_inner(PT) - # TODO, consider optimization - # However, julia will optimize after this, so no need - submod = parse(LLVM.Module, String(submod)) - LLVM.link!(mod, submod) - lfn = functions(mod)[String(subname)] - FT = LLVM.function_type(lfn) + llty = from_tape_type(TapeType) + @assert value_type(tape) == llty end + end - r = call!(builder, FT, lfn, callparams) - - if returnRoots - attr = if LLVM.version().major >= 12 - TypeAttribute("sret", jltype) - else - EnumAttribute("sret") - end - LLVM.API.LLVMAddCallSiteAttribute(r, LLVM.API.LLVMAttributeIndex(1), attr) - r = load!(builder, eltype(value_type(callparams[1])), callparams[1]) - end + if !(GPUCompiler.isghosttype(PT) || Core.Compiler.isconstType(PT)) + FT = LLVM.FunctionType(returnRoots ? T_void : T_ret, [value_type(x) for x in callparams]) + lfn = inttoptr!(builder, lfn, LLVM.PointerType(FT)) + else + val_inner(::Type{Val{V}}) where V = V + submod, subname = val_inner(PT) + # TODO, consider optimization + # However, julia will optimize after this, so no need + submod = parse(LLVM.Module, String(submod)) + LLVM.link!(mod, submod) + lfn = functions(mod)[String(subname)] + FT = LLVM.function_type(lfn) + end - if T_ret != T_void - ret!(builder, r) + r = call!(builder, FT, lfn, callparams) + + if returnRoots + attr = if LLVM.version().major >= 12 + TypeAttribute("sret", jltype) else - ret!(builder) + EnumAttribute("sret") end + LLVM.API.LLVMAddCallSiteAttribute(r, LLVM.API.LLVMAttributeIndex(1), attr) + r = load!(builder, eltype(value_type(callparams[1])), callparams[1]) + end + + if T_ret != T_void + ret!(builder, r) + else + ret!(builder) end reinsert_gcmarker!(llvm_f) @@ -6736,94 +6739,119 @@ end @inline remove_innerty(::Type{<:MixedDuplicated}) = MixedDuplicated @inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated -@inline function thunkbase(mi::Core.MethodInstance, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} - JuliaContext() do ctx - - target = Compiler.EnzymeTarget() - params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) - tmp_job = if World isa Nothing +@inline function thunkbase(ctx, mi::Core.MethodInstance, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} + target = Compiler.EnzymeTarget() + params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) + tmp_job = if World isa Nothing Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) else Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) end - interp = GPUCompiler.get_interpreter(tmp_job) + interp = GPUCompiler.get_interpreter(tmp_job) - # TODO check compile return here, early - # rrt = Core.Compiler.return_type(f, primal.tt) # nothing - rrt = something(Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), Any) - rrt = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype + # TODO check compile return here, early + # rrt = Core.Compiler.return_type(f, primal.tt) # nothing + rrt = something(Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), Any) + rrt = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype - run_enzyme = true + run_enzyme = true - A2 = if rrt == Union{} - run_enzyme = false - Const + A2 = if rrt == Union{} + run_enzyme = false + Const else - A - end - - if run_enzyme && !(A2 <: Const) && guaranteed_const_nongen(rrt, World) - estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant" - return error(estr) - end + A + end + + if run_enzyme && !(A2 <: Const) && guaranteed_const_nongen(rrt, World) + estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant" + return error(estr) + end - rt2 = if !run_enzyme - Const{rrt} - elseif A2 isa UnionAll - A2{rrt} - else - @assert A isa DataType - # Can we relax this condition? - # @assert eltype(A) == rrt - A2 - end - - params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, run_enzyme, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) - job = if World isa Nothing - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) - else - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) + rt2 = if !run_enzyme + Const{rrt} + elseif A2 isa UnionAll + A2{rrt} + else + @assert A isa DataType + # Can we relax this condition? + # @assert eltype(A) == rrt + A2 + end + + params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, run_enzyme, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) + job = if World isa Nothing + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) + else + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) end - # We need to use primal as the key, to lookup the right method - # but need to mixin the hash of the adjoint to avoid cache collisions - # This is counter-intuitive since we would expect the cache to be split - # by the primal, but we want the generated code to be invalidated by - # invalidations of the primal, which is managed by GPUCompiler. - - - compile_result = cached_compilation(job) - if !run_enzyme - ErrT = PrimalErrorThunk{typeof(compile_result.adjoint), FA, rt2, TT, width, ReturnPrimal} - if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient - return (ErrT(compile_result.adjoint), ErrT(compile_result.adjoint)) - else - return ErrT(compile_result.adjoint) - end - elseif Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient - TapeType = compile_result.TapeType - AugT = AugmentedForwardThunk{typeof(compile_result.primal), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal, TapeType} - AdjT = AdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, TapeType} - return (AugT(compile_result.primal), AdjT(compile_result.adjoint)) - elseif Mode == API.DEM_ReverseModeCombined - CAdjT = CombinedAdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal} - return CAdjT(compile_result.adjoint) - elseif Mode == API.DEM_ForwardMode - FMT = ForwardModeThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal} - return FMT(compile_result.adjoint) + # We need to use primal as the key, to lookup the right method + # but need to mixin the hash of the adjoint to avoid cache collisions + # This is counter-intuitive since we would expect the cache to be split + # by the primal, but we want the generated code to be invalidated by + # invalidations of the primal, which is managed by GPUCompiler. + + + compile_result = cached_compilation(job) + if !run_enzyme + ErrT = PrimalErrorThunk{typeof(compile_result.adjoint), FA, rt2, TT, width, ReturnPrimal} + if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient + return (ErrT(compile_result.adjoint), ErrT(compile_result.adjoint)) else - @assert false - end + return ErrT(compile_result.adjoint) + end + elseif Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient + TapeType = compile_result.TapeType + AugT = AugmentedForwardThunk{typeof(compile_result.primal), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal, TapeType} + AdjT = AdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, TapeType} + return (AugT(compile_result.primal), AdjT(compile_result.adjoint)) + elseif Mode == API.DEM_ReverseModeCombined + CAdjT = CombinedAdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal} + return CAdjT(compile_result.adjoint) + elseif Mode == API.DEM_ForwardMode + FMT = ForwardModeThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal} + return FMT(compile_result.adjoint) + else + @assert false end end @inline function thunk(mi::Core.MethodInstance, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, ABI} - return thunkbase(mi, Val(#=World=#nothing), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI) + ts_ctx = JuliaContext() + ctx = @static if VERSION >= v"1.9.0-DEV.115" + context(ts_ctx) + else + ts_ctx + end + activate(ctx) + try + return thunkbase(ctx, mi, Val(#=World=#nothing), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI) + finally + deactivate(ctx) + @static if VERSION >= v"1.9.0-DEV.115" + dispose(ts_ctx) + end + end end @inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} mi = fspec(eltype(FA), TT, World) - res = thunkbase(mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI) + ts_ctx = JuliaContext() + ctx = @static if VERSION >= v"1.9.0-DEV.115" + context(ts_ctx) + else + ts_ctx + end + activate(ctx) + res = try + thunkbase(ctx, mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI) + finally + deactivate(ctx) + @static if VERSION >= v"1.9.0-DEV.115" + dispose(ts_ctx) + end + end return quote Base.@_inline_meta return $(res) @@ -6835,7 +6863,7 @@ import GPUCompiler: deferred_codegen_jobs @generated function deferred_codegen(::Val{World}, ::Type{FA}, ::Val{TT}, ::Val{A},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false),::Val{ShadowInit}=Val(false),::Type{ExpectedTapeType}=UnknownTapeType) where {World, FA<:Annotation,TT, A, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType} JuliaContext() do ctx - + Base.@_inline_meta mi = fspec(eltype(FA), TT, World) target = EnzymeTarget() diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index a0283f899e..7506d2f565 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -340,7 +340,7 @@ function custom_rule_method_error(world, fn, args...) throw(MethodError(fn, (args...,), world)) end -function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end @@ -640,7 +640,7 @@ end return aug_fwd_mi(orig, gutils)[1] !== nothing end -function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, normalR, shadowR, tape)::LLVM.API.LLVMValueRef +@register_rev function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, normalR, shadowR, tape)::LLVM.API.LLVMValueRef ctx = LLVM.context(orig) @@ -1065,7 +1065,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, end -function enzyme_custom_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function enzyme_custom_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) return true end @@ -1076,8 +1076,7 @@ function enzyme_custom_augfwd(B, orig, gutils, normalR, shadowR, tapeR) return false end - -function enzyme_custom_rev(B, orig, gutils, tape) +@register_rev function enzyme_custom_rev(B, orig, gutils, tape) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) return end @@ -1085,7 +1084,7 @@ function enzyme_custom_rev(B, orig, gutils, tape) return nothing end -function enzyme_custom_diffuse(orig, gutils, val, isshadow, mode) +@register_diffuse function enzyme_custom_diffuse(orig, gutils, val, isshadow, mode) # use default if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) return (false, true) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index e99d118772..c624d85202 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1334,7 +1334,7 @@ function common_generic_fwd(offset, B, orig, gutils, normalR, shadowR) return false end -function generic_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function generic_fwd(B, orig, gutils, normalR, shadowR) conv = LLVM.callconv(orig) # https://github.com/JuliaLang/julia/blob/5162023b9b67265ddb0bbbc0f4bd6b225c429aa0/src/codegen_shared.h#L20 @assert conv == 37 @@ -1390,7 +1390,7 @@ function common_generic_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) return false end -function generic_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function generic_augfwd(B, orig, gutils, normalR, shadowR, tapeR) conv = LLVM.callconv(orig) # https://github.com/JuliaLang/julia/blob/5162023b9b67265ddb0bbbc0f4bd6b225c429aa0/src/codegen_shared.h#L20 @@ -1414,7 +1414,7 @@ function common_generic_rev(offset, B, orig, gutils, tape)::Cvoid return nothing end -function generic_rev(B, orig, gutils, tape)::Cvoid +@register_rev function generic_rev(B, orig, gutils, tape)::Cvoid conv = LLVM.callconv(orig) # https://github.com/JuliaLang/julia/blob/5162023b9b67265ddb0bbbc0f4bd6b225c429aa0/src/codegen_shared.h#L20 @@ -1532,7 +1532,7 @@ function common_apply_latest_rev(offset, B, orig, gutils, tape)::Cvoid return nothing end -function apply_latest_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function apply_latest_fwd(B, orig, gutils, normalR, shadowR) conv = LLVM.callconv(orig) # https://github.com/JuliaLang/julia/blob/5162023b9b67265ddb0bbbc0f4bd6b225c429aa0/src/codegen_shared.h#L20 @assert conv == 37 @@ -1540,7 +1540,7 @@ function apply_latest_fwd(B, orig, gutils, normalR, shadowR) common_apply_latest_fwd(1, B, orig, gutils, normalR, shadowR) end -function apply_latest_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function apply_latest_augfwd(B, orig, gutils, normalR, shadowR, tapeR) conv = LLVM.callconv(orig) # https://github.com/JuliaLang/julia/blob/5162023b9b67265ddb0bbbc0f4bd6b225c429aa0/src/codegen_shared.h#L20 @assert conv == 37 @@ -1548,7 +1548,7 @@ function apply_latest_augfwd(B, orig, gutils, normalR, shadowR, tapeR) common_apply_latest_augfwd(1, B, orig, gutils, normalR, shadowR, tapeR) end -function apply_latest_rev(B, orig, gutils, tape) +@register_rev function apply_latest_rev(B, orig, gutils, tape) conv = LLVM.callconv(orig) # https://github.com/JuliaLang/julia/blob/5162023b9b67265ddb0bbbc0f4bd6b225c429aa0/src/codegen_shared.h#L20 @assert conv == 37 @@ -1728,15 +1728,15 @@ function common_apply_iterate_rev(offset, B, orig, gutils, tape) return nothing end -function apply_iterate_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function apply_iterate_fwd(B, orig, gutils, normalR, shadowR) common_apply_iterate_fwd(1, B, orig, gutils, normalR, shadowR) end -function apply_iterate_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function apply_iterate_augfwd(B, orig, gutils, normalR, shadowR, tapeR) common_apply_iterate_augfwd(1, B, orig, gutils, normalR, shadowR, tapeR) end -function apply_iterate_rev(B, orig, gutils, tape) +@register_rev function apply_iterate_rev(B, orig, gutils, tape) common_apply_iterate_rev(1, B, orig, gutils, tape) return nothing end @@ -1851,15 +1851,15 @@ function common_invoke_rev(offset, B, orig, gutils, tape) return nothing end -function invoke_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function invoke_fwd(B, orig, gutils, normalR, shadowR) common_invoke_fwd(1, B, orig, gutils, normalR, shadowR) end -function invoke_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function invoke_augfwd(B, orig, gutils, normalR, shadowR, tapeR) common_invoke_augfwd(1, B, orig, gutils, normalR, shadowR, tapeR) end -function invoke_rev(B, orig, gutils, tape) +@register_rev function invoke_rev(B, orig, gutils, tape) common_invoke_rev(1, B, orig, gutils, tape) return nothing end diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index beecfa5a2f..bafd3fb119 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -1,9 +1,74 @@ +macro register_aug(expr) + decl = string(expr.args[1]) + name = decl[1:prevind(decl, findfirst('(', decl))] + cname = name*"_cfunc" + name = Symbol(name) + cname = Symbol(cname) + + expr2 = :(@inline $expr) + res = quote + function $cname(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef}, tapeR::Ptr{LLVM.API.LLVMValueRef})::UInt8 + return UInt8($name(LLVM.IRBuilder(B), LLVM.CallInst(OrigCI), GradientUtils(gutils), normalR, shadowR, tapeR)::Bool) + end + end + return Expr(:block, esc(expr2), esc(res)) +end + +macro register_rev(expr) + decl = string(expr.args[1]) + name = decl[1:prevind(decl, findfirst('(', decl))] + cname = name*"_cfunc" + + name = Symbol(name) + cname = Symbol(cname) + expr2 = :(@inline $expr) + res = quote + function $cname(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, tape::LLVM.API.LLVMValueRef)::Cvoid + $name(LLVM.IRBuilder(B), LLVM.CallInst(OrigCI), GradientUtils(gutils), tape == C_NULL ? nothing : LLVM.Value(tape)) + return + end + end + return Expr(:block, esc(expr2), esc(res)) +end + +macro register_fwd(expr) + decl = string(expr.args[1]) + name = decl[1:prevind(decl, findfirst('(', decl))] + cname = name*"_cfunc" + name = Symbol(name) + cname = Symbol(cname) + expr2 = :(@inline $expr) + res = quote + function $cname(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef})::UInt8 + return UInt8($name(LLVM.IRBuilder(B), LLVM.CallInst(OrigCI), GradientUtils(gutils), normalR, shadowR)::Bool) + end + end + return Expr(:block, esc(expr2), esc(res)) +end + +macro register_diffuse(expr) + decl = string(expr.args[1]) + name = decl[1:prevind(decl, findfirst('(', decl))] + cname = name*"_cfunc" + name = Symbol(name) + cname = Symbol(cname) + expr2 = :(@inline $expr) + res = quote + function $cname(OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, val::LLVM.API.LLVMValueRef, shadow::UInt8, mode::API.CDerivativeMode, useDefault::Ptr{UInt8})::UInt8 + res = $name(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, mode)::Tuple{Bool, Bool} + unsafe_store!(useDefault, UInt8(res[2])) + return UInt8(res[1]) + end + end + return Expr(:block, esc(expr2), esc(res)) +end + include("customrules.jl") include("jitrules.jl") include("typeunstablerules.jl") include("parallelrules.jl") -function jlcall_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function jlcall_fwd(B, orig, gutils, normalR, shadowR) F = operands(orig)[1] if isa(F, LLVM.Function) name = LLVM.name(F) @@ -44,7 +109,7 @@ function jlcall_fwd(B, orig, gutils, normalR, shadowR) return false end -function jlcall_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function jlcall_augfwd(B, orig, gutils, normalR, shadowR, tapeR) F = operands(orig)[1] if isa(F, LLVM.Function) name = LLVM.name(F) @@ -85,7 +150,7 @@ function jlcall_augfwd(B, orig, gutils, normalR, shadowR, tapeR) return false end -function jlcall_rev(B, orig, gutils, tape) +@register_rev function jlcall_rev(B, orig, gutils, tape) F = operands(orig)[1] if isa(F, LLVM.Function) name = LLVM.name(F) @@ -135,7 +200,7 @@ function jlcall_rev(B, orig, gutils, tape) return nothing end -function jlcall2_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function jlcall2_fwd(B, orig, gutils, normalR, shadowR) F = operands(orig)[1] if isa(F, LLVM.Function) name = LLVM.name(F) @@ -152,7 +217,7 @@ function jlcall2_fwd(B, orig, gutils, normalR, shadowR) return false end -function jlcall2_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function jlcall2_augfwd(B, orig, gutils, normalR, shadowR, tapeR) F = operands(orig)[1] if isa(F, LLVM.Function) name = LLVM.name(F) @@ -169,7 +234,7 @@ function jlcall2_augfwd(B, orig, gutils, normalR, shadowR, tapeR) return false end -function jlcall2_rev(B, orig, gutils, tape) +@register_rev function jlcall2_rev(B, orig, gutils, tape) F = operands(orig)[1] if isa(F, LLVM.Function) name = LLVM.name(F) @@ -188,15 +253,15 @@ function jlcall2_rev(B, orig, gutils, tape) end -function noop_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function noop_fwd(B, orig, gutils, normalR, shadowR) return true end -function noop_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function noop_augfwd(B, orig, gutils, normalR, shadowR, tapeR) return true end -function duplicate_rev(B, orig, gutils, tape) +@register_rev function duplicate_rev(B, orig, gutils, tape) newg = new_from_original(gutils, orig) real_ops = collect(operands(orig))[1:end-1] @@ -208,7 +273,7 @@ function duplicate_rev(B, orig, gutils, tape) return nothing end -function arraycopy_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function arraycopy_fwd(B, orig, gutils, normalR, shadowR) ctx = LLVM.context(orig) if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL @@ -404,7 +469,7 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) return nothing end -function arraycopy_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function arraycopy_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL return true end @@ -421,7 +486,7 @@ function arraycopy_augfwd(B, orig, gutils, normalR, shadowR, tapeR) return false end -function arraycopy_rev(B, orig, gutils, tape) +@register_rev function arraycopy_rev(B, orig, gutils, tape) origops = LLVM.operands(orig) if !is_constant_value(gutils, origops[1]) && !is_constant_value(gutils, orig) arraycopy_common(#=fwd=#false, B, orig, origops[1], gutils, nothing) @@ -430,7 +495,7 @@ function arraycopy_rev(B, orig, gutils, tape) return nothing end -function arrayreshape_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function arrayreshape_fwd(B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end @@ -465,15 +530,15 @@ function arrayreshape_fwd(B, orig, gutils, normalR, shadowR) return false end -function arrayreshape_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function arrayreshape_augfwd(B, orig, gutils, normalR, shadowR, tapeR) arrayreshape_fwd(B, orig, gutils, normalR, shadowR) end -function arrayreshape_rev(B, orig, gutils, tape) +@register_rev function arrayreshape_rev(B, orig, gutils, tape) return nothing end -function gcloaded_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function gcloaded_fwd(B, orig, gutils, normalR, shadowR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) @@ -513,15 +578,15 @@ function gcloaded_fwd(B, orig, gutils, normalR, shadowR) return false end -function gcloaded_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function gcloaded_augfwd(B, orig, gutils, normalR, shadowR, tapeR) gcloaded_fwd(B, orig, gutils, normalR, shadowR) end -function gcloaded_rev(B, orig, gutils, tape) +@register_rev function gcloaded_rev(B, orig, gutils, tape) return nothing end -function boxfloat_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function boxfloat_fwd(B, orig, gutils, normalR, shadowR) origops = collect(operands(orig)) width = get_width(gutils) if is_constant_value(gutils, orig) @@ -548,7 +613,7 @@ function boxfloat_fwd(B, orig, gutils, normalR, shadowR) return false end -function boxfloat_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function boxfloat_augfwd(B, orig, gutils, normalR, shadowR, tapeR) origops = collect(operands(orig)) width = get_width(gutils) if is_constant_value(gutils, orig) @@ -576,7 +641,7 @@ function boxfloat_augfwd(B, orig, gutils, normalR, shadowR, tapeR) return false end -function boxfloat_rev(B, orig, gutils, tape) +@register_rev function boxfloat_rev(B, orig, gutils, tape) origops = collect(operands(orig)) width = get_width(gutils) if !is_constant_value(gutils, orig) @@ -606,7 +671,7 @@ function boxfloat_rev(B, orig, gutils, tape) return nothing end -function eqtableget_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function eqtableget_fwd(B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) return true end @@ -630,7 +695,7 @@ function error_if_active(::Type{T}) where T nothing end -function eqtableget_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function eqtableget_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if is_constant_value(gutils, orig) return true end @@ -701,11 +766,11 @@ function eqtableget_augfwd(B, orig, gutils, normalR, shadowR, tapeR) return false end -function eqtableget_rev(B, orig, gutils, tape) +@register_rev function eqtableget_rev(B, orig, gutils, tape) return nothing end -function eqtableput_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function eqtableput_fwd(B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end @@ -719,7 +784,7 @@ function eqtableput_fwd(B, orig, gutils, normalR, shadowR) return false end -function eqtableput_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function eqtableput_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end @@ -781,12 +846,12 @@ function eqtableput_augfwd(B, orig, gutils, normalR, shadowR, tapeR) return false end -function eqtableput_rev(B, orig, gutils, tape) +@register_rev function eqtableput_rev(B, orig, gutils, tape) return nothing end -function idtablerehash_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function idtablerehash_fwd(B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end @@ -800,7 +865,7 @@ function idtablerehash_fwd(B, orig, gutils, normalR, shadowR) return false end -function idtablerehash_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function idtablerehash_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end @@ -814,12 +879,12 @@ function idtablerehash_augfwd(B, orig, gutils, normalR, shadowR, tapeR) return false end -function idtablerehash_rev(B, orig, gutils, tape) +@register_rev function idtablerehash_rev(B, orig, gutils, tape) emit_error(B, orig, "Enzyme: Not yet implemented reverse for jl_idtable_rehash") return nothing end -function jl_array_grow_end_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function jl_array_grow_end_fwd(B, orig, gutils, normalR, shadowR) origops = collect(operands(orig)) if is_constant_value(gutils, origops[1]) return true @@ -847,7 +912,7 @@ function jl_array_grow_end_fwd(B, orig, gutils, normalR, shadowR) end -function jl_array_grow_end_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function jl_array_grow_end_augfwd(B, orig, gutils, normalR, shadowR, tapeR) origops = collect(operands(orig)) if is_constant_value(gutils, origops[1]) return true @@ -898,7 +963,7 @@ function jl_array_grow_end_augfwd(B, orig, gutils, normalR, shadowR, tapeR) return false end -function jl_array_grow_end_rev(B, orig, gutils, tape) +@register_rev function jl_array_grow_end_rev(B, orig, gutils, tape) origops = collect(operands(orig)) if !is_constant_value(gutils, origops[1]) @@ -935,15 +1000,15 @@ function jl_array_grow_end_rev(B, orig, gutils, tape) return nothing end -function jl_array_del_end_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function jl_array_del_end_fwd(B, orig, gutils, normalR, shadowR) jl_array_grow_end_fwd(B, orig, gutils, normalR, shadowR) end -function jl_array_del_end_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function jl_array_del_end_augfwd(B, orig, gutils, normalR, shadowR, tapeR) jl_array_del_end_fwd(B, orig, gutils, normalR, shadowR) end -function jl_array_del_end_rev(B, orig, gutils, tape) +@register_rev function jl_array_del_end_rev(B, orig, gutils, tape) origops = collect(operands(orig)) if !is_constant_value(gutils, origops[1]) width = get_width(gutils) @@ -998,7 +1063,7 @@ function jl_array_del_end_rev(B, orig, gutils, tape) return nothing end -function jl_array_ptr_copy_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function jl_array_ptr_copy_fwd(B, orig, gutils, normalR, shadowR) if is_constant_inst(gutils, orig) return true end @@ -1036,14 +1101,14 @@ function jl_array_ptr_copy_fwd(B, orig, gutils, normalR, shadowR) return false end -function jl_array_ptr_copy_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function jl_array_ptr_copy_augfwd(B, orig, gutils, normalR, shadowR, tapeR) jl_array_ptr_copy_fwd(B, orig, gutils, normalR, shadowR) end -function jl_array_ptr_copy_rev(B, orig, gutils, tape) +@register_rev function jl_array_ptr_copy_rev(B, orig, gutils, tape) return nothing end -function jl_array_sizehint_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function jl_array_sizehint_fwd(B, orig, gutils, normalR, shadowR) origops = collect(operands(orig)) if is_constant_value(gutils, origops[1]) return true @@ -1070,15 +1135,15 @@ function jl_array_sizehint_fwd(B, orig, gutils, normalR, shadowR) return false end -function jl_array_sizehint_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function jl_array_sizehint_augfwd(B, orig, gutils, normalR, shadowR, tapeR) jl_array_sizehint_fwd(B, orig, gutils, normalR, shadowR) end -function jl_array_sizehint_rev(B, orig, gutils, tape) +@register_rev function jl_array_sizehint_rev(B, orig, gutils, tape) return nothing end -function jl_unhandled_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function jl_unhandled_fwd(B, orig, gutils, normalR, shadowR) newo = new_from_original(gutils, orig) origops = collect(operands(orig)) err = emit_error(B, orig, "Enzyme: unhandled forward for "*string(origops[end])) @@ -1100,14 +1165,14 @@ function jl_unhandled_fwd(B, orig, gutils, normalR, shadowR) end return false end -function jl_unhandled_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function jl_unhandled_augfwd(B, orig, gutils, normalR, shadowR, tapeR) jl_unhandled_fwd(B, orig, gutils, normalR, shadowR) end -function jl_unhandled_rev(B, orig, gutils, tape) +@register_rev function jl_unhandled_rev(B, orig, gutils, tape) return nothing end -function get_binding_or_error_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function get_binding_or_error_fwd(B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) return true end @@ -1133,7 +1198,7 @@ function get_binding_or_error_fwd(B, orig, gutils, normalR, shadowR) return false end -function get_binding_or_error_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function get_binding_or_error_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if is_constant_value(gutils, orig) return true end @@ -1158,12 +1223,12 @@ function get_binding_or_error_augfwd(B, orig, gutils, normalR, shadowR, tapeR) return false end -function get_binding_or_error_rev(B, orig, gutils, tape) +@register_rev function get_binding_or_error_rev(B, orig, gutils, tape) emit_error(B, orig, "Enzyme: unhandled reverse for jl_get_binding_or_error") return nothing end -function finalizer_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function finalizer_fwd(B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end @@ -1177,7 +1242,7 @@ function finalizer_fwd(B, orig, gutils, normalR, shadowR) return false end -function finalizer_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function finalizer_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end @@ -1198,7 +1263,7 @@ function finalizer_augfwd(B, orig, gutils, normalR, shadowR, tapeR) return false end -function finalizer_rev(B, orig, gutils, tape) +@register_rev function finalizer_rev(B, orig, gutils, tape) # emit_error(B, orig, "Enzyme: unhandled reverse for jl_gc_add_finalizer_th") return nothing end @@ -1216,33 +1281,26 @@ function register_handler!(variants, augfwd_handler, rev_handler, fwd_handler=no end macro augfunc(f) - :(@cfunction((B, OrigCI, gutils, normalR, shadowR, tapeR) -> begin - UInt8($f(LLVM.IRBuilder(B), LLVM.CallInst(OrigCI), GradientUtils(gutils), normalR, shadowR, tapeR)::Bool) - end, UInt8, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}) + cname = Symbol(string(f)*"_cfunc") + :(@cfunction($cname, UInt8, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}) )) end macro revfunc(f) - :(@cfunction((B, OrigCI, gutils, tape) -> begin - $f(LLVM.IRBuilder(B), LLVM.CallInst(OrigCI), GradientUtils(gutils), tape == C_NULL ? nothing : LLVM.Value(tape)) - end, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef) + cname = Symbol(string(f)*"_cfunc") + :(@cfunction($cname, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef) )) end macro fwdfunc(f) - :(@cfunction((B, OrigCI, gutils, normalR, shadowR) -> begin - UInt8($f(LLVM.IRBuilder(B), LLVM.CallInst(OrigCI), GradientUtils(gutils), normalR, shadowR)::Bool) - end, UInt8, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}) + cname = Symbol(string(f)*"_cfunc") + :(@cfunction($cname, UInt8, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}) )) end - macro diffusefunc(f) - :(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin - res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, mode)::Tuple{Bool, Bool} - unsafe_store!(useDefault, UInt8(res[2])) - UInt8(res[1]) - end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8}) + cname = Symbol(string(f)*"_cfunc") + :(@cfunction(Compiler.$cname, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8}) )) end @@ -1290,12 +1348,6 @@ end @revfunc(threadsfor_rev), @fwdfunc(threadsfor_fwd), ) - register_handler!( - ("jl_pmap",), - @augfunc(pmap_augfwd), - @revfunc(pmap_rev), - @fwdfunc(pmap_fwd), - ) register_handler!( ("jl_new_task", "ijl_new_task"), @augfunc(newtask_augfwd), diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index 6aa661feef..e4fbab6faf 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -381,7 +381,7 @@ end return refed, LLVM.name(subfunc), dfuncT, vals, thunkTy, TapeType, copies end -function threadsfor_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function threadsfor_fwd(B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end @@ -419,7 +419,7 @@ end return false end -function threadsfor_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function threadsfor_augfwd(B, orig, gutils, normalR, shadowR, tapeR) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) @@ -474,7 +474,7 @@ end return false end -function threadsfor_rev(B, orig, gutils, tape) +@register_rev function threadsfor_rev(B, orig, gutils, tape) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) world = enzyme_extract_world(LLVM.parent(position(B))) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) @@ -512,7 +512,7 @@ end return nothing end -function newtask_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function newtask_fwd(B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end @@ -550,7 +550,7 @@ function newtask_fwd(B, orig, gutils, normalR, shadowR) return false end -function newtask_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function newtask_augfwd(B, orig, gutils, normalR, shadowR, tapeR) # fn, dfn = augmentAndGradient(fn) # t = jl_new_task(fn) # # shadow t @@ -608,11 +608,11 @@ function newtask_augfwd(B, orig, gutils, normalR, shadowR, tapeR) return false end -function newtask_rev(B, orig, gutils, tape) +@register_rev function newtask_rev(B, orig, gutils, tape) return nothing end -function set_task_tid_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function set_task_tid_fwd(B, orig, gutils, normalR, shadowR) ops = collect(operands(orig))[1:end-1] if is_constant_value(gutils, ops[1]) return true @@ -641,15 +641,15 @@ function set_task_tid_fwd(B, orig, gutils, normalR, shadowR) return false end -function set_task_tid_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function set_task_tid_augfwd(B, orig, gutils, normalR, shadowR, tapeR) set_task_tid_fwd(B, orig, gutils, normalR, shadowR) end -function set_task_tid_rev(B, orig, gutils, tape) +@register_rev function set_task_tid_rev(B, orig, gutils, tape) return nothing end -function enq_work_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function enq_work_fwd(B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end @@ -661,7 +661,7 @@ function enq_work_fwd(B, orig, gutils, normalR, shadowR) return false end -function enq_work_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function enq_work_augfwd(B, orig, gutils, normalR, shadowR, tapeR) enq_work_fwd(B, orig, gutils, normalR, shadowR) end @@ -684,7 +684,7 @@ function find_match(mod, name) return nothing end -function enq_work_rev(B, orig, gutils, tape) +@register_rev function enq_work_rev(B, orig, gutils, tape) # jl_wait(shadow(t)) origops = LLVM.operands(orig) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -701,7 +701,7 @@ function enq_work_rev(B, orig, gutils, tape) return nothing end -function wait_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function wait_fwd(B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end @@ -712,7 +712,7 @@ function wait_fwd(B, orig, gutils, normalR, shadowR) return false end -function wait_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function wait_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end @@ -723,7 +723,7 @@ function wait_augfwd(B, orig, gutils, normalR, shadowR, tapeR) return false end -function wait_rev(B, orig, gutils, tape) +@register_rev function wait_rev(B, orig, gutils, tape) # jl_enq_work(shadow(t)) origops = LLVM.operands(orig) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index a09968a325..e092f81fb0 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -471,33 +471,33 @@ function common_f_tuple_rev(offset, B, orig, gutils, tape) end -function f_tuple_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function f_tuple_fwd(B, orig, gutils, normalR, shadowR) common_f_tuple_fwd(1, B, orig, gutils, normalR, shadowR) end -function f_tuple_augfwd(B, orig, gutils, normalR, shadowR, tapeR)::Bool +@register_aug function f_tuple_augfwd(B, orig, gutils, normalR, shadowR, tapeR)::Bool common_f_tuple_augfwd(1, B, orig, gutils, normalR, shadowR, tapeR) end -function f_tuple_rev(B, orig, gutils, tape) +@register_rev function f_tuple_rev(B, orig, gutils, tape) common_f_tuple_rev(1, B, orig, gutils, tape) return nothing end -function new_structv_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function new_structv_fwd(B, orig, gutils, normalR, shadowR) common_newstructv_fwd(1, B, orig, gutils, normalR, shadowR) end -function new_structv_augfwd(B, orig, gutils, normalR, shadowR, tapeR)::Bool +@register_aug function new_structv_augfwd(B, orig, gutils, normalR, shadowR, tapeR)::Bool common_newstructv_augfwd(1, B, orig, gutils, normalR, shadowR, tapeR) end -function new_structv_rev(B, orig, gutils, tape) +@register_rev function new_structv_rev(B, orig, gutils, tape) common_apply_latest_rev(1, B, orig, gutils, tape) return nothing end -function new_structt_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function new_structt_fwd(B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL return true end @@ -526,11 +526,12 @@ function new_structt_fwd(B, orig, gutils, normalR, shadowR) unsafe_store!(shadowR, shadowres.ref) return false end -function new_structt_augfwd(B, orig, gutils, normalR, shadowR, tapeR)::Bool + +@register_aug function new_structt_augfwd(B, orig, gutils, normalR, shadowR, tapeR)::Bool new_structt_fwd(B, orig, gutils, normalR, shadowR) end -function new_structt_rev(B, orig, gutils, tape) +@register_rev function new_structt_rev(B, orig, gutils, tape) if is_constant_value(gutils, orig) return true end @@ -978,7 +979,7 @@ function common_jl_getfield_rev(offset, B, orig, gutils, tape) return nothing end -function jl_nthfield_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function jl_nthfield_fwd(B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL return true end @@ -1020,7 +1021,7 @@ function jl_nthfield_fwd(B, orig, gutils, normalR, shadowR) end return false end -function jl_nthfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function jl_nthfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL return true end @@ -1097,7 +1098,7 @@ function jl_nthfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) unsafe_store!(tapeR, cal.ref) return false end -function jl_nthfield_rev(B, orig, gutils, tape) +@register_rev function jl_nthfield_rev(B, orig, gutils, tape) if is_constant_value(gutils, orig) return end @@ -1159,13 +1160,13 @@ function jl_nthfield_rev(B, orig, gutils, tape) return nothing end -function jl_getfield_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function jl_getfield_fwd(B, orig, gutils, normalR, shadowR) common_jl_getfield_fwd(1, B, orig, gutils, normalR, shadowR) end -function jl_getfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function jl_getfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) common_jl_getfield_augfwd(1, B, orig, gutils, normalR, shadowR, tapeR) end -function jl_getfield_rev(B, orig, gutils, tape) +@register_rev function jl_getfield_rev(B, orig, gutils, tape) common_jl_getfield_rev(1, B, orig, gutils, tape) end @@ -1314,15 +1315,15 @@ function common_setfield_rev(offset, B, orig, gutils, tape) end -function setfield_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function setfield_fwd(B, orig, gutils, normalR, shadowR) common_setfield_fwd(1, B, orig, gutils, normalR, shadowR) end -function setfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function setfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) common_setfield_augfwd(1, B, orig, gutils, normalR, shadowR, tapeR) end -function setfield_rev(B, orig, gutils, tape) +@register_rev function setfield_rev(B, orig, gutils, tape) common_setfield_rev(1, B, orig, gutils, tape) end @@ -1438,17 +1439,17 @@ function common_finalizer_rev(offset, B, orig, gutils, tape) return nothing end -function f_svec_ref_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function f_svec_ref_fwd(B, orig, gutils, normalR, shadowR) common_f_svec_ref_fwd(1, B, orig, gutils, normalR, shadowR) return nothing end -function f_svec_ref_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function f_svec_ref_augfwd(B, orig, gutils, normalR, shadowR, tapeR) common_f_svec_ref_augfwd(1, B, orig, gutils, normalR, shadowR, tapeR) return nothing end -function f_svec_ref_rev(B, orig, gutils, tape) +@register_rev function f_svec_ref_rev(B, orig, gutils, tape) common_f_svec_ref_rev(1, B, orig, gutils, tape) return nothing end From 76205816822b28171b578ae40d5bcea692ca225b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Sun, 21 Jul 2024 06:35:00 +0200 Subject: [PATCH 183/495] Fix typo in jl_array_ptr_copy_fwd (#1648) --- src/rules/llvmrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index bafd3fb119..e52c4e09c6 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -1093,7 +1093,7 @@ end push!(vargs, extract_value!(B, a, idx-1)) end push!(vargs, args[end]) - cal = call_samefunc_with_inverted_bundles!(b, gutils, orig, vargs, valTys, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, vargs, valTys, #=lookup=#false) debug_from_orig!(gutils, cal, orig) callconv!(cal, callconv(orig)) end From 87338e42884b6f65c8bd8e97fdb0e00c0ac2b762 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Sat, 20 Jul 2024 21:35:57 -0700 Subject: [PATCH 184/495] Fix #1630 (#1631) --- src/rules/customrules.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 7506d2f565..34e8f2d8df 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -231,7 +231,7 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, if mixed RefTy = arg.typ if width != 1 - RefTy = NTuple{N, RefTy} + RefTy = NTuple{Int(width), RefTy} end llrty = convert(LLVMType, RefTy) RefTy = Base.RefValue{RefTy} @@ -1035,7 +1035,7 @@ end for (ptr_val, argTyp, refal) in mixeds RefTy = argTyp if width != 1 - RefTy = NTuple{N, RefTy} + RefTy = NTuple{Int(width), RefTy} end curs = load!(B, convert(LLVMType, RefTy), refal) From 1e15769e6a3fe4bcedac24d521d091a398347d93 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 21 Jul 2024 02:49:18 -0400 Subject: [PATCH 185/495] Mark regex fns as nofree (#1654) * Mark regex fns as nofree * more pcre --- src/compiler.jl | 8 ++++++++ src/internal_rules.jl | 3 +++ src/rules/typeunstablerules.jl | 2 +- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index f53996aed5..ad8769aca3 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -155,6 +155,13 @@ end end const nofreefns = Set{String}(( + "pcre2_match_8", + "julia.gcroot_flush", + "pcre2_jit_stack_assign_8", + "pcre2_match_context_create_8", + "pcre2_jit_stack_create_8", + "ijl_gc_enable_finalizers_internal", "jl_gc_enable_finalizers_internal", + "pcre2_match_data_create_from_pattern_8", "ijl_gc_run_pending_finalizers", "jl_gc_run_pending_finalizers", "ijl_typeassert", "jl_typeassert", "ijl_f_isdefined", "jl_f_isdefined", @@ -239,6 +246,7 @@ const nofreefns = Set{String}(( )) const inactivefns = Set{String}(( + "pcre2_match_data_create_from_pattern_8", "ijl_typeassert", "jl_typeassert", "ijl_f_isdefined", "jl_f_isdefined", "ijl_field_index", "jl_field_index", diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 65933b4237..d874dd5380 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -109,6 +109,9 @@ end function EnzymeRules.inactive_noinl(::typeof(Base.hasproperty), args...) return nothing end +function EnzymeRules.inactive(::typeof(Base.startswith), ::AbstractString, args...) + return nothing +end if VERSION >= v"1.9" Enzyme.EnzymeRules.inactive_noinl(::typeof(Core._compute_sparams), args...) = nothing diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index e092f81fb0..2765edecb9 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -1423,7 +1423,7 @@ function common_finalizer_fwd(offset, B, orig, gutils, normalR, shadowR) return false end -function common_finalizer_augfwd(offset, B, orig, gutils, normalR, shadowR) +function common_finalizer_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end From e585a7d8dde8dedd4152d104eadcb2a1eec19d1c Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 21 Jul 2024 12:05:13 -0400 Subject: [PATCH 186/495] Bump jll (#1657) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e0f9f3f645..d2faef44a6 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.7" -Enzyme_jll = "0.0.135" +Enzyme_jll = "0.0.136" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" LogExpFunctions = "0.3" From 4f3365d088e64133ccc082a242bc465eb519c9c1 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 21 Jul 2024 13:33:37 -0400 Subject: [PATCH 187/495] Bigfloat constructor rules (#1658) --- src/compiler.jl | 4 +++ src/internal_rules.jl | 62 ++++++++++++++++++++++++++++++++++++++++ src/rules/customrules.jl | 2 +- 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index ad8769aca3..e789048bde 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -486,6 +486,10 @@ end return active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) end + if T <: BigFloat + return DupState + end + if T <: AbstractFloat return ActiveState end diff --git a/src/internal_rules.jl b/src/internal_rules.jl index d874dd5380..8c703effe6 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -815,3 +815,65 @@ function EnzymeRules.forward(func::Const{typeof(ldiv!)}, end end end + +function EnzymeRules.forward( + Ty::Const{Type{BigFloat}}, + RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}; + kwargs... + ) + if RT <: Const + return Ty.val(; kwargs...) + elseif RT <: DuplicatedNoNeed + return Ty.val(; kwargs...) + elseif RT <: Duplicated + return RT(Ty.val(; kwargs...), Ty.val(; kwargs...)) + elseif RT <: BatchDuplicatedNoNeed + ntuple(Val(width(RT))) do i + Base.@_inline_meta + Ty.val(; kwargs...) + end + else + @assert RT <: BatchDuplicated + tup = ntuple(Val(width(RT))) do i + Base.@_inline_meta + Ty.val(; kwargs...) + end + RT(Ty.val(; kwargs...), tup) + end +end + +function EnzymeRules.augmented_primal( + config, + Ty::Const{Type{BigFloat}}, + RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}, + kwargs... + ) + primal = if EnzymeRules.needs_primal(config) + Ty.val(; kwargs...) + else + nothing + end + shadow = if RT <: Const + shadow = nothing + else + if EnzymeRules.width(config) == 1 + Ty.val(; kwargs...) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + Ty.val(; kwargs...) + end + end + end + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function EnzymeRules.reverse( + config, + Ty::Const{Type{BigFloat}}, + RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}, + tape, + kwargs..., + ) + return () +end \ No newline at end of file diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 34e8f2d8df..64508660ab 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -825,7 +825,7 @@ end val = LLVM.Value(API.EnzymeGradientUtilsDiffe(gutils, orig, B)) API.EnzymeGradientUtilsSetDiffe(gutils, orig, LLVM.null(value_type(val)), B) else - llety = convert(LLVMType, eltype(RT)) + llety = convert(LLVMType, eltype(RT); allow_boxed=true) ptr_val = invert_pointer(gutils, operands(orig)[1 + !isghostty(funcTy)], B) val = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llety))) for idx in 1:width From c0caf9acf287f131fdbff049b9fcd2c69f379204 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 21 Jul 2024 14:09:58 -0400 Subject: [PATCH 188/495] GPU report exception: fix linkage (#1659) --- src/compiler.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index e789048bde..a7409355ec 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5796,6 +5796,12 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; # annotate annotate!(mod, mode) + if haskey(functions(mod), "gpu_report_exception") + exc = functions(mod)["gpu_report_exception"] + if !isempty(blocks(exc)) + linkage!(exc, LLVM.API.LLVMExternalLinkage) + end + end # Run early pipeline optimize!(mod, target_machine) @@ -5803,6 +5809,13 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if process_module GPUCompiler.optimize_module!(parent_job, mod) end + + if haskey(functions(mod), "gpu_report_exception") + exc = functions(mod)["gpu_report_exception"] + if !isempty(blocks(exc)) + linkage!(exc, LLVM.API.LLVMInternalLinkage) + end + end seen = TypeTreeTable() T_jlvalue = LLVM.StructType(LLVMType[]) From 3aa6a5a00d85106fc7324241046c7bb1833f7943 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 22 Jul 2024 10:09:47 -0400 Subject: [PATCH 189/495] Add internal forward-mode rules for ranges (#1655) * Add internal forward-mode rules for ranges This is part 1 one solving https://github.com/EnzymeAD/Enzyme.jl/issues/274. It does the forward mode rules as those are simpler. A separate PR will do the WIP reverse mode rules as that seems to be a bit more complex. Add missing `@test` don't forget the rule * namespace * Update internal_rules.jl * Update internal_rules.jl * Update src/internal_rules.jl * Update internal_rules.jl * Update internal_rules.jl --------- Co-authored-by: William Moses --- src/internal_rules.jl | 43 +++++++++++++++++++++++++++++++++++++++++- test/internal_rules.jl | 23 ++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 8c703effe6..8adf38f258 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -816,6 +816,47 @@ function EnzymeRules.forward(func::Const{typeof(ldiv!)}, end end +# Ranges +# Float64 ranges in Julia use bitwise `&` with higher precision +# to correct for numerical error, thus we put rules over the +# operations as this is not directly differentiable +function EnzymeRules.forward(func::Const{Colon}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, start::Annotation, step::Annotation, stop::Annotation) + ret = func.val(start.val, step.val, stop.val) + dstart = if start isa Const + zero(eltype(ret)) + elseif start isa Duplicated || start isa DuplicatedNoNeed + one(eltype(ret)) + elseif start isa BatchDuplicated || start isa BatchDuplicatedNoNeed + ntuple(x->one(eltype(ret)), Val(width(RT))) + else + error("Annotation type $(typeof(start)) not supported for range start. Please open an issue") + end + + dstep = if step isa Const + zero(eltype(ret)) + elseif step isa Duplicated || step isa DuplicatedNoNeed + one(eltype(ret)) + elseif step isa BatchDuplicated || step isa BatchDuplicatedNoNeed + ntuple(x->one(eltype(ret)), Val(width(RT))) + else + error("Annotation type $(typeof(start)) not supported for range step. Please open an issue") + end + + if RT <: Duplicated + Duplicated(ret, range(dstart, step=dstep, length=length(ret))) + elseif RT <: Const + ret + elseif RT <: DuplicatedNoNeed + range(dstart, step=dstep, length=length(ret)) + elseif RT <: BatchDuplicated + BatchDuplicated(ret, ntuple(x-> range(dstart, step=dstep, length=length(ret)), Val(width(RT)))) + elseif RT <: BatchDuplicatedNoNeed + ntuple(x-> range(dstart, step=dstep, length=length(ret)), Val(width(RT))) + else + error("This should not be possible. Please report.") + end +end + function EnzymeRules.forward( Ty::Const{Type{BigFloat}}, RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}; @@ -876,4 +917,4 @@ function EnzymeRules.reverse( kwargs..., ) return () -end \ No newline at end of file +end diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 7cc5c07321..25c7ab2838 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -618,4 +618,27 @@ end @test autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) == ((1.0,),) end +@testset "Ranges" begin + function f1(x) + ts = Array(0.0:x:3.0) + sum(ts) + end + function f2(x) + ts = Array(0.0:.25:3.0) + sum(ts) + x + end + function f3(x) + ts = Array(x:.25:3.0) + sum(ts) + end + function f4(x) + ts = Array(0.0:.25:x) + sum(ts) + end + @test Enzyme.autodiff(Forward, f1, Duplicated(0.25, 1.0)) == (78,) + @test Enzyme.autodiff(Forward, f2, Duplicated(0.25, 1.0)) == (1.0,) + @test Enzyme.autodiff(Forward, f3, Duplicated(0.25, 1.0)) == (12,) + @test Enzyme.autodiff(Forward, f4, Duplicated(3.0, 1.0)) == (0,) +end + end # InternalRules From 21f97518dd2b636831ceede117043303bc3f7a2f Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 22 Jul 2024 10:10:13 -0400 Subject: [PATCH 190/495] Fix tape union check (#1660) * Fix tape union check * Fix applicable fn --- src/rules/customrules.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 64508660ab..6a5c101786 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -700,6 +700,7 @@ end mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) llvmf = nothing + applicablefn = true if forward llvmf = nested_codegen!(mode, mod, ami, world) @@ -733,6 +734,7 @@ end llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) pushfirst!(args, LLVM.ConstantInt(world)) rev_RT = Union{} + applicablefn = false end else if EnzymeRules.isapplicable(EnzymeRules.reverse, rev_TT; world) @@ -744,6 +746,7 @@ end llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) pushfirst!(args, LLVM.ConstantInt(world)) rev_RT = Union{} + applicablefn = false end end end @@ -779,11 +782,11 @@ end funcTy = rev_TT.parameters[isKWCall ? 4 : 2] if needsTape @assert tape != C_NULL - tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup)) + !isghostty(funcTy) + (rev_RT == Union{}) + tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup)) + !isghostty(funcTy) + (!applicablefn) trueidx = tape_idx+(sret !== nothing)+(returnRoots !== nothing)+swiftself + (RT <: Active) innerTy = value_type(parameters(llvmf)[trueidx]) if innerTy != value_type(tape) - if isabstracttype(TapeT) || TapeT == Tuple || TapeT.layout == C_NULL || TapeT == Array + if isabstracttype(TapeT) || TapeT isa UnionAll || TapeT == Tuple || TapeT.layout == C_NULL || TapeT == Array msg = sprint() do io println(io, "Enzyme : mismatch between innerTy $innerTy and tape type $(value_type(tape))") println(io, "tape_idx=", tape_idx) @@ -797,6 +800,8 @@ end println(io, "returnRoots=", returnRoots) println(io, "swiftself=", swiftself) println(io, "RT=", RT) + println(io, "rev_RT=", rev_RT) + println(io, "applicablefn=", applicablefn) println(io, "tape=", tape) println(io, "llvmf=", string(LLVM.function_type(llvmf))) println(io, "TapeT=", TapeT) @@ -846,7 +851,7 @@ end if any_jltypes(llty) emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) end - insert!(args, 1+(!isghostty(funcTy))+(kwtup!==nothing && !isghostty(kwtup)) + (rev_RT == Union{}), al) + insert!(args, 1+(!isghostty(funcTy))+(kwtup!==nothing && !isghostty(kwtup)) + (!applicablefn), al) end end From f207a05b2ea22d83a2541021deac7cea6ff00801 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 22 Jul 2024 16:58:09 -0400 Subject: [PATCH 191/495] Tape type of error (#1666) --- src/Enzyme.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 05ba2ae4c0..9ff56bdd81 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -722,8 +722,12 @@ end Val(codegen_world_age(eltype(FA), primal_tt)) end nondef = Enzyme.Compiler.thunk(opt_mi, FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) - TapeType = EnzymeRules.tape_type(nondef[1]) - return TapeType + if nondef[1] isa Enzyme.Compiler.PrimalErrorThunk + return Nothing + else + TapeType = EnzymeRules.tape_type(nondef[1]) + return TapeType + end end const tape_cache = Dict{UInt, Type}() From 7744a8cbf0060c413116b84712a5b7c538644bdd Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 22 Jul 2024 16:59:14 -0400 Subject: [PATCH 192/495] Fix make_zero box infinite recursion (#1665) --- src/compiler.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index a7409355ec..7f41099462 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -489,7 +489,7 @@ end if T <: BigFloat return DupState end - + if T <: AbstractFloat return ActiveState end @@ -1425,8 +1425,9 @@ end return seen[prev] end prev2 = prev.contents - res = Core.Box(Base.Ref(EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)))) + res = Core.Box() seen[prev] = res + res.contents = Base.Ref(EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive))) return res end From 73abcf1dba50e2fa5cf446889c793ccad474579c Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 22 Jul 2024 19:56:39 -0400 Subject: [PATCH 193/495] Pushforward dvals in the ranges forward rule (#1663) * Pushforward dvals in the ranges forward rule I just realized before it merged that we're actually missing a part of the pushforward here. Even though the solution to the derivative is just the range itself, we forgot to multiply it by the dval to pushforward the derivative. It implicitly had it as one, calculating the derivative, instead of the full pushforward. The test should be updated to catch this. * Test dual propagation Changes the inputs to be the same but have another operation * fix and test batch rules and format --- src/internal_rules.jl | 36 ++++++++++++++++++++++-------------- test/internal_rules.jl | 37 +++++++++++++++++++++++++------------ 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 8adf38f258..e3040de747 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -820,38 +820,46 @@ end # Float64 ranges in Julia use bitwise `&` with higher precision # to correct for numerical error, thus we put rules over the # operations as this is not directly differentiable -function EnzymeRules.forward(func::Const{Colon}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, start::Annotation, step::Annotation, stop::Annotation) +function EnzymeRules.forward(func::Const{Colon}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated, + BatchDuplicated,BatchDuplicatedNoNeed}}, + start::Annotation, step::Annotation, stop::Annotation) ret = func.val(start.val, step.val, stop.val) - dstart = if start isa Const - zero(eltype(ret)) + dstart = if start isa Const + zero(eltype(ret)) elseif start isa Duplicated || start isa DuplicatedNoNeed - one(eltype(ret)) + start.dval elseif start isa BatchDuplicated || start isa BatchDuplicatedNoNeed - ntuple(x->one(eltype(ret)), Val(width(RT))) + ntuple(i -> start.dval[i], Val(width(RT))) else error("Annotation type $(typeof(start)) not supported for range start. Please open an issue") end - dstep = if step isa Const - zero(eltype(ret)) + dstep = if step isa Const + zero(eltype(ret)) elseif step isa Duplicated || step isa DuplicatedNoNeed - one(eltype(ret)) + step.dval elseif step isa BatchDuplicated || step isa BatchDuplicatedNoNeed - ntuple(x->one(eltype(ret)), Val(width(RT))) + ntuple(i -> step.dval[i], Val(width(RT))) else error("Annotation type $(typeof(start)) not supported for range step. Please open an issue") end - if RT <: Duplicated - Duplicated(ret, range(dstart, step=dstep, length=length(ret))) + if RT <: Duplicated + Duplicated(ret, range(dstart; step=dstep, length=length(ret))) elseif RT <: Const ret elseif RT <: DuplicatedNoNeed - range(dstart, step=dstep, length=length(ret)) + range(dstart; step=dstep, length=length(ret)) elseif RT <: BatchDuplicated - BatchDuplicated(ret, ntuple(x-> range(dstart, step=dstep, length=length(ret)), Val(width(RT)))) + BatchDuplicated(ret, + ntuple(i -> range(dstart isa Number ? dstart : dstart[i]; + step=dstep isa Number ? dstep : dstep[i], + length=length(ret)), Val(width(RT)))) elseif RT <: BatchDuplicatedNoNeed - ntuple(x-> range(dstart, step=dstep, length=length(ret)), Val(width(RT))) + ntuple(i -> range(dstart isa Number ? dstart : dstart[i]; + step=dstep isa Number ? dstep : dstep[i], + length=length(ret)), Val(width(RT))) else error("This should not be possible. Please report.") end diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 25c7ab2838..9a5c0bdbb2 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -620,25 +620,38 @@ end @testset "Ranges" begin function f1(x) + x = 25.0x ts = Array(0.0:x:3.0) - sum(ts) + return sum(ts) end function f2(x) - ts = Array(0.0:.25:3.0) - sum(ts) + x + x = 25.0x + ts = Array(0.0:0.25:3.0) + return sum(ts) + x end function f3(x) - ts = Array(x:.25:3.0) - sum(ts) + x = 25.0x + ts = Array(x:0.25:3.0) + return sum(ts) end function f4(x) - ts = Array(0.0:.25:x) - sum(ts) - end - @test Enzyme.autodiff(Forward, f1, Duplicated(0.25, 1.0)) == (78,) - @test Enzyme.autodiff(Forward, f2, Duplicated(0.25, 1.0)) == (1.0,) - @test Enzyme.autodiff(Forward, f3, Duplicated(0.25, 1.0)) == (12,) - @test Enzyme.autodiff(Forward, f4, Duplicated(3.0, 1.0)) == (0,) + x = 25.0x + ts = Array(0.0:0.25:x) + return sum(ts) + end + @test Enzyme.autodiff(Forward, f1, Duplicated(0.1, 1.0)) == (25.0,) + @test Enzyme.autodiff(Forward, f2, Duplicated(0.1, 1.0)) == (25.0,) + @test Enzyme.autodiff(Forward, f3, Duplicated(0.1, 1.0)) == (75.0,) + @test Enzyme.autodiff(Forward, f4, Duplicated(0.12, 1.0)) == (0,) + + @test Enzyme.autodiff(Forward, f1, BatchDuplicated(0.1, (1.0, 2.0))) == + ((var"1"=25.0, var"2"=50.0),) + @test Enzyme.autodiff(Forward, f2, BatchDuplicated(0.1, (1.0, 2.0))) == + ((var"1"=25.0, var"2"=50.0),) + @test Enzyme.autodiff(Forward, f3, BatchDuplicated(0.1, (1.0, 2.0))) == + ((var"1"=75.0, var"2"=150.0),) + @test Enzyme.autodiff(Forward, f4, BatchDuplicated(0.12, (1.0, 2.0))) == + ((var"1"=0.0, var"2"=0.0),) end end # InternalRules From 426495faebdd478d2e389d4d4f65235ad7d5df39 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 22 Jul 2024 21:26:05 -0400 Subject: [PATCH 194/495] Inactive boxed cache (#1667) * Inactive boxed cache * noroot --- src/compiler.jl | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7f41099462..00971c170c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -314,6 +314,13 @@ const activefns = Set{String}(( "jl_", )) +const inactiveglobs = Set{String}(( + "ijl_boxed_uint8_cache", + "jl_boxed_uint8_cache", + "ijl_boxed_int8_cache", + "jl_boxed_int8_cache", +)) + @enum ActivityState begin AnyState = 0 ActiveState = 1 @@ -3270,6 +3277,16 @@ function annotate!(mod, mode) for f in fns API.EnzymeAttributeKnownFunctions(f.ref) end + +@static if VERSION >= v"1.8-" + for gname in inactiveglobs + globs = LLVM.globals(mod) + if haskey(globs, gname) + glob = globs[gname] + metadata(glob)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) + end + end +end for fname in inactivefns if haskey(fns, fname) @@ -4823,9 +4840,6 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function ops = collect(operands(ci))[1:end-1] position!(builder, ci) nops = LLVM.Value[] - if returnRoots - push!(nops, ops[1+sret]) - end if swiftself push!(nops, ops[1+sret+returnRoots]) end From 5fc23099cdac3dbdfa9d413ee3f17b34759bb6f2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 23 Jul 2024 00:35:47 -0400 Subject: [PATCH 195/495] Update Project.toml (#1668) --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index d2faef44a6..d94f35ec11 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.24" +version = "0.12.25" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -32,7 +32,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.7" -Enzyme_jll = "0.0.136" +Enzyme_jll = "0.0.137" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" LogExpFunctions = "0.3" From 4b464ac935a4b3090ec345fd959954f4811a970a Mon Sep 17 00:00:00 2001 From: Miles Cranmer Date: Thu, 25 Jul 2024 00:09:40 +0100 Subject: [PATCH 196/495] ci: add DynamicExpressions integration tests (#1675) * ci: add DynamicExpressions integration tests * ci: set to 1.10 for tests * ci: fix DE integration test compat * ci: fix DE integration test env * ci: fix DE integration command --- .github/workflows/CI.yml | 26 ++++++++++++++++++++++ test/integration/DynamicExpressions.jl | 30 ++++++++++++++++++++++++++ test/integration/Project.toml | 5 +++++ 3 files changed, 61 insertions(+) create mode 100644 test/integration/DynamicExpressions.jl create mode 100644 test/integration/Project.toml diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index bc20420585..20415db568 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -225,6 +225,32 @@ jobs: if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' with: files: lcov.info + integration: + name: Integration Tests - ${{ matrix.test }} + runs-on: ${{ matrix.os }} + env: + JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager + strategy: + fail-fast: false + matrix: + version: + - '1.10' + os: + - ubuntu-latest + test: + - DynamicExpressions + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + - uses: julia-actions/cache@v1 + - uses: julia-actions/julia-buildpkg@v1 + - name: "Run tests" + run: | + julia --color=yes --project=test/integration -e 'using Pkg; Pkg.develop([PackageSpec(; path) for path in (".", "lib/EnzymeCore")]); Pkg.instantiate()' + julia --color=yes --project=test/integration --threads=auto --check-bounds=yes test/integration/${{ matrix.test }}.jl + shell: bash docs: name: Documentation runs-on: ubuntu-latest diff --git a/test/integration/DynamicExpressions.jl b/test/integration/DynamicExpressions.jl new file mode 100644 index 0000000000..dc626cb77a --- /dev/null +++ b/test/integration/DynamicExpressions.jl @@ -0,0 +1,30 @@ +using Test, Enzyme, DynamicExpressions + +operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(cos, sin)) + +tree = Node(; op=1, l=Node{Float64}(; feature=1), r=Node(; op=1, l=Node{Float64}(; feature=2))) +# == x1 + cos(x2) + +X = randn(3, 100); +dX = zero(X) + +function f(tree, X, operators, output) + output[] = sum(eval_tree_array(tree, X, operators)[1]) + return nothing +end + +output = [0.0] +doutput = [1.0] + +autodiff( + Reverse, + f, + Const(tree), + Duplicated(X, dX), + Const(operators), + Duplicated(output, doutput), +) + +true_dX = cat(ones(100), -sin.(X[2, :]), zeros(100); dims=2)' + +@test true_dX ≈ dX diff --git a/test/integration/Project.toml b/test/integration/Project.toml new file mode 100644 index 0000000000..dd4a40af73 --- /dev/null +++ b/test/integration/Project.toml @@ -0,0 +1,5 @@ +[deps] +DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" + +[compat] +DynamicExpressions = "=0.18.5" From 07ebcd3c68c8d8a8784fbaeb283a3eca6583047e Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 26 Jul 2024 11:09:11 -0400 Subject: [PATCH 197/495] Bump jll with extract tuple for reverse fix (#1678) --- Project.toml | 2 +- test/runtests.jl | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d94f35ec11..2760fa45a1 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.7" -Enzyme_jll = "0.0.137" +Enzyme_jll = "0.0.138" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" LogExpFunctions = "0.3" diff --git a/test/runtests.jl b/test/runtests.jl index 15b3c3ddf4..6534d9a2a5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2239,6 +2239,19 @@ end end end +function solve_cubic_eq(poly::AbstractVector{Complex{T}}) where T + a1 = 1 / @inbounds poly[1] + E1 = 2*a1 + E12 = E1*E1 + s1 = log(E12) + return nothing +end + +@testset "Extract Tuple for Reverse" begin + autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(solve_cubic_eq)}, Const, Duplicated{Vector{Complex{Float64}}}) +end + + @testset "GetField" begin mutable struct MyType x::Float64 From ad1f2d853a70c003fa0944ef8a4e87e801e4bdc9 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 26 Jul 2024 16:39:57 -0400 Subject: [PATCH 198/495] Fix bitcast arg (#1679) * bitcast arg * Update compiler.jl --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 00971c170c..f709128580 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6619,7 +6619,7 @@ end tape = callparams[end] if TapeType <: EnzymeTapeToLoad llty = from_tape_type(eltype(TapeType)) - tape = bitcast!(builder, LLVM.PointerType(llty, LLVM.addrspace(value_type(tape)))) + tape = bitcast!(builder, tape, LLVM.PointerType(llty, LLVM.addrspace(value_type(tape)))) tape = load!(builder, llty, tape) API.SetMustCache!(tape) callparams[end] = tape From 3d1d6e6aeabcf199245352089d8910e85bec6e50 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 27 Jul 2024 19:27:14 -0400 Subject: [PATCH 199/495] Fix array of tuple any (#1685) * Fix array of tuple any * Update typetree.jl * Update typetree.jl --- src/typetree.jl | 4 ++-- test/runtests.jl | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/typetree.jl b/src/typetree.jl index 2ab6cd4a50..f8c70808be 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -161,7 +161,7 @@ end offset = 0 tt = copy(typetree(T, ctx, dl, seen)) - if !allocatedinline(T) + if !allocatedinline(T) && Base.isconcretetype(T) merge!(tt, TypeTree(API.DT_Pointer, ctx)) only!(tt, 0) end @@ -184,7 +184,7 @@ else function typetree_inner(::Type{<:GenericMemory{kind, T}}, ctx, dl, seen::TypeTreeTable) where {kind, T} offset = 0 tt = copy(typetree(T, ctx, dl, seen)) - if !allocatedinline(T) + if !allocatedinline(T) && Base.isconcretetype(T) merge!(tt, TypeTree(API.DT_Pointer, ctx)) only!(tt, 0) end diff --git a/test/runtests.jl b/test/runtests.jl index 6534d9a2a5..092cf8abce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3236,6 +3236,23 @@ end @test ad_eta[1] ≈ 0.0 end +function absset(out, x) + @inbounds out[1] = (x,) + return nothing +end + +@testset "Abstract Array element type" begin + out = Tuple{Any}[(9.7,)] + dout = Tuple{Any}[(4.3,)] + + autodiff(Enzyme.Forward, + absset, + Duplicated(out, dout), + Duplicated(3.1, 2.4) + ) + @test dout[1][1] ≈ 2.4 +end + @testset "Tape Width" begin struct Roo x::Float64 From 5f34de2dd48a273ed0977b1aef4956856c547ecb Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 27 Jul 2024 21:40:56 -0400 Subject: [PATCH 200/495] First relocation of globals (#1684) * First relocation of globals * Fix relocation * rework * load glob * with orcv1, maybe? * Support setting metadata on old llvm.jl * Vendor global set md * More 1.6 orc * add libdl * orcv1 * fix --- src/absint.jl | 14 +++ src/api.jl | 1 + src/compiler.jl | 224 +++++++++++++++++++++++++++------ src/compiler/optimize.jl | 6 +- src/compiler/orcv1.jl | 31 ++++- src/compiler/orcv2.jl | 13 ++ src/compiler/validation.jl | 2 +- src/rules/customrules.jl | 5 +- src/rules/jitrules.jl | 16 +-- src/rules/llvmrules.jl | 12 +- src/rules/parallelrules.jl | 18 ++- src/rules/typeunstablerules.jl | 42 ++++--- src/utils.jl | 90 +++++++++++-- 13 files changed, 377 insertions(+), 97 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 6462162bc6..80cb3b9d4f 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -103,6 +103,20 @@ function absint(arg::LLVM.Value, partial::Bool=false) return (true, typ) end + if isa(arg, GlobalVariable) + gname = LLVM.name(arg) + for (k, v) in JuliaGlobalNameMap + if gname == k || gname == "ejl_"*k + return (true, v) + end + end + for (k, v) in JuliaEnzymeNameMap + if gname == k || gname == "ejl_"*k + return (true, v) + end + end + end + if isa(arg, LLVM.LoadInst) && value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked) ptr = operands(arg)[1] ce = ptr diff --git a/src/api.jl b/src/api.jl index da0958e75c..2eaae18901 100644 --- a/src/api.jl +++ b/src/api.jl @@ -103,6 +103,7 @@ struct CFnTypeInfo known_values::Ptr{IntList} end +SetMD(v::Union{LLVM.Instruction, LLVM.GlobalVariable}, kind::String, node::LLVM.Metadata) = ccall((:EnzymeSetStringMD, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef, Cstring, LLVM.API.LLVMValueRef), v, kind, LLVM.Value(node)) @static if !isdefined(LLVM, :ValueMetadataDict) Base.haskey(md::LLVM.InstructionMetadataDict, kind::String) = diff --git a/src/compiler.jl b/src/compiler.jl index f709128580..f9fdb34a4c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -45,6 +45,8 @@ end import GPUCompiler: @safe_debug, @safe_info, @safe_warn, @safe_error +include("compiler/utils.jl") + if LLVM.has_orc_v1() include("compiler/orcv1.jl") else @@ -53,7 +55,6 @@ end include("gradientutils.jl") -include("compiler/utils.jl") # Julia function to LLVM stem and arity const cmplx_known_ops = @@ -981,6 +982,113 @@ function emit_svec!(B, args)::LLVM.Value call!(B, fty, fn, [LLVM.ConstantInt(sz, length(args)), args...]) end +AnyArray(Length::Int) = NamedTuple{ntuple(i->Symbol(i), Val(Length)),NTuple{Length,Any}} + +const JuliaEnzymeNameMap = Dict{String, Any}( + + "enz_val_true" => Val(true), + "enz_val_false" => Val(false), + + "enz_val_1" => Val(1), + + "enz_any_array_1" => AnyArray(1), + "enz_any_array_2" => AnyArray(2), + "enz_any_array_3" => AnyArray(3) +) + +const JuliaGlobalNameMap = Dict{String, Any}( + "jl_type_type" => Type, + "jl_any_type" => Any, + "jl_datatype_type" => DataType, + "jl_methtable_type" => Core.MethodTable, + "jl_symbol_type" => Symbol, + "jl_simplevector_type" => Core.SimpleVector, + "jl_nothing_type" => Nothing, + + "jl_tvar_type" => TypeVar, + "jl_typeofbottom_type" => Core.TypeofBottom, + "jl_bottom_type" => Union{}, + "jl_unionall_type" => UnionAll, + + "jl_uniontype_type" => Union, + "jl_emptytuple_type" => Tuple{}, + "jl_emptytuple" => (), + "jl_int8_type" => Int8, + "jl_uint8_type" => UInt8, + "jl_int16_type" => Int16, + "jl_uint16_type" => UInt16, + "jl_int32_type" => Int32, + "jl_uint32_type" => UInt32, + "jl_int64_type" => Int64, + "jl_uint64_type" => UInt64, + "jl_float16_type" => Float16, + "jl_float32_type" => Float32, + "jl_float64_type" => Float64, + "jl_ssavalue_type" => Core.SSAValue, + "jl_slotnumber_type" => Core.SlotNumber, + "jl_argument_type" => Core.Argument, + "jl_bool_type" => Bool, + "jl_char_type" => Char, + "jl_false" => false, + "jl_true" => true, + "jl_abstractstring_type" => AbstractString, + "jl_string_type" => String, + "jl_an_empty_string" => "", + "jl_function_type" => Function, + "jl_builtin_type" => Core.Builtin, + "jl_module_type" => Core.Module, + "jl_globalref_type" => Core.GlobalRef, + "jl_ref_type" => Ref, + "jl_pointer_typename" => Ptr, + "jl_voidpointer_type" => Ptr{Nothing}, + + "jl_abstractarray_type" => AbstractArray, + + "jl_densearray_type" => DenseArray, + + "jl_array_type" => Array, + + "jl_array_any_type" => Array{Any, 1}, + + "jl_array_symbol_type" => Array{Symbol, 1}, + + "jl_array_uint8_type" => Array{UInt8, 1}, + + # "jl_array_uint32_type" => Array{UInt32, 1}, + + "jl_array_int32_type" => Array{Int32, 1}, + + + "jl_expr_type" => Expr, + + "jl_method_type" => Method, + "jl_method_instance_type" => Core.MethodInstance, + "jl_code_instance_type" => Core.CodeInstance, + "jl_const_type" => Core.Const, + "jl_llvmpointer_type" => Core.LLVMPtr, + + + "jl_namedtuple_type" => NamedTuple, + + "jl_task_type" => Task, + + "jl_uint8pointer_type" => Ptr{UInt8}, + + "jl_nothing" => nothing, + + "jl_anytuple_type" => Tuple, +) +@static if VERSION >= v"1.7.0" + JuliaGlobalNameMap["jl_vararg_type"] = Core.TypeofVararg + JuliaGlobalNameMap["jl_opaque_closure_type"] = Core.OpaqueClosure +end +@static if VERSION >= v"1.8.0" + JuliaGlobalNameMap["jl_array_uint64_type"] = Array{UInt64, 1} +end +@static if VERSION >= v"1.10.0" + JuliaGlobalNameMap["jl_binding_type"] = Core.Binding +end + include("absint.jl") function emit_apply_type!(B::LLVM.IRBuilder, Ty, args)::LLVM.Value @@ -1001,7 +1109,7 @@ function emit_apply_type!(B::LLVM.IRBuilder, Ty, args)::LLVM.Value end if legal - return unsafe_to_llvm(Ty{found...}) + return unsafe_to_llvm(B, Ty{found...}) end T_jlvalue = LLVM.StructType(LLVMType[]) @@ -1011,7 +1119,7 @@ function emit_apply_type!(B::LLVM.IRBuilder, Ty, args)::LLVM.Value generic_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32]) f_apply_type, _ = get_function!(mod, "jl_f_apply_type", generic_FT) - Ty = unsafe_to_llvm(Ty) + Ty = unsafe_to_llvm(B, Ty) @static if VERSION < v"1.9.0-" FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue]; vararg=true) @@ -1047,7 +1155,7 @@ function emit_tuple!(B, args)::LLVM.Value end if legal - return unsafe_to_llvm((found...,)) + return unsafe_to_llvm(B, (found...,)) end T_jlvalue = LLVM.StructType(LLVMType[]) @@ -1075,15 +1183,15 @@ function emit_tuple!(B, args)::LLVM.Value end function emit_jltypeof!(B::LLVM.IRBuilder, arg::LLVM.Value)::LLVM.Value - legal, val = abs_typeof(arg) - if legal - return unsafe_to_llvm(val) - end - curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) + legal, val = abs_typeof(arg) + if legal + return unsafe_to_llvm(B, val) + end + T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg=true) @@ -1101,7 +1209,7 @@ function emit_methodinstance!(B::LLVM.IRBuilder, func, args)::LLVM.Value sizeT = convert(LLVMType, Csize_t) psizeT = LLVM.PointerType(sizeT) - primalvaltys = LLVM.Value[unsafe_to_llvm(Core.Typeof(func))] + primalvaltys = LLVM.Value[unsafe_to_llvm(B, Core.Typeof(func))] for a in args push!(primalvaltys, emit_jltypeof!(B, a)) end @@ -1122,7 +1230,7 @@ function emit_methodinstance!(B::LLVM.IRBuilder, func, args)::LLVM.Value # sv = emit_svec!(B, tosv[2:end]) # - meth = unsafe_to_llvm(meth) + meth = unsafe_to_llvm(B, meth) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -1142,7 +1250,7 @@ function emit_methodinstance!(B::LLVM.IRBuilder, func, args)::LLVM.Value @static if VERSION < v"1.8.0-" methodmatch = call!(B, FT, worlds, LLVM.Value[tag, LLVM.ConstantInt(sizeT, world), minworld, maxworld]) else - methodmatch = call!(B, FT, worlds, LLVM.Value[tag, unsafe_to_llvm(nothing), LLVM.ConstantInt(sizeT, world), minworld, maxworld]) + methodmatch = call!(B, FT, worlds, LLVM.Value[tag, unsafe_to_llvm(B, nothing), LLVM.ConstantInt(sizeT, world), minworld, maxworld]) end # emit_jl!(B, methodmatch) # emit_jl!(B, emit_jltypeof!(B, methodmatch)) @@ -1297,13 +1405,6 @@ struct Return2 ret2::Any end -struct Return3 - ret1::Any - ret2::Any - ret3::Any -end -AnyArray(Length::Int) = NamedTuple{ntuple(i->Symbol(i), Val(Length)),NTuple{Length,Any}} - function permit_inlining!(f::LLVM.Function) for bb in blocks(f), inst in instructions(bb) # remove illegal invariant.load and jtbaa_const invariants @@ -2776,7 +2877,7 @@ function julia_default_tape_type(C::LLVM.API.LLVMContextRef) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) return T_prjlvalue.ref end -function julia_undef_value_for_type(Ty::LLVM.API.LLVMTypeRef, forceZero::UInt8)::LLVM.API.LLVMValueRef +function julia_undef_value_for_type(mod::LLVM.API.LLVMModuleRef, Ty::LLVM.API.LLVMTypeRef, forceZero::UInt8)::LLVM.API.LLVMValueRef ty = LLVM.LLVMType(Ty) if !any_jltypes(ty) if forceZero != 0 @@ -2786,7 +2887,7 @@ function julia_undef_value_for_type(Ty::LLVM.API.LLVMTypeRef, forceZero::UInt8): end end if isa(ty, LLVM.PointerType) - val = unsafe_to_llvm(nothing) + val = unsafe_nothing_to_llvm(LLVM.Module(mod)) if !is_opaque(ty) val = const_pointercast(val, LLVM.PointerType(eltype(ty), Tracked)) end @@ -2796,13 +2897,13 @@ function julia_undef_value_for_type(Ty::LLVM.API.LLVMTypeRef, forceZero::UInt8): return val.ref end if isa(ty, LLVM.ArrayType) - st = LLVM.Value(julia_undef_value_for_type(eltype(ty).ref, forceZero)) + st = LLVM.Value(julia_undef_value_for_type(mod, eltype(ty).ref, forceZero)) return ConstantArray(eltype(ty), [st for i in 1:length(ty)]).ref end if isa(ty, LLVM.StructType) vals = LLVM.Constant[] for st in LLVM.elements(ty) - push!(vals, LLVM.Value(julia_undef_value_for_type(st.ref, forceZero))) + push!(vals, LLVM.Value(julia_undef_value_for_type(mod, st.ref, forceZero))) end return ConstantStruct(ty, vals).ref end @@ -2820,7 +2921,9 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie @assert has rt = active_reg_inner(Ty, (), world) if rt == ActiveState || rt == MixedState - operands(V)[3] = unsafe_to_llvm(Base.RefValue{Ty}) + B = LLVM.IRBuilder() + position!(B, V) + operands(V)[3] = unsafe_to_llvm(B, Base.RefValue{Ty}) end end nothing @@ -2851,7 +2954,7 @@ function fixup_return(B, retval) if isa(ty, LLVM.StructType) elems = LLVM.elements(ty) if length(elems) == 2 && elems[1] == T_prjlvalue - fill_val = unsafe_to_llvm(nothing) + fill_val = unsafe_to_llvm(B, nothing) prev = extract_value!(B, retval, 0) eq = icmp!(B, LLVM.API.LLVMIntEQ, prev, LLVM.null(T_prjlvalue)) retval = select!(B, eq, insert_value!(B, retval, fill_val, 0), retval) @@ -2885,7 +2988,8 @@ function zero_single_allocation(builder, jlType, LLVMType, nobj, zeroAll, idx) if isa(ty, LLVM.PointerType) if any_jltypes(ty) loc = gep!(builder, LLVMType, nobj, path) - fill_val = unsafe_to_llvm(nothing) + mod = LLVM.parent(LLVM.parent(Base.position(builder))) + fill_val = unsafe_nothing_to_llvm(mod) loc = bitcast!(builder, loc, LLVM.PointerType(T_prjlvalue, addrspace(value_type(loc)))) store!(builder, fill_val, loc) elseif zeroAll @@ -3024,7 +3128,7 @@ function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) end # Obtain tag - tag = unsafe_to_llvm(ETT) + tag = unsafe_to_llvm(B, ETT) else if sizeof(Int) == sizeof(Int64) boxed_count = emit_box_int64!(B, Count) @@ -3033,7 +3137,7 @@ function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) Count = trunc!(B, Count, T_size_t) boxed_count = emit_box_int32!(B, Count) end - tag = emit_apply_type!(B, NTuple, (boxed_count, unsafe_to_llvm(TT))) + tag = emit_apply_type!(B, NTuple, (boxed_count, unsafe_to_llvm(B, TT))) end # Check if Julia version has https://github.com/JuliaLang/julia/pull/46914 @@ -3126,6 +3230,40 @@ end include("rules/allocrules.jl") include("rules/llvmrules.jl") +for (k, v) in ( + ("enz_runtime_newtask_fwd", Enzyme.Compiler.runtime_newtask_fwd), + ("enz_runtime_newtask_augfwd", Enzyme.Compiler.runtime_newtask_augfwd), + + ("enz_runtime_generic_fwd", Enzyme.Compiler.runtime_generic_fwd), + ("enz_runtime_generic_augfwd", Enzyme.Compiler.runtime_generic_augfwd), + ("enz_runtime_generic_rev", Enzyme.Compiler.runtime_generic_rev), + + ("enz_runtime_iterate_fwd", Enzyme.Compiler.runtime_iterate_fwd), + ("enz_runtime_iterate_augfwd", Enzyme.Compiler.runtime_iterate_augfwd), + ("enz_runtime_iterate_rev", Enzyme.Compiler.runtime_iterate_rev), + + ("enz_runtime_newstruct_augfwd", Enzyme.Compiler.runtime_newstruct_augfwd), + ("enz_runtime_newstruct_rev", Enzyme.Compiler.runtime_newstruct_rev), + + ("enz_runtime_tuple_augfwd", Enzyme.Compiler.runtime_tuple_augfwd), + ("enz_runtime_tuple_rev", Enzyme.Compiler.runtime_tuple_rev), + + + ("enz_runtime_jl_getfield_aug", Enzyme.Compiler.rt_jl_getfield_aug), + ("enz_runtime_jl_getfield_rev", Enzyme.Compiler.rt_jl_getfield_rev), + + ("enz_runtime_idx_jl_getfield_aug", Enzyme.Compiler.idx_jl_getfield_aug), + ("enz_runtime_idx_jl_getfield_rev", Enzyme.Compiler.idx_jl_getfield_aug), + + ("enz_runtime_jl_setfield_aug", Enzyme.Compiler.rt_jl_setfield_aug), + ("enz_runtime_jl_setfield_rev", Enzyme.Compiler.rt_jl_setfield_rev), + + ("enz_runtime_error_if_differentiable", Enzyme.Compiler.error_if_differentiable), + ("enz_runtime_error_if_active", Enzyme.Compiler.error_if_active), +) + JuliaEnzymeNameMap[k] = v +end + function __init__() API.memmove_warning!(false) API.typeWarning!(false) @@ -3150,7 +3288,7 @@ function __init__() fixup_return, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef))) API.EnzymeSetUndefinedValueForType(@cfunction( - julia_undef_value_for_type, LLVM.API.LLVMValueRef, (LLVM.API.LLVMTypeRef,UInt8))) + julia_undef_value_for_type, LLVM.API.LLVMValueRef, (LLVM.API.LLVMModuleRef, LLVM.API.LLVMTypeRef,UInt8))) API.EnzymeSetShadowAllocRewrite(@cfunction( shadow_alloc_rewrite, Cvoid, (LLVM.API.LLVMValueRef,API.EnzymeGradientUtilsRef))) register_alloc_rules() @@ -3278,15 +3416,13 @@ function annotate!(mod, mode) API.EnzymeAttributeKnownFunctions(f.ref) end -@static if VERSION >= v"1.8-" for gname in inactiveglobs globs = LLVM.globals(mod) if haskey(globs, gname) glob = globs[gname] - metadata(glob)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) + API.SetMD(glob, "enzyme_inactive", LLVM.MDNode(LLVM.Metadata[])) end end -end for fname in inactivefns if haskey(fns, fname) @@ -3570,7 +3706,7 @@ function enzyme_extract_world(fn::LLVM.Function)::UInt end end end - GPUCompiler.@safe_error "Enzyme: Could not find world", fn + throw(AssertionError("Enzyme: could not find world in $(string(fn))")) end function enzyme_custom_extract_mi(orig::LLVM.Instruction, error=true) @@ -4081,6 +4217,12 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, realparms = LLVM.Value[] i = 1 + for attr in collect(function_attributes(enzymefn)) + if kind(attr) == "enzymejl_world" + push!(function_attributes(llvm_f), attr) + end + end + if returnRoots sret = params[i] i+= 1 @@ -4309,7 +4451,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end @assert !(isghostty(combinedReturn) || Core.Compiler.isconstType(combinedReturn) ) @assert Core.Compiler.isconstType(ty) - eval = makeInstanceOf(ty) + eval = makeInstanceOf(builder, ty) eval = fixup_abi(i, eval) ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) @@ -4339,7 +4481,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end for returnNum in 0:(count_Sret-1) eval = fixup_abi(returnNum+1, if count_llvm_Sret == 0 - makeInstanceOf(sret_types[returnNum+1]) + makeInstanceOf(builder, sret_types[returnNum+1]) elseif count_llvm_Sret == 1 val else @@ -4360,7 +4502,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, eval = fixup_abi(returnNum+1, if !isghostty(actualRetType) extract_value!(builder, val, returnNum) else - makeInstanceOf(sret_types[returnNum+1]) + makeInstanceOf(builder, sret_types[returnNum+1]) end) store!(builder, eval, inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), length(elements(jltype))-1 )])) returnNum+=1 @@ -4827,6 +4969,12 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function end end + for attr in collect(function_attributes(entry_f)) + if kind(attr) == "enzymejl_world" + push!(function_attributes(wrapper_f), attr) + end + end + seen = TypeTreeTable() # emit IR performing the "conversions" let builder = IRBuilder() @@ -4982,7 +5130,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function position!(builder, BB) if isghostty(jlrettype) || Core.Compiler.isconstType(jlrettype) - fill_val = unsafe_to_llvm(jlrettype.instance) + fill_val = unsafe_to_llvm(builder, jlrettype.instance) ret!(builder, fill_val) else nobj = if sretPtr !== nothing @@ -5785,8 +5933,6 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; primalf, returnRoots, boxedArgs, loweredArgs = lower_convention(source_sig, mod, primalf, actualRetType, job.config.params.rt, TT) end - push!(function_attributes(primalf), StringAttribute("enzymejl_world", string(job.world))) - if primal_job.config.target isa GPUCompiler.NativeCompilerTarget target_machine = JIT.get_tm() else diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 7aebf9be0f..c2fd190641 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -813,8 +813,12 @@ function fix_decayaddr!(mod::LLVM.Module) if operands(st)[2] == inst LLVM.API.LLVMSetOperand(st, 2-1, operands(inst)[1]) continue - end + end end + if isa(st, LLVM.LoadInst) + LLVM.API.LLVMSetOperand(st, 1-1, operands(inst)[1]) + continue + end # if isa(st, LLVM.InsertValueInst) # if operands(st)[1] == inst # push!(invalid, st) diff --git a/src/compiler/orcv1.jl b/src/compiler/orcv1.jl index 1b6bd2fe81..4bbaa0125f 100644 --- a/src/compiler/orcv1.jl +++ b/src/compiler/orcv1.jl @@ -1,6 +1,7 @@ module JIT using LLVM +using Libdl import LLVM: TargetMachine import GPUCompiler: CompilerJob, JuliaContext @@ -127,11 +128,37 @@ function resolver(name, ctx) name = name[2:end] end end - LLVM.API.LLVMSearchForAddressOfSymbol(name) + + found = false + val = nothing + hnd = Libdl.dlopen("libjulia") + for (k, v) in Compiler.JuliaGlobalNameMap + if "ejl_"*k == name + val = unsafe_load(Base.reinterpret(Ptr{Ptr{Cvoid}}, Libdl.dlsym(hnd, k))) + found = true + break + end + end + + if !found + for (k, v) in Compiler.JuliaEnzymeNameMap + if "ejl_"*k == name + val = Compiler.unsafe_to_ptr(v) + found = true + break + end + end + end + + if found + val + else + LLVM.API.LLVMSearchForAddressOfSymbol(name) + end ## Step 4: Lookup in libatomic # TODO: Do we need to do this? catch ex - @error "Enzyme: Lookup failed" jl_name exception=(ex, Base.catch_backtrace()) + @error "Enzyme: Lookup failed" name exception=(ex, Base.catch_backtrace()) C_NULL end if ptr === C_NULL diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index e61560548b..90bcb540cd 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -1,6 +1,7 @@ module JIT using LLVM +using Libdl import LLVM:TargetMachine import GPUCompiler @@ -131,6 +132,18 @@ function __init__() jit[] = CompilerInstance(lljit, nothing, nothing) end + hnd = Libdl.dlopen("libjulia") + + for (k, v) in Compiler.JuliaGlobalNameMap + ptr = unsafe_load(Base.reinterpret(Ptr{Ptr{Cvoid}}, Libdl.dlsym(hnd, k))) + LLVM.define(jd_main, absolute_symbol_materialization(mangle(lljit, "ejl_"*k), ptr)) + end + + for (k, v) in Compiler.JuliaEnzymeNameMap + ptr = Compiler.unsafe_to_ptr(v) + LLVM.define(jd_main, absolute_symbol_materialization(mangle(lljit, "ejl_"*k), ptr)) + end + atexit() do @static if !use_ojit() ci = jit[] diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 3649188320..8bf562addf 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -861,7 +861,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width if reg == ActiveState || reg == MixedState NTy = Base.RefValue{Ty} @assert sizeof(Ty) == sizeof(NTy) - LLVM.API.LLVMSetOperand(cur, 2, unsafe_to_llvm(NTy)) + LLVM.API.LLVMSetOperand(cur, 2, unsafe_to_llvm(LLVM.IRBuilder(cur), NTy)) end continue end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 6a5c101786..3303dd6c59 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -31,7 +31,8 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) - world = enzyme_extract_world(LLVM.parent(LLVM.parent(orig))) + ofn = LLVM.parent(LLVM.parent(orig)) + world = enzyme_extract_world(ofn) for arg in jlargs @assert arg.cc != RemovedParam @@ -55,7 +56,7 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) - val = unsafe_to_llvm(arg.typ.parameters[1]) + val = unsafe_to_llvm(B, arg.typ.parameters[1]) store!(B, val, ptr) if any_jltypes(llty) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index c624d85202..8f68cec991 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1158,7 +1158,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, ActivityList = LLVM.Value[] @assert length(ops) != 0 - fill_val = unsafe_to_llvm(nothing) + fill_val = unsafe_to_llvm(B, nothing) vals = LLVM.Value[] @@ -1185,7 +1185,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, active = !is_constant_value(gutils, op) if !active - push!(ActivityList, unsafe_to_llvm(false)) + push!(ActivityList, unsafe_to_llvm(B, false)) else inverted = invert_pointer(gutils, op, B) if lookup @@ -1197,9 +1197,9 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, else extract_value!(B, inverted, 0) end - push!(ActivityList, select!(B, icmp!(B, LLVM.API.LLVMIntNE, val, inv_0), unsafe_to_llvm(true), unsafe_to_llvm(false))) + push!(ActivityList, select!(B, icmp!(B, LLVM.API.LLVMIntNE, val, inv_0), unsafe_to_llvm(B, true), unsafe_to_llvm(B, false))) else - push!(ActivityList, unsafe_to_llvm(true)) + push!(ActivityList, unsafe_to_llvm(B, true)) end end @@ -1227,7 +1227,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, pushfirst!(vals, tape) end else - pushfirst!(vals, unsafe_to_llvm(Val(ReturnType))) + pushfirst!(vals, unsafe_to_llvm(B, Val(ReturnType))) end if firstconst && firstconst_after_tape @@ -1248,10 +1248,10 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, for idx in 1:(length(ops)+firstconst) push!(ModifiedBetween, uncacheable[(start-1)+idx] != 0) end - pushfirst!(vals, unsafe_to_llvm(Val((ModifiedBetween...,)))) + pushfirst!(vals, unsafe_to_llvm(B, Val((ModifiedBetween...,)))) end - pushfirst!(vals, unsafe_to_llvm(Val(Int(width)))) + pushfirst!(vals, unsafe_to_llvm(B, Val(Int(width)))) etup0 = emit_tuple!(B, ActivityList) etup = emit_apply_type!(B, Base.Val, [etup0]) if isa(etup, LLVM.Instruction) @@ -1264,7 +1264,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, mi = emit_methodinstance!(B, func, vals) end - pushfirst!(vals, unsafe_to_llvm(func)) + pushfirst!(vals, unsafe_to_llvm(B, func)) @static if VERSION < v"1.7.0-" || true else diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index e52c4e09c6..664d643af0 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -709,6 +709,8 @@ end return false end + mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) + width = get_width(gutils) origh, origkey, origdflt = operands(orig)[1:end-1] @@ -747,7 +749,7 @@ end newops = LLVM.Value[shadowh, new_from_original(gutils, origkey), shadowdflt] cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) callconv!(cal, callconv(orig)) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active), emit_jltypeof!(B, cal)]) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, cal)]) cal else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) @@ -756,7 +758,7 @@ end newops = LLVM.Value[extract_value!(B, shadowh, j-1), new_from_original(gutils, origkey), extract_value!(B, shadowdflt, j-1)] cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) callconv!(cal, callconv(orig)) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active), emit_jltypeof!(B, cal)]) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, cal)]) shadow = insert_value!(B, shadow, cal, j-1) end shadow @@ -820,10 +822,12 @@ end invert_pointer(gutils, origval, B) end + mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) + newvals = API.CValueType[API.VT_Shadow, API.VT_Primal, API.VT_Shadow, API.VT_None] shadowres = if width == 1 - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active), emit_jltypeof!(B, shadowval)]) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, shadowval)]) newops = LLVM.Value[shadowh, new_from_original(gutils, origkey), shadowval, LLVM.null(value_type(originserted))] cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) callconv!(cal, callconv(orig)) @@ -833,7 +837,7 @@ end shadow = LLVM.UndefValue(ST) for j in 1:width sval2 = extract_value!(B, shadowval, j-1) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active), emit_jltypeof!(B, sval2)]) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, sval2)]) newops = LLVM.Value[extract_value!(B, shadowh, j-1), new_from_original(gutils, origkey), sval2, LLVM.null(value_type(originserted))] cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) callconv!(cal, callconv(orig)) diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index e4fbab6faf..651c05952e 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -306,7 +306,7 @@ end v = load!(B, pllty, v) end else - v = makeInstanceOf(ppfuncT) + v = makeInstanceOf(B, ppfuncT) end if refed @@ -516,7 +516,6 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) width = get_width(gutils) mode = get_mode(gutils) @@ -526,13 +525,13 @@ end ops = collect(operands(orig)) vals = LLVM.Value[ - unsafe_to_llvm(runtime_newtask_fwd), - unsafe_to_llvm(Val(world)), + unsafe_to_llvm(B, runtime_newtask_fwd), + unsafe_to_llvm(B, Val(world)), new_from_original(gutils, ops[1]), invert_pointer(gutils, ops[1], B), new_from_original(gutils, ops[2]), (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)(B, new_from_original(gutils, ops[3])), - unsafe_to_llvm(Val(width)), + unsafe_to_llvm(B, Val(width)), ] ntask = emit_apply_generic!(B, vals) @@ -560,7 +559,6 @@ end end normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing - mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -577,14 +575,14 @@ end ops = collect(operands(orig)) vals = LLVM.Value[ - unsafe_to_llvm(runtime_newtask_augfwd), - unsafe_to_llvm(Val(world)), + unsafe_to_llvm(B, runtime_newtask_augfwd), + unsafe_to_llvm(B, Val(world)), new_from_original(gutils, ops[1]), invert_pointer(gutils, ops[1], B), new_from_original(gutils, ops[2]), (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)(B, new_from_original(gutils, ops[3])), - unsafe_to_llvm(Val(width)), - unsafe_to_llvm(Val(ModifiedBetween)), + unsafe_to_llvm(B, Val(width)), + unsafe_to_llvm(B, Val(ModifiedBetween)), ] ntask = emit_apply_generic!(B, vals) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 2765edecb9..dafc367ef3 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -872,20 +872,20 @@ function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, ta end AA = Val(AnyArray(Int(width))) - vals = LLVM.Value[unsafe_to_llvm(AA)] + vals = LLVM.Value[unsafe_to_llvm(B, AA)] push!(vals, inps[1]) sym = new_from_original(gutils, ops[3]) sym = emit_apply_type!(B, Base.Val, [sym]) push!(vals, sym) - push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, ops[2])))) + push!(vals, unsafe_to_llvm(B, Val(is_constant_value(gutils, ops[2])))) for v in inps[2:end] push!(vals, v) end - pushfirst!(vals, unsafe_to_llvm(rt_jl_getfield_aug)) + pushfirst!(vals, unsafe_to_llvm(B, rt_jl_getfield_aug)) cal = emit_apply_generic!(B, vals) @@ -965,13 +965,13 @@ function common_jl_getfield_rev(offset, B, orig, gutils, tape) sym = emit_apply_type!(B, Base.Val, [sym]) push!(vals, sym) - push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, ops[2])))) + push!(vals, unsafe_to_llvm(B, Val(is_constant_value(gutils, ops[2])))) for v in inps[2:end] push!(vals, v) end - pushfirst!(vals, unsafe_to_llvm(rt_jl_getfield_rev)) + pushfirst!(vals, unsafe_to_llvm(B, rt_jl_getfield_rev)) cal = emit_apply_generic!(B, vals) @@ -1050,7 +1050,7 @@ end end AA = Val(AnyArray(Int(width))) - vals = LLVM.Value[unsafe_to_llvm(AA)] + vals = LLVM.Value[unsafe_to_llvm(B, AA)] push!(vals, inps[1]) sym = new_from_original(gutils, ops[2]) @@ -1058,13 +1058,13 @@ end sym = emit_apply_type!(B, Base.Val, [sym]) push!(vals, sym) - push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, ops[1])))) + push!(vals, unsafe_to_llvm(B, Val(is_constant_value(gutils, ops[1])))) for v in inps[2:end] push!(vals, v) end - pushfirst!(vals, unsafe_to_llvm(idx_jl_getfield_aug)) + pushfirst!(vals, unsafe_to_llvm(B, idx_jl_getfield_aug)) cal = emit_apply_generic!(B, vals) @@ -1146,13 +1146,13 @@ end sym = emit_apply_type!(B, Base.Val, [sym]) push!(vals, sym) - push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, ops[1])))) + push!(vals, unsafe_to_llvm(B, Val(is_constant_value(gutils, ops[1])))) for v in inps[2:end] push!(vals, v) end - pushfirst!(vals, unsafe_to_llvm(idx_jl_getfield_rev)) + pushfirst!(vals, unsafe_to_llvm(B, idx_jl_getfield_rev)) cal = emit_apply_generic!(B, vals) @@ -1262,16 +1262,18 @@ function common_setfield_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR nothing end + mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) + for idx in 1:width vals = LLVM.Value[ (width == 1) ? shadowstruct : extract_value!(B, shadowstruct, idx-1), new_from_original(gutils, origops[3]), - unsafe_to_llvm(Val(is_constant_value(gutils, origops[4]))), + unsafe_to_llvm(B, Val(is_constant_value(gutils, origops[4]))), new_from_original(gutils, origops[4]), - is_constant_value(gutils, origops[4]) ? unsafe_to_llvm(nothing) : ((width == 1) ? shadowval : extract_value!(B, shadowval, idx-1)), + is_constant_value(gutils, origops[4]) ? unsafe_to_llvm(B, nothing) : ((width == 1) ? shadowval : extract_value!(B, shadowval, idx-1)), ] - pushfirst!(vals, unsafe_to_llvm(rt_jl_setfield_aug)) + pushfirst!(vals, unsafe_to_llvm(B, rt_jl_setfield_aug)) cal = emit_apply_generic!(B, vals) @@ -1294,17 +1296,19 @@ function common_setfield_rev(offset, B, orig, gutils, tape) else nothing end + + mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) for idx in 1:width vals = LLVM.Value[ lookup_value(gutils, (width == 1) ? shadowstruct : extract_value!(B, shadowstruct, idx-1), B), lookup_value(gutils, new_from_original(gutils, origops[3]), B), - unsafe_to_llvm(Val(is_constant_value(gutils, origops[4]))), + unsafe_to_llvm(B, Val(is_constant_value(gutils, origops[4]))), lookup_value(gutils, new_from_original(gutils, origops[4]), B), - is_constant_value(gutils, origops[4]) ? unsafe_to_llvm(nothing) : lookup_value(gutils, ((width == 1) ? shadowval : extract_value!(B, shadowval, idx-1)), B), + is_constant_value(gutils, origops[4]) ? unsafe_to_llvm(B, nothing) : lookup_value(gutils, ((width == 1) ? shadowval : extract_value!(B, shadowval, idx-1)), B), ] - pushfirst!(vals, unsafe_to_llvm(rt_jl_setfield_rev)) + pushfirst!(vals, unsafe_to_llvm(B, rt_jl_setfield_rev)) cal = emit_apply_generic!(B, vals) @@ -1375,6 +1379,8 @@ function common_f_svec_ref_augfwd(offset, B, orig, gutils, normalR, shadowR, tap mi = new_from_original(gutils, origmi) + mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) + shadowres = if width == 1 newops = LLVM.Value[mi, shadowh, new_from_original(gutils, origkey)] if offset != 1 @@ -1384,7 +1390,7 @@ function common_f_svec_ref_augfwd(offset, B, orig, gutils, normalR, shadowR, tap callconv!(cal, callconv(orig)) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(errfn), emit_jltypeof!(B, cal)]) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, errfn), emit_jltypeof!(B, cal)]) cal else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) @@ -1396,7 +1402,7 @@ function common_f_svec_ref_augfwd(offset, B, orig, gutils, normalR, shadowR, tap end cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) callconv!(cal, callconv(orig)) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(errfn), emit_jltypeof!(B, cal)]) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, errfn), emit_jltypeof!(B, cal)]) shadow = insert_value!(B, shadow, cal, j-1) end shadow diff --git a/src/utils.jl b/src/utils.jl index 916818181e..fdc15db247 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -17,16 +17,21 @@ export Tracked, Derived const captured_constants = Base.IdSet{Any}() -# This mimicks literal_pointer_val / literal_pointer_val_slot -function unsafe_to_llvm(val) +function unsafe_nothing_to_llvm(mod::LLVM.Module) + globs = LLVM.globals(mod) + k = "jl_nothing" + if Base.haskey(globs, "ejl_"*k) + return globs["ejl_"*k] + end T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) - # XXX: This prevents code from being runtime relocatable - # We likely should emit global variables and use something - # like `absolute_symbol_materialization` and write out cache-files - # that have relocation tables. - # TODO: What about things like `nothing` + gv = LLVM.GlobalVariable(mod, T_jlvalue, "ejl_"*k, Tracked) + + API.SetMD(gv, "enzyme_ta_norecur", LLVM.MDNode(LLVM.Metadata[])) + API.SetMD(gv, "enzyme_inactive", LLVM.MDNode(LLVM.Metadata[])) + return gv +end + +function unsafe_to_ptr(@nospecialize(val)) if !Base.ismutable(val) val = Core.Box(val) # FIXME many objects could be leaked here @assert Base.ismutable(val) @@ -37,16 +42,77 @@ function unsafe_to_llvm(val) push!(captured_constants, val) # Globally root ptr = Base.pointer_from_objref(val) end + return ptr +end +export unsafe_to_ptr + +# This mimicks literal_pointer_val / literal_pointer_val_slot +function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val)) + T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) + + for (k, v) in Compiler.JuliaGlobalNameMap + if v === val + mod = LLVM.parent(LLVM.parent(LLVM.position(B))) + globs = LLVM.globals(mod) + if Base.haskey(globs, "ejl_"*k) + return globs["ejl_"*k] + end + gv = LLVM.GlobalVariable(mod, T_jlvalue, "ejl_"*k, Tracked) + + API.SetMD(gv, "enzyme_ta_norecur", LLVM.MDNode(LLVM.Metadata[])) + legal, jTy = Compiler.abs_typeof(gv, true) + if legal + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + world = Compiler.enzyme_extract_world(fn) + if Compiler.guaranteed_const_nongen(jTy, world) + API.SetMD(gv, "enzyme_inactive", LLVM.MDNode(LLVM.Metadata[])) + end + end + return gv + end + end + + for (k, v) in Compiler.JuliaEnzymeNameMap + if v === val + mod = LLVM.parent(LLVM.parent(LLVM.position(B))) + globs = LLVM.globals(mod) + if Base.haskey(globs, "ejl_"*k) + return globs["ejl_"*k] + end + gv = LLVM.GlobalVariable(mod, T_jlvalue, "ejl_"*k, Tracked) + API.SetMD(gv, "enzyme_ta_norecur", LLVM.MDNode(LLVM.Metadata[])) + legal, jTy = Compiler.abs_typeof(gv, true) + if legal + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + world = Compiler.enzyme_extract_world(fn) + if Compiler.guaranteed_const_nongen(jTy, world) + API.SetMD(gv, "enzyme_inactive", LLVM.MDNode(LLVM.Metadata[])) + end + end + return gv + end + end + + # XXX: This prevents code from being runtime relocatable + # We likely should emit global variables and use something + # like `absolute_symbol_materialization` and write out cache-files + # that have relocation tables. + ptr = unsafe_to_ptr(val) + fill_val = LLVM.ConstantInt(convert(UInt, ptr)) fill_val = LLVM.const_inttoptr(fill_val, T_prjlvalue_UT) LLVM.const_addrspacecast(fill_val, T_prjlvalue) end -export unsafe_to_llvm +export unsafe_to_llvm, unsafe_nothing_to_llvm -function makeInstanceOf(@nospecialize(T)) +function makeInstanceOf(B::LLVM.IRBuilder, @nospecialize(T)) @assert Core.Compiler.isconstType(T) @assert T <: Type - return unsafe_to_llvm(T.parameters[1]) + return unsafe_to_llvm(B, T.parameters[1]) end export makeInstanceOf From b0e13be0ff478267b1981093fdd4e8e7d981b812 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 28 Jul 2024 00:12:03 -0400 Subject: [PATCH 201/495] Simplify libjulia usage (#1687) * Simplify libjulia usage * try fix * fx * champ * tmp * fix * bump --- src/compiler/orcv1.jl | 3 ++- src/compiler/orcv2.jl | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/compiler/orcv1.jl b/src/compiler/orcv1.jl index 4bbaa0125f..bcac867e73 100644 --- a/src/compiler/orcv1.jl +++ b/src/compiler/orcv1.jl @@ -131,7 +131,8 @@ function resolver(name, ctx) found = false val = nothing - hnd = Libdl.dlopen("libjulia") + + hnd = unsafe_load(cglobal(:jl_libjulia_handle, Ptr{Cvoid})) for (k, v) in Compiler.JuliaGlobalNameMap if "ejl_"*k == name val = unsafe_load(Base.reinterpret(Ptr{Ptr{Cvoid}}, Libdl.dlsym(hnd, k))) diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index 90bcb540cd..61971f47ed 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -132,8 +132,11 @@ function __init__() jit[] = CompilerInstance(lljit, nothing, nothing) end - hnd = Libdl.dlopen("libjulia") - + hnd = @static if VERSION >= v"1.10" + unsafe_load(cglobal(:jl_libjulia_handle, Ptr{Cvoid})) + else + Libdl.dlopen("libjulia") + end for (k, v) in Compiler.JuliaGlobalNameMap ptr = unsafe_load(Base.reinterpret(Ptr{Ptr{Cvoid}}, Libdl.dlsym(hnd, k))) LLVM.define(jd_main, absolute_symbol_materialization(mangle(lljit, "ejl_"*k), ptr)) From b1000f016c1e3eaaf5b4ed2733af6bb55e3eb5f1 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 28 Jul 2024 13:39:58 -0400 Subject: [PATCH 202/495] Update Project.toml (#1689) --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 2760fa45a1..4e7abf5e67 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.25" +version = "0.12.26" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -32,7 +32,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.7" -Enzyme_jll = "0.0.138" +Enzyme_jll = "0.0.140" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" LogExpFunctions = "0.3" From d89dbce9c3f652c81b038e57633a94a114bff9e3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 28 Jul 2024 13:43:11 -0400 Subject: [PATCH 203/495] Update pipeline.yml --- .buildkite/pipeline.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 5ce458d7d1..6de936a558 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -3,8 +3,6 @@ steps: matrix: setup: version: - - "1.8" - - "1.9" - "1.10" plugins: - JuliaCI/julia#v1: From 9b099a6e568c6e26cdce1003263ca4067af03245 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 1 Aug 2024 13:38:54 -0400 Subject: [PATCH 204/495] Add dump mod post optimizations option (#1694) --- src/api.jl | 4 ++++ src/compiler.jl | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/src/api.jl b/src/api.jl index 2eaae18901..d5cb2a5451 100644 --- a/src/api.jl +++ b/src/api.jl @@ -763,6 +763,10 @@ function EnzymeReplaceFunctionImplementation(mod) ccall((:EnzymeReplaceFunctionImplementation, libEnzyme),Cvoid,(LLVM.API.LLVMModuleRef,), mod) end +function EnzymeDumpModuleRef(mod) + ccall((:EnzymeDumpModuleRef, libEnzyme),Cvoid,(LLVM.API.LLVMModuleRef,), mod) +end + EnzymeComputeByteOffsetOfGEP(B, V, T) = LLVM.Value(ccall((:EnzymeComputeByteOffsetOfGEP, libEnzyme), LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMTypeRef), B, V, T)) EnzymeAllocaType(al) = LLVM.LLVMType(ccall((:EnzymeAllocaType, libEnzyme), LLVM.API.LLVMTypeRef, (LLVM.API.LLVMValueRef,), al)) diff --git a/src/compiler.jl b/src/compiler.jl index f9fdb34a4c..33f14424a6 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6866,6 +6866,8 @@ function _link(job, (mod, adjoint_name, primal_name, TapeType)) return CompileResult(adjoint_ptr, primal_ptr, TapeType) end +const DumpPostOpt = Ref(false) + # actual compilation function _thunk(job, postopt::Bool=true) mod, meta = codegen(:llvm, job; optimize=false) @@ -6888,6 +6890,9 @@ function _thunk(job, postopt::Bool=true) if postopt if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI post_optimze!(mod, JIT.get_tm()) + if DumpPostOpt[] + API.EnzymeDumpModuleRef(mod.ref) + end else propagate_returned!(mod) end From 8a850c77be67604e9b9429be19f2876db29899b3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 1 Aug 2024 13:39:19 -0400 Subject: [PATCH 205/495] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 4e7abf5e67..e47535efbd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.26" +version = "0.12.27" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -32,7 +32,7 @@ EnzymeStaticArraysExt = "StaticArrays" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.7" -Enzyme_jll = "0.0.140" +Enzyme_jll = "0.0.141" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" LogExpFunctions = "0.3" From c44bf01460942f2c94b97fd8b2b93755e1973a32 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 4 Aug 2024 13:08:54 -0400 Subject: [PATCH 206/495] Fix interpreter caches (#1698) * Sdebug alloc inline * Update Project.toml * Update Enzyme.jl * Update compiler.jl * Update compiler.jl * Update compiler.jl * Fix 1.6/1.7 --- src/compiler.jl | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 33f14424a6..9f7ca6c373 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3361,20 +3361,38 @@ struct EnzymeCacheToken always_inline method_table::Core.MethodTable param_type::Type - mode::API.CDerivativeMode + is_fwd::API.CDerivativeMode end GPUCompiler.ci_cache_token(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = EnzymeCacheToken( typeof(job.config.target), job.config.always_inline, GPUCompiler.method_table(job), - typeof(job.config.params), job.config.params.mode, + typeof(job.config.params), job.config.params.mode == API.DEM_ForwardMode, ) GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = Interpreter.EnzymeInterpreter(GPUCompiler.ci_cache_token(job), GPUCompiler.method_table(job), job.world, job.config.params.mode) else + +# the codeinstance cache to use -- should only be used for the constructor +# Note that the only way the interpreter modifies codegen is either not inlining a fwd mode +# rule or not inlining a rev mode rule. Otherwise, all caches can be re-used. +const GLOBAL_FWD_CACHE = GPUCompiler.CodeCache() +const GLOBAL_REV_CACHE = GPUCompiler.CodeCache() +function enzyme_ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) + return if job.config.params.mode == API.DEM_ForwardMode + GLOBAL_FWD_CACHE + else + GLOBAL_REV_CACHE + end +end + +@static if VERSION < v"1.8" +GPUCompiler.ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = enzyme_ci_cache(job) +end + GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = - Interpreter.EnzymeInterpreter(GPUCompiler.ci_cache(job), GPUCompiler.method_table(job), job.world, job.config.params.mode) + Interpreter.EnzymeInterpreter(enzyme_ci_cache(job), GPUCompiler.method_table(job), job.world, job.config.params.mode) end include("compiler/passes.jl") @@ -6952,7 +6970,7 @@ end run_enzyme = false Const else - A + A end if run_enzyme && !(A2 <: Const) && guaranteed_const_nongen(rrt, World) From c0c07c318acad2524aba44a4d521298dc253302c Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 6 Aug 2024 13:12:06 -0400 Subject: [PATCH 207/495] Interpreter: single return type (#1703) * Interpreter: single return type * fix --- src/Enzyme.jl | 36 +++++++++++++++++++++++++----------- src/compiler.jl | 20 ++++++++++++++++++++ 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 9ff56bdd81..72ed79eca0 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -229,7 +229,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) [`Active`](@ref) will automatically convert plain integers to floating point values, but cannot do so for integer values in tuples and structs. """ -@inline function autodiff(::ReverseMode{ReturnPrimal, RABI,Holomorphic}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RABI<:ABI,Holomorphic, Nargs} +@inline function autodiff(rmode::ReverseMode{ReturnPrimal, RABI,Holomorphic}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RABI<:ABI,Holomorphic, Nargs} tt′ = vaTypeof(args...) width = same_or_one(1, args...) if width == 0 @@ -239,17 +239,23 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) ModifiedBetween = Val(falses_from_args(Nargs+1)) tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - - rt = if A isa UnionAll - Core.Compiler.return_type(f.val, tt) - else - eltype(A) - end + + FTy = Core.Typeof(f.val) opt_mi = if RABI <: NonGenABI Compiler.fspec(eltype(FA), tt′) else - Val(codegen_world_age(Core.Typeof(f.val), tt)) + Val(codegen_world_age(FTy, tt)) + end + + rt = if A isa UnionAll + @static if VERSION >= v"1.8.0" + Compiler.primal_return_type(rmode, Val(codegen_world_age(FTy, tt)), FTy, tt) + else + Core.Compiler.return_type(f.val, tt) + end + else + eltype(A) end if A <: Active @@ -333,7 +339,11 @@ Like [`autodiff`](@ref) but will try to guess the activity of the return value. """ @inline function autodiff(mode::CMode, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, CMode<:Mode, Nargs} tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - rt = Core.Compiler.return_type(f.val, tt) + rt = if mode isa ReverseMode && VERSION >= v"1.8.0" + Compiler.primal_return_type(mode, Val(codegen_world_age(eltype(FA), tt)), eltype(FA), tt) + else + Core.Compiler.return_type(f.val, tt) + end A = guess_activity(rt, mode) autodiff(mode, f, A, args...) end @@ -546,8 +556,12 @@ Like [`autodiff_deferred`](@ref) but will try to guess the activity of the retur @inline function autodiff_deferred(mode::M, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, M<:Mode, Nargs} tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - world = codegen_world_age(Core.Typeof(f.val), tt) - rt = Core.Compiler.return_type(f.val, tt) + rt = if mode isa ReverseMode && VERSION >= v"1.8.0" + Compiler.primal_return_type(mode, Val(codegen_world_age(eltype(FA), tt)), eltype(FA), tt) + else + Core.Compiler.return_type(f.val, tt) + end + if rt === Union{} error("return type is Union{}, giving up.") end diff --git a/src/compiler.jl b/src/compiler.jl index 9f7ca6c373..ea56926297 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3420,6 +3420,26 @@ Create the methodinstance pair, and lookup the primal return type. return primal end +@generated function primal_return_type(::ReverseMode, ::Val{world}, ::Type{FT}, ::Type{TT}) where {world, FT, TT} + mode = Enzyme.API.DEM_ReverseModeCombined + interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(Enzyme.Compiler.GLOBAL_REV_CACHE, nothing, world, mode) + res = Core.Compiler._return_type(interp, Tuple{FT, TT.parameters...}) + return quote + Base.@_inline_meta + $res + end +end + +@generated function primal_return_type(::ForwardMode, ::Val{world}, ::Type{FT}, ::Type{TT}) where {world, FT, TT} + mode = Enzyme.API.DEM_ForwardMode + interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(Enzyme.Compiler.GLOBAL_FWD_CACHE, nothing, world, mode) + res = Core.Compiler._return_type(interp, Tuple{FT, TT.parameters...}) + return quote + Base.@_inline_meta + $res + end +end + ## # Enzyme compiler step ## From 06b23f50ccd7f61c344907f84219462595ea8eb9 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 7 Aug 2024 10:05:37 +0200 Subject: [PATCH 208/495] allow forward mode reflection (#1702) --- src/compiler/reflection.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl index ec0165ec0e..b9838e1620 100644 --- a/src/compiler/reflection.jl +++ b/src/compiler/reflection.jl @@ -43,9 +43,9 @@ end function enzyme_code_llvm(io::IO, @nospecialize(func), @nospecialize(A), @nospecialize(types); optimize::Bool=true, run_enzyme::Bool=true, second_stage::Bool=true, - raw::Bool=false, debuginfo::Symbol=:default, dump_module::Bool=false) + raw::Bool=false, debuginfo::Symbol=:default, dump_module::Bool=false, mode=API.DEM_ReverseModeCombined) JuliaContext() do ctx - entry_fn, ir = reflect(func, A, types; optimize, run_enzyme, second_stage) + entry_fn, ir = reflect(func, A, types; optimize, run_enzyme, second_stage, mode) @static if VERSION >= v"1.9.0-DEV.516" ts_mod = ThreadSafeModule(ir) if VERSION >= v"1.9.0-DEV.672" @@ -76,9 +76,9 @@ function enzyme_code_llvm(io::IO, @nospecialize(func), @nospecialize(A), @nospec end enzyme_code_llvm(@nospecialize(func), @nospecialize(A), @nospecialize(types); kwargs...) = enzyme_code_llvm(stdout, func, A, types; kwargs...) -function enzyme_code_native(io::IO, @nospecialize(func), @nospecialize(A), @nospecialize(types)) +function enzyme_code_native(io::IO, @nospecialize(func), @nospecialize(A), @nospecialize(types); mode=API.DEM_ReverseModeCombined) JuliaContext() do ctx - _, mod = reflect(func, A, types) + _, mod = reflect(func, A, types; mode) str = String(LLVM.emit(JIT.get_tm(), mod, LLVM.API.LLVMAssemblyFile)) print(io, str) end From 037dfed7a2a5af4b577546ff7c879a3adbbe7058 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 7 Aug 2024 08:08:27 -0700 Subject: [PATCH 209/495] more 1.11 workarounds (#1704) * Change to newpm passbuilder * Update optimize.jl * fix * Fix * fixup * fix * managers * fixup * Add revert * more 1.11 workarounds * fix * mt * fix * Update optimize.jl * Update optimize.jl * fix * fix * fix * Update optimize.jl --- src/compiler.jl | 46 +++++-- src/compiler/optimize.jl | 288 ++++++++++++++++++++++----------------- 2 files changed, 193 insertions(+), 141 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index ea56926297..b6207b5813 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3361,7 +3361,7 @@ struct EnzymeCacheToken always_inline method_table::Core.MethodTable param_type::Type - is_fwd::API.CDerivativeMode + is_fwd::Bool end GPUCompiler.ci_cache_token(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = @@ -3422,7 +3422,17 @@ end @generated function primal_return_type(::ReverseMode, ::Val{world}, ::Type{FT}, ::Type{TT}) where {world, FT, TT} mode = Enzyme.API.DEM_ReverseModeCombined - interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(Enzyme.Compiler.GLOBAL_REV_CACHE, nothing, world, mode) + + CT = @static if VERSION >= v"1.11.0-DEV.1552" + EnzymeCacheToken( + typeof(DefaultCompilerTarget()), #=job.config.always_inline=#false, GPUCompiler.GLOBAL_METHOD_TABLE, + EnzymeCompilerParams, false, + ) + else + Enzyme.Compiler.GLOBAL_REV_CACHE + end + + interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode) res = Core.Compiler._return_type(interp, Tuple{FT, TT.parameters...}) return quote Base.@_inline_meta @@ -3432,7 +3442,17 @@ end @generated function primal_return_type(::ForwardMode, ::Val{world}, ::Type{FT}, ::Type{TT}) where {world, FT, TT} mode = Enzyme.API.DEM_ForwardMode - interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(Enzyme.Compiler.GLOBAL_FWD_CACHE, nothing, world, mode) + + CT = @static if VERSION >= v"1.11.0-DEV.1552" + EnzymeCacheToken( + typeof(DefaultCompilerTarget()), #=always_inline=#false, GPUCompiler.GLOBAL_METHOD_TABLE, + EnzymeCompilerParams, false, + ) + else + Enzyme.Compiler.GLOBAL_FWD_CACHE + end + + interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode) res = Core.Compiler._return_type(interp, Tuple{FT, TT.parameters...}) return quote Base.@_inline_meta @@ -4000,7 +4020,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr end ModulePassManager() do pm dce!(pm) - run!(pm, mod) + LLVM.run!(pm, mod) end fix_decayaddr!(mod) adjointf = adjointf == nothing ? nothing : functions(mod)[adjointfname] @@ -5290,7 +5310,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function ModulePassManager() do pm always_inliner!(pm) - run!(pm, mod) + LLVM.run!(pm, mod) end if !hasReturnsTwice LLVM.API.LLVMRemoveEnumAttributeAtIndex(wrapper_f, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), kind(EnumAttribute("returns_twice"))) @@ -5368,7 +5388,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function # Kill the temporary staging function global_dce!(pm) global_optimizer!(pm) - run!(pm, mod) + LLVM.run!(pm, mod) end if haskey(globals(mod), "llvm.used") unsafe_delete!(mod, globals(mod)["llvm.used"]) @@ -5446,7 +5466,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; LLVM.ModulePassManager() do pm API.AddPreserveNVVMPass!(pm, #=Begin=#true) - run!(pm, mod) + LLVM.run!(pm, mod) end primalf = meta.entry @@ -5474,7 +5494,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if bitcode_replacement() && API.EnzymeBitcodeReplacement(mod, disableFallback, found) != 0 ModulePassManager() do pm instruction_combining!(pm) - run!(pm, mod) + LLVM.run!(pm, mod) end toremove = [] for f in functions(mod) @@ -5517,7 +5537,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; ModulePassManager() do pm always_inliner!(pm) - run!(pm, mod) + LLVM.run!(pm, mod) end for fname in toremove if haskey(functions(mod), fname) @@ -5528,7 +5548,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; GPUCompiler.@safe_warn "Using fallback BLAS replacements for ($found), performance may be degraded" ModulePassManager() do pm global_optimizer!(pm) - run!(pm, mod) + LLVM.run!(pm, mod) end end @@ -6157,7 +6177,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end ModulePassManager() do pm always_inliner!(pm) - run!(pm, mod) + LLVM.run!(pm, mod) end for fname in toremove if haskey(functions(mod), fname) @@ -6172,7 +6192,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; LLVM.ModulePassManager() do pm API.AddPreserveNVVMPass!(pm, #=Begin=#false) - run!(pm, mod) + LLVM.run!(pm, mod) end if parent_job !== nothing if parent_job.config.target isa GPUCompiler.PTXCompilerTarget @@ -6921,7 +6941,7 @@ function _thunk(job, postopt::Bool=true) LLVM.ModulePassManager() do pm add!(pm, FunctionPass("ReinsertGCMarker", reinsert_gcmarker_pass!)) - run!(pm, mod) + LLVM.run!(pm, mod) end # Run post optimization pipeline diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index c2fd190641..a5a4908def 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -30,11 +30,11 @@ end function run_jl_pipeline(pm, tm; kwargs...) config = Ref(pipeline_options(;kwargs...)) function jl_pipeline(m) - @dispose pb=PassBuilder(tm) begin - NewPMModulePassManager(pb) do mpm + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMModulePassManager()) do mpm @ccall jl_build_newpm_pipeline(mpm.ref::Ptr{Cvoid}, pb.ref::Ptr{Cvoid}, config::Ptr{PipelineConfig})::Cvoid - run!(mpm, m, tm) end + LLVM.run!(mpm, m, tm) end return true end @@ -52,16 +52,18 @@ end end else function gc_invariant_verifier_tm!(pm, tm, cond) - function gc_invariant_verifier(f) - @dispose pb=PassBuilder(tm) begin - NewPMFunctionPassManager(pb) do fpm - add!(fpm, GCInvariantVerifierPass(GCInvariantVerifierPassOptions(;strong=cond))) - run!(fpm, f, tm) + function gc_invariant_verifier(mod) + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMModulePassManager()) do mpm + add!(mpm, NewPMFunctionPassManager()) do fpm + add!(fpm, GCInvariantVerifierPass(;strong=cond)) + end end + run!(pb, mod) end return true end - add!(pm, FunctionPass("GCInvariantVerifier", gc_invariant_verifier)) + add!(pm, ModulePass("GCInvariantVerifier", gc_invariant_verifier)) end end @@ -71,16 +73,18 @@ end end else function propagate_julia_addrsp_tm!(pm, tm) - function prop_julia_addr(f) - @dispose pb=PassBuilder(tm) begin - NewPMFunctionPassManager(pb) do fpm - add!(fpm, PropagateJuliaAddrspacesPass()) - run!(fpm, f, tm) + function prop_julia_addr(mod) + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMModulePassManager()) do mpm + add!(mpm, NewPMFunctionPassManager()) do fpm + add!(fpm, PropagateJuliaAddrspacesPass()) + end end + run!(pb, mod) end return true end - add!(pm, FunctionPass("PropagateJuliaAddrSpace", prop_julia_addr)) + add!(pm, ModulePass("PropagateJuliaAddrSpace", prop_julia_addr)) end end @@ -90,16 +94,18 @@ end end else function alloc_opt_tm!(pm, tm) - function alloc_opt(f) - @dispose pb=PassBuilder(tm) begin - NewPMFunctionPassManager(pb) do fpm - add!(fpm, AllocOptPass()) - run!(fpm, f, tm) + function alloc_opt(mod) + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMModulePassManager()) do mpm + add!(mpm, NewPMFunctionPassManager()) do fpm + add!(fpm, AllocOptPass()) + end end + run!(pb, mod) end return true end - add!(pm, FunctionPass("AllocOpt", alloc_opt)) + add!(pm, ModulePass("AllocOpt", alloc_opt)) end end @@ -109,12 +115,12 @@ end end else function remove_ni_tm!(pm, tm) - function remove_ni(f) - @dispose pb=PassBuilder(tm) begin - NewPMModulePassManager(pb) do fpm - add!(fpm, RemoveNIPass()) - run!(fpm, f, tm) + function remove_ni(mod) + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMModulePassManager()) do mpm + add!(mpm, RemoveNIPass()) end + run!(pb, mod) end return true end @@ -128,17 +134,21 @@ end end else function julia_licm_tm!(pm, tm) - function julia_licm(f) - @dispose pb=PassBuilder(tm) begin - NewPMLoopPassManager(pb) do fpm - add!(fpm, JuliaLICMPass()) - run!(fpm, f, tm) + function julia_licm(mod) + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMModulePassManager()) do mpm + add!(mpm, NewPMFunctionPassManager()) do fpm + add!(fpm, NewPMLoopPassManager()) do lpm + add!(lpm, JuliaLICMPass()) + end + end end + run!(pb, mod) end return true end # really looppass - add!(pm, FunctionPass("JuliaLICM", julia_licm)) + add!(pm, ModulePass("JuliaLICM", julia_licm)) end end @@ -148,17 +158,73 @@ end end else function lower_simdloop_tm!(pm, tm) - function lower_simdloop(f) - @dispose pb=PassBuilder(tm) begin - NewPMLoopPassManager(pb) do fpm - add!(fpm, LowerSIMDLoopPass()) - run!(fpm, f, tm) + function lower_simdloop(mod) + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMModulePassManager()) do mpm + add!(mpm, NewPMFunctionPassManager()) do fpm + add!(fpm, NewPMLoopPassManager()) do lpm + add!(lpm, LowerSIMDLoopPass()) + end + end end + run!(pb, mod) end return true end # really looppass - add!(pm, FunctionPass("LowerSIMDLoop", lower_simdloop)) + add!(pm, ModulePass("LowerSIMDLoop", lower_simdloop)) + end +end + + +function loop_optimizations_tm!(pm, tm) + @static if true || VERSION < v"1.11-" + lower_simdloop_tm!(pm, tm) + licm!(pm) + if LLVM.version() >= v"15" + simple_loop_unswitch_legacy!(pm) + else + loop_unswitch!(pm) + end + else + run_jl_pipeline(pm, tm; lower_intrinsics=false, dump_native=false, external_use=false, llvm_only=false, always_inline=false, enable_early_simplifications=false, enable_early_optimizations=false, enable_scalar_optimizations=false, enable_loop_optimizations=true, enable_vector_pipeline=false, remove_ni=false, cleanup=false) + end +end + + +function more_loop_optimizations_tm!(pm, tm) + @static if true || VERSION < v"1.11-" + loop_rotate!(pm) + # moving IndVarSimplify here prevented removing the loop in perf_sumcartesian(10:-1:1) + loop_idiom!(pm) + + # LoopRotate strips metadata from terminator, so run LowerSIMD afterwards + lower_simdloop_tm!(pm, tm) # Annotate loop marked with "loopinfo" as LLVM parallel loop + licm!(pm) + julia_licm_tm!(pm, tm) + # Subsequent passes not stripping metadata from terminator + instruction_combining!(pm) # TODO: createInstSimplifyLegacy + jl_inst_simplify!(pm) + + ind_var_simplify!(pm) + loop_deletion!(pm) + loop_unroll!(pm) # TODO: in Julia createSimpleLoopUnroll + else + # LowerSIMDLoopPass + # LoopRotatePass [opt >= 2] + # LICMPass + # JuliaLICMPass + # SimpleLoopUnswitchPass + # LICMPass + # JuliaLICMPass + # IRCEPass + # LoopInstSimplifyPass + # - in ours this is instcombine with jlinstsimplify + # LoopIdiomRecognizePass + # IndVarSimplifyPass + # LoopDeletionPass + # LoopFullUnrollPass + run_jl_pipeline(pm, tm; lower_intrinsics=false, dump_native=false, external_use=false, llvm_only=false, always_inline=false, enable_early_simplifications=false, enable_early_optimizations=false, enable_scalar_optimizations=false, enable_loop_optimizations=true, enable_vector_pipeline=false, remove_ni=false, cleanup=false) end end @@ -168,16 +234,18 @@ end end else function demote_float16_tm!(pm, tm) - function demote_float16(f) - @dispose pb=PassBuilder(tm) begin - NewPMFunctionPassManager(pb) do fpm - add!(fpm, DemoteFloat16Pass()) - run!(fpm, f, tm) + function demote_float16(mod) + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMModulePassManager()) do mpm + add!(mpm, NewPMFunctionPassManager()) do fpm + add!(fpm, DemoteFloat16Pass()) + end end + run!(pb, mod) end return true end - add!(pm, FunctionPass("DemoteFloat16", demote_float16)) + add!(pm, ModulePass("DemoteFloat16", demote_float16)) end end @@ -187,16 +255,18 @@ end end else function lower_exc_handlers_tm!(pm, tm) - function lower_exc_handlers(f) - @dispose pb=PassBuilder(tm) begin - NewPMFunctionPassManager(pb) do fpm - add!(fpm, LowerExcHandlersPass()) - run!(fpm, f, tm) + function lower_exc_handlers(mod) + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMModulePassManager()) do mpm + add!(mpm, NewPMFunctionPassManager()) do fpm + add!(fpm, LowerExcHandlersPass()) + end end + run!(pb, mod) end return true end - add!(pm, FunctionPass("LowerExcHandlers", lower_exc_handlers)) + add!(pm, ModulePass("LowerExcHandlers", lower_exc_handlers)) end end @@ -206,12 +276,12 @@ end end else function lower_ptls_tm!(pm, tm, dump_native) - function lower_ptls(f) - @dispose pb=PassBuilder(tm) begin - NewPMModulePassManager(pb) do fpm - add!(fpm, LowerPTLSPass()) - run!(fpm, f, tm) + function lower_ptls(mod) + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMModulePassManager()) do mpm + add!(mpm, LowerPTLSPass()) end + run!(pb, mod) end return true end @@ -225,16 +295,18 @@ end end else function combine_mul_add_tm!(pm, tm) - function combine_mul_add(f) - @dispose pb=PassBuilder(tm) begin - NewPMFunctionPassManager(pb) do fpm - add!(fpm, CombineMulAddPass()) - run!(fpm, f, tm) + function combine_mul_add(mod) + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMModulePassManager()) do mpm + add!(mpm, NewPMFunctionPassManager()) do fpm + add!(fpm, CombineMulAddPass()) + end end + run!(pb, mod) end return true end - add!(pm, FunctionPass("CombineMulAdd", combine_mul_add)) + add!(pm, ModulePass("CombineMulAdd", combine_mul_add)) end end @@ -244,16 +316,18 @@ end end else function late_lower_gc_frame_tm!(pm, tm) - function late_lower_gc_frame(f) - @dispose pb=PassBuilder(tm) begin - NewPMFunctionPassManager(pb) do fpm - add!(fpm, LateLowerGCPass()) - run!(fpm, f, tm) + function late_lower_gc_frame(mod) + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMModulePassManager()) do mpm + add!(mpm, NewPMFunctionPassManager()) do fpm + add!(fpm, LateLowerGCPass()) + end end + run!(pb, mod) end return true end - add!(pm, FunctionPass("LateLowerGCFrame", late_lower_gc_frame)) + add!(pm, ModulePass("LateLowerGCFrame", late_lower_gc_frame)) end end @@ -263,16 +337,18 @@ end end else function final_lower_gc_tm!(pm, tm) - function final_lower_gc(f) - @dispose pb=PassBuilder(tm) begin - NewPMFunctionPassManager(pb) do fpm - add!(fpm, FinalLowerGCPass()) - run!(fpm, f, tm) + function final_lower_gc(mod) + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMModulePassManager()) do mpm + add!(mpm, NewPMFunctionPassManager()) do fpm + add!(fpm, FinalLowerGCPass()) + end end + run!(pb, mod) end return true end - add!(pm, FunctionPass("FinalLowerGCFrame", final_lower_gc)) + add!(pm, ModulePass("FinalLowerGCFrame", final_lower_gc)) end end @@ -330,7 +406,7 @@ end # # turn this into load/store, as this is more # amenable to caching analysis infrastructure -function memcpy_alloca_to_loadstore(mod) +function memcpy_alloca_to_loadstore(mod::LLVM.Module) dl = datalayout(mod) for f in functions(mod) if length(blocks(f)) != 0 @@ -1806,7 +1882,7 @@ function removeDeadArgs!(mod::LLVM.Module, tm) # callsites. See: https://godbolt.org/z/9Y3Gv6q5M ModulePassManager() do pm global_dce!(pm) - run!(pm, mod) + LLVM.run!(pm, mod) end # Prevent dead-arg-elimination of functions which we may require args for in the derivative funcT = LLVM.FunctionType(LLVM.VoidType(), LLVMType[], vararg=true) @@ -1895,7 +1971,7 @@ function removeDeadArgs!(mod::LLVM.Module, tm) alloc_opt_tm!(pm, tm) scalar_repl_aggregates_ssa!(pm) # SSA variant? cse!(pm) - run!(pm, mod) + LLVM.run!(pm, mod) end propagate_returned!(mod) pre_attr!(mod) @@ -1903,7 +1979,7 @@ function removeDeadArgs!(mod::LLVM.Module, tm) if LLVM.version().major >= 13 ModulePassManager() do pm API.EnzymeAddAttributorLegacyPass(pm) - run!(pm, mod) + LLVM.run!(pm, mod) end end end @@ -1919,7 +1995,7 @@ function removeDeadArgs!(mod::LLVM.Module, tm) end end cse!(pm) - run!(pm, mod) + LLVM.run!(pm, mod) end post_attr!(mod) propagate_returned!(mod) @@ -1991,17 +2067,7 @@ end loop_idiom!(pm) loop_rotate!(pm) - if VERSION < v"1.11-" - lower_simdloop_tm!(pm, tm) - licm!(pm) - if LLVM.version() >= v"15" - simple_loop_unswitch_legacy!(pm) - else - loop_unswitch!(pm) - end - else - run_jl_pipeline(pm, tm; lower_intrinsics=false, dump_native=false, external_use=false, llvm_only=false, always_inline=false, enable_early_simplifications=false, enable_early_optimizations=false, enable_scalar_optimizations=false, enable_loop_optimizations=true, enable_vector_pipeline=false, remove_ni=false, cleanup=false) - end + loop_optimizations_tm!(pm, tm) instruction_combining!(pm) jl_inst_simplify!(pm) @@ -2030,7 +2096,7 @@ end correlated_value_propagation!(pm) # SLP_Vectorizer -- not for Enzyme - run!(pm, mod) + LLVM.run!(pm, mod) aggressive_dce!(pm) instruction_combining!(pm) @@ -2048,7 +2114,7 @@ end jl_inst_simplify!(pm) LLVM.API.LLVMAddGlobalOptimizerPass(pm) # Exxtra gvn!(pm) # Exxtra - run!(pm, mod) + LLVM.run!(pm, mod) end removeDeadArgs!(mod, tm) detect_writeonly!(mod) @@ -2103,41 +2169,7 @@ function addOptimizationPasses!(pm, tm) # remove those before optimizing loops. alloc_opt_tm!(pm, tm) - - if VERSION < v"1.11-" - loop_rotate!(pm) - # moving IndVarSimplify here prevented removing the loop in perf_sumcartesian(10:-1:1) - loop_idiom!(pm) - - # LoopRotate strips metadata from terminator, so run LowerSIMD afterwards - lower_simdloop_tm!(pm, tm) # Annotate loop marked with "loopinfo" as LLVM parallel loop - licm!(pm) - julia_licm_tm!(pm, tm) - # Subsequent passes not stripping metadata from terminator - instruction_combining!(pm) # TODO: createInstSimplifyLegacy - jl_inst_simplify!(pm) - - ind_var_simplify!(pm) - loop_deletion!(pm) - loop_unroll!(pm) # TODO: in Julia createSimpleLoopUnroll - else - # LowerSIMDLoopPass - # LoopRotatePass [opt >= 2] - # LICMPass - # JuliaLICMPass - # SimpleLoopUnswitchPass - # LICMPass - # JuliaLICMPass - # IRCEPass - # LoopInstSimplifyPass - # - in ours this is instcombine with jlinstsimplify - # LoopIdiomRecognizePass - # IndVarSimplifyPass - # LoopDeletionPass - # LoopFullUnrollPass - run_jl_pipeline(pm, tm; lower_intrinsics=false, dump_native=false, external_use=false, llvm_only=false, always_inline=false, enable_early_simplifications=false, enable_early_optimizations=false, enable_scalar_optimizations=false, enable_loop_optimizations=true, enable_vector_pipeline=false, remove_ni=false, cleanup=false) - end - + more_loop_optimizations_tm!(pm, tm) # Run our own SROA on heap objects before LLVM's alloc_opt_tm!(pm, tm) @@ -2240,7 +2272,7 @@ function post_optimze!(mod, tm, machine=true) LLVM.ModulePassManager() do pm addTargetPasses!(pm, tm, LLVM.triple(mod)) addOptimizationPasses!(pm, tm) - run!(pm, mod) + LLVM.run!(pm, mod) end if machine # TODO enable validate_return_roots @@ -2248,7 +2280,7 @@ function post_optimze!(mod, tm, machine=true) LLVM.ModulePassManager() do pm addJuliaLegalizationPasses!(pm, tm, true) addMachinePasses!(pm, tm) - run!(pm, mod) + LLVM.run!(pm, mod) end end # @safe_show "post_mod", mod From fc5577c75db1957e8b6e48bcc7d1eca27df8f4fd Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 7 Aug 2024 08:41:14 -0700 Subject: [PATCH 210/495] Add optional [but presently on] check if function contains written-to data (#1701) * Functioning * cleanup * fixup * Improving error msg * cleanup * fixup * fixup * ix * bump versions * Update Project.toml * fixup * allocinst * skip done * fix * fix * fix par * Disable ename map on < 1.8 --- Project.toml | 6 +- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/EnzymeCore.jl | 33 +++-- src/Enzyme.jl | 64 ++++---- src/absint.jl | 37 ++++- src/compiler.jl | 241 ++++++++++++++++++++++++------- src/compiler/reflection.jl | 4 +- src/compiler/utils.jl | 30 ++++ src/internal_rules.jl | 4 +- src/rules/jitrules.jl | 12 +- src/rules/parallelrules.jl | 8 +- src/utils.jl | 8 +- test/abi.jl | 44 +++++- test/amdgpu.jl | 2 - test/internal_rules.jl | 9 +- test/rrules.jl | 2 +- test/runtests.jl | 36 +++-- 17 files changed, 387 insertions(+), 155 deletions(-) diff --git a/Project.toml b/Project.toml index e47535efbd..9cdd1ebe9c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.27" +version = "0.12.28" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -31,8 +31,8 @@ EnzymeStaticArraysExt = "StaticArrays" [compat] CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.7.7" -Enzyme_jll = "0.0.141" +EnzymeCore = "0.7.8" +Enzyme_jll = "0.0.142" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" LogExpFunctions = "0.3" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index e0861dadda..c0c4e0b1e6 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.7.7" +version = "0.7.8" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index fa0a31d44a..94df9a61c5 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -209,7 +209,7 @@ const DefaultABI = FFIABI Abstract type for what differentiation mode will be used. """ -abstract type Mode{ABI} end +abstract type Mode{ABI, ErrIfFuncWritten} end """ struct ReverseMode{ReturnPrimal,ABI,Holomorphic} <: Mode{ABI} @@ -219,11 +219,14 @@ Reverse mode differentiation. - `ABI`: What runtime ABI to use - `Holomorphic`: Whether the complex result function is holomorphic and we should compute d/dz """ -struct ReverseMode{ReturnPrimal,ABI,Holomorphic} <: Mode{ABI} end -const Reverse = ReverseMode{false,DefaultABI, false}() -const ReverseWithPrimal = ReverseMode{true,DefaultABI, false}() -const ReverseHolomorphic = ReverseMode{false,DefaultABI, true}() -const ReverseHolomorphicWithPrimal = ReverseMode{true,DefaultABI, true}() +struct ReverseMode{ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten} end +const Reverse = ReverseMode{false,DefaultABI, false, false}() +const ReverseWithPrimal = ReverseMode{true,DefaultABI, false, false}() +const ReverseHolomorphic = ReverseMode{false,DefaultABI, true, false}() +const ReverseHolomorphicWithPrimal = ReverseMode{true,DefaultABI, true, false}() + +@inline set_err_if_func_written(::ReverseMode{ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,ABI,Holomorphic,true}() +@inline clear_err_if_func_written(::ReverseMode{ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,ABI,Holomorphic,false}() """ struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI} <: Mode{ABI} @@ -234,19 +237,23 @@ Reverse mode differentiation. - `Width`: Batch Size (0 if to be automatically derived) - `ModifiedBetween`: Tuple of each argument's modified between state (true if to be automatically derived). """ -struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI} <: Mode{ABI} end -const ReverseSplitNoPrimal = ReverseModeSplit{false, true, 0, true,DefaultABI}() -const ReverseSplitWithPrimal = ReverseModeSplit{true, true, 0, true,DefaultABI}() -@inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, Width, MBO, ABI}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,Width,MB,MBO,ABI} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,MB,ABI}() -@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, WidthO, MB, ABI}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,Width,MB,WidthO,ABI} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,MB,ABI}() +struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten} end +const ReverseSplitNoPrimal = ReverseModeSplit{false, true, 0, true,DefaultABI, false}() +const ReverseSplitWithPrimal = ReverseModeSplit{true, true, 0, true,DefaultABI, false}() +@inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, Width, MBO, ABI, ErrIfFuncWritten}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,Width,MB,MBO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,MB,ABI, ErrIfFuncWritten}() +@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, WidthO, MB, ABI, ErrIfFuncWritten}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,Width,MB,WidthO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,MB,ABI, ErrIfFuncWritten}() """ struct Forward <: Mode Forward mode differentiation """ -struct ForwardMode{ABI} <: Mode{ABI} +struct ForwardMode{ABI, ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten} end -const Forward = ForwardMode{DefaultABI}() +const Forward = ForwardMode{DefaultABI, false}() + + +@inline set_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten}) where {ABI,ErrIfFuncWritten} = ForwardMode{ABI,true}() +@inline clear_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten}) where {ABI,ErrIfFuncWritten} = ForwardMode{ABI,false}() function autodiff end function autodiff_deferred end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 72ed79eca0..076eb62761 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -5,8 +5,8 @@ import EnzymeCore import EnzymeCore: Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal -import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI -export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI +import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written +export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written import EnzymeCore: BatchDuplicatedFunc export BatchDuplicatedFunc @@ -229,7 +229,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) [`Active`](@ref) will automatically convert plain integers to floating point values, but cannot do so for integer values in tuples and structs. """ -@inline function autodiff(rmode::ReverseMode{ReturnPrimal, RABI,Holomorphic}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RABI<:ABI,Holomorphic, Nargs} +@inline function autodiff(rmode::ReverseMode{ReturnPrimal, RABI,Holomorphic, ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RABI<:ABI,Holomorphic, Nargs, ErrIfFuncWritten} tt′ = vaTypeof(args...) width = same_or_one(1, args...) if width == 0 @@ -260,7 +260,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) if A <: Active if (!allocatedinline(rt) || rt isa Union) && rt != Union{} - forward, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI) + forward, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI, Val(ErrIfFuncWritten)) res = forward(f, args...) tape = res[1] if ReturnPrimal @@ -290,7 +290,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) args = seed_complex_args(seen, seen2, args...) tt′ = vaTypeof(args...) - thunk = Enzyme.Compiler.thunk(opt_mi, typeof(f), A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) + thunk = Enzyme.Compiler.thunk(opt_mi, typeof(f), A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) results = thunk(f, args..., (rt(0), rt(1), rt(im))) @@ -312,7 +312,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) throw(ErrorException("Reverse-mode Active Complex return is ambiguous and requires more information to specify the desired result. See https://enzyme.mit.edu/julia/stable/faq/#Complex-numbers for more details.")) end - thunk = Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) + thunk = Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) if A <: Active args = (args..., Compiler.default_adjoint(rt)) @@ -326,10 +326,10 @@ end Like [`autodiff`](@ref) but will try to extend f to an annotation, if needed. """ @inline function autodiff(mode::CMode, f::F, args::Vararg{Annotation, Nargs}) where {F, CMode<:Mode, Nargs} - autodiff(mode, Const(f), args...) + autodiff(EnzymeCore.set_err_if_func_written(mode), Const(f), args...) end @inline function autodiff(mode::CMode, f::F, ::Type{RT}, args::Vararg{Annotation, Nargs}) where {F, RT<:Annotation, CMode<:Mode, Nargs} - autodiff(mode, Const(f), RT, args...) + autodiff(EnzymeCore.set_err_if_func_written(mode), Const(f), RT, args...) end """ @@ -393,7 +393,7 @@ f(x) = x*x (6.28,) ``` """ -@inline function autodiff(::ForwardMode{RABI}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {RABI <: ABI, Nargs} +@inline function autodiff(::ForwardMode{RABI, ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {RABI <: ABI, Nargs, ErrIfFuncWritten} if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end @@ -433,7 +433,7 @@ f(x) = x*x end thunk = Enzyme.Compiler.thunk(opt_mi, FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), - ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI) + ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) thunk(f, args...) end @@ -443,7 +443,7 @@ end Same as [`autodiff`](@ref) but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ReverseMode{ReturnPrimal}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, Nargs} +@inline function autodiff_deferred(::ReverseMode{ReturnPrimal, ABI,Holomorphic,ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, Nargs, ABI,Holomorphic,ErrIfFuncWritten} tt′ = vaTypeof(args...) width = same_or_one(1, args...) if width == 0 @@ -467,7 +467,7 @@ code, as well as high-order differentiation. ModifiedBetween = Val(falses_from_args(Nargs+1)) - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal)) + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), UnknownTapeType, Val(ErrIfFuncWritten)) thunk = Compiler.CombinedAdjointThunk{Ptr{Cvoid}, FA, rt, tt′, width, ReturnPrimal}(adjoint_ptr) if rt <: Active @@ -484,7 +484,7 @@ end Same as `autodiff(::ForwardMode, f, Activity, args)` but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ForwardMode, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, Nargs} +@inline function autodiff_deferred(::ForwardMode{ABI, ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, Nargs, ABI, ErrIfFuncWritten} if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end @@ -531,7 +531,7 @@ code, as well as high-order differentiation. ReturnPrimal = RT <: Duplicated || RT <: BatchDuplicated ModifiedBetween = Val(falses_from_args(Nargs+1)) - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, Val(ReturnPrimal)) + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), UnknownTapeType, Val(ErrIfFuncWritten)) thunk = Compiler.ForwardModeThunk{Ptr{Cvoid}, FA, rt, tt′, width, ReturnPrimal}(adjoint_ptr) thunk(f, args...) end @@ -542,10 +542,10 @@ end Like [`autodiff_deferred`](@ref) but will try to extend f to an annotation, if needed. """ @inline function autodiff_deferred(mode::CMode, f::F, args::Vararg{Annotation, Nargs}) where {F, CMode<:Mode, Nargs} - autodiff_deferred(mode, Const(f), args...) + autodiff_deferred(EnzymeCore.set_err_if_func_written(mode), Const(f), args...) end @inline function autodiff_deferred(mode::CMode, f::F, ::Type{RT}, args::Vararg{Annotation, Nargs}) where {F, RT<:Annotation, CMode<:Mode, Nargs} - autodiff_deferred(mode, Const(f), RT, args...) + autodiff_deferred(EnzymeCore.set_err_if_func_written(mode), Const(f), RT, args...) end """ @@ -612,7 +612,7 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI<:ABI, Nargs} +@inline function autodiff_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI<:ABI, Nargs, ErrIfFuncWritten} width = if Width == 0 w = same_or_one(1, args...) if w == 0 @@ -640,7 +640,7 @@ result, ∂v, ∂A else Val(codegen_world_age(eltype(FA), tt)) end - Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) + Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) end """ @@ -687,7 +687,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated (6.28,) ``` """ -@inline function autodiff_thunk(::ForwardMode{RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, RABI<:ABI, Nargs} +@inline function autodiff_thunk(::ForwardMode{RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, RABI<:ABI, Nargs, ErrIfFuncWritten} width = same_or_one(1, A, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) @@ -706,10 +706,10 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated else Val(codegen_world_age(eltype(FA), tt)) end - Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI) + Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) end -@inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} +@inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, ErrIfFuncWritten} width = if Width == 0 w = same_or_one(1, args...) if w == 0 @@ -735,7 +735,7 @@ end else Val(codegen_world_age(eltype(FA), primal_tt)) end - nondef = Enzyme.Compiler.thunk(opt_mi, FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) + nondef = Enzyme.Compiler.thunk(opt_mi, FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) if nondef[1] isa Enzyme.Compiler.PrimalErrorThunk return Nothing else @@ -853,7 +853,7 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_deferred_thunk(mode::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, tt::Type{TapeType}, fa::Type{FA}, a2::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A2<:Annotation, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} +@inline function autodiff_deferred_thunk(mode::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI, ErrIfFuncWritten}, tt::Type{TapeType}, fa::Type{FA}, a2::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A2<:Annotation, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, ErrIfFuncWritten} @assert RABI == FFIABI width = if Width == 0 w = same_or_one(1, args...) @@ -877,8 +877,8 @@ result, ∂v, ∂A primal_tt = Tuple{map(eltype, args)...} world = codegen_world_age(eltype(FA), primal_tt) - primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType) - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType) + primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType, Val(ErrIfFuncWritten)) + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType, Val(ErrIfFuncWritten)) RT = if A2 <: Duplicated && width != 1 if A2 isa UnionAll @@ -1038,7 +1038,7 @@ grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) @inline function gradient(rm::ReverseMode, f::F, x::X) where {F, X} if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState dx = Ref(make_zero(x)) - autodiff(rm, f∘only, Active, Duplicated(Ref(x), dx)) + autodiff(rm, f, Active, MixedDuplicated(x, dx)) return only(dx) else dx = make_zero(x) @@ -1055,7 +1055,7 @@ Like [`gradient`](@ref), except it using deferred mode. @inline function gradient_deferred(rm::ReverseMode, f::F, x::X) where {F, X} if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState dx = Ref(make_zero(x)) - autodiff_deferred(rm, f∘only, Active, Duplicated(Ref(x), dx)) + autodiff_deferred(rm, f, Active, MixedDuplicated(x, dx)) return only(dx) else dx = make_zero(x) @@ -1250,7 +1250,7 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2)) 0.0 1.0 ``` """ -@inline function jacobian(::ReverseMode{ReturnPrimal,RABI}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, ReturnPrimal, RABI<:ABI} +@inline function jacobian(::ReverseMode{ReturnPrimal,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, ReturnPrimal, RABI<:ABI, ErrIfFuncWritten} @assert !ReturnPrimal num = ((n_out_val + chunk - 1) ÷ chunk) @@ -1268,7 +1268,7 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2)) else Val(codegen_world_age(Core.Typeof(f), tt)) end - primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI) + primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) if num * chunk == n_out_val last_size = chunk @@ -1276,7 +1276,7 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2)) else last_size = n_out_val - (num-1)*chunk tt′ = Tuple{BatchDuplicated{Core.Typeof(x), last_size}} - primal2, adjoint2 = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI) + primal2, adjoint2 = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) end tmp = ntuple(num) do i @@ -1299,7 +1299,7 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2)) mapreduce(LinearAlgebra.adjoint, vcat, rows) end -@inline function jacobian(::ReverseMode{ReturnPrimal,RABI}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,ReturnPrimal,RABI<:ABI} +@inline function jacobian(::ReverseMode{ReturnPrimal,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,ReturnPrimal,RABI<:ABI, ErrIfFuncWritten} @assert !ReturnPrimal tt′ = Tuple{Duplicated{Core.Typeof(x)}} tt = Tuple{Core.Typeof(x)} @@ -1311,7 +1311,7 @@ end else Val(codegen_world_age(Core.Typeof(f), tt)) end - primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI) + primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) rows = ntuple(n_outs) do i Base.@_inline_meta dx = zero(x) diff --git a/src/absint.jl b/src/absint.jl index 80cb3b9d4f..a9748d2dd5 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -2,6 +2,10 @@ # Return (bool if could interpret, julia object interpreted to) function absint(arg::LLVM.Value, partial::Bool=false) + if isa(arg, LLVM.BitCastInst) || + isa(arg, LLVM.AddrSpaceCastInst) + return absint(operands(arg)[1], partial) + end if isa(arg, LLVM.CallInst) fn = LLVM.called_operand(arg) nm = "" @@ -146,6 +150,10 @@ function absint(arg::LLVM.Value, partial::Bool=false) end function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Type},Tuple{Bool, Nothing}} + if isa(arg, LLVM.BitCastInst) || + isa(arg, LLVM.AddrSpaceCastInst) + return abs_typeof(operands(arg)[1], partial) + end if isa(arg, LLVM.CallInst) fn = LLVM.called_operand(arg) nm = "" @@ -292,6 +300,16 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ f = LLVM.Function(LLVM.API.LLVMGetParamParent(larg)) idx = only([i for (i, v) in enumerate(LLVM.parameters(f)) if v == larg]) typ, byref = enzyme_extract_parm_type(f, idx, #=error=#false) + @static if VERSION < v"1.11-" + if typ !== nothing && typ <: Array && Base.isconcretetype(typ) + T = eltype(typ) + if offset === nothing || offset == 0 + return (true, Ptr{T}) + else + return (true, Int) + end + end + end if typ !== nothing && byref == GPUCompiler.BITS_REF if offset === nothing return (true, typ) @@ -321,7 +339,24 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ end end end - # @show "not found", typ, offset, [fieldoffset(typ, i) for i in 1:fieldcount(typ)] + end + end + else + legal, RT = abs_typeof(larg) + if legal + if RT <: Array && Base.isconcretetype(RT) + @static if VERSION < v"1.11-" + T = eltype(RT) + + if offset == 0 + return (true, Ptr{T}) + end + + return (true, Int) + end + end + if RT <: Ptr && Base.isconcretetype(RT) + return (true, eltype(RT)) end end end diff --git a/src/compiler.jl b/src/compiler.jl index b6207b5813..477a9f6a3c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -807,9 +807,7 @@ function emit_allocobj!(B, T::DataType) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) # Obtain tag - tag = LLVM.ConstantInt(convert(UInt, Base.pointer_from_objref(T))) # do we need to root ETT - tag = LLVM.const_inttoptr(tag, T_prjlvalue_UT) - tag = LLVM.const_addrspacecast(tag, T_prjlvalue) + tag = unsafe_to_llvm(B, T) T_size_t = convert(LLVM.LLVMType, UInt) Size = LLVM.ConstantInt(T_size_t, sizeof(T)) @@ -864,6 +862,18 @@ function emit_jl!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value call!(B, FT, fn, [val]) end +function emit_jl_throw!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + T_void = LLVM.VoidType() + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, 12) + FT = LLVM.FunctionType(T_void, [T_prjlvalue]) + fn, _ = get_function!(mod, "jl_throw", FT) + call!(B, FT, fn, [val]) +end + function emit_box_int32!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) @@ -984,17 +994,49 @@ end AnyArray(Length::Int) = NamedTuple{ntuple(i->Symbol(i), Val(Length)),NTuple{Length,Any}} -const JuliaEnzymeNameMap = Dict{String, Any}( +struct EnzymeRuntimeException <: Base.Exception + msg::Cstring +end +function Base.showerror(io::IO, ece::EnzymeRuntimeException) + print(io, "Enzyme execution failed.\n") + msg = Base.unsafe_string(ece.msg) + print(io, msg, '\n') +end + +struct EnzymeMutabilityException <: Base.Exception + msg::Cstring +end + +function Base.showerror(io::IO, ece::EnzymeMutabilityException) + msg = Base.unsafe_string(ece.msg) + print(io, msg, '\n') +end + +struct EnzymeRuntimeActivityError <: Base.Exception + msg::Cstring +end + +function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError) + msg = Base.unsafe_string(ece.msg) + print(io, msg, '\n') +end + +@static if VERSION >= v"1.8.0" +const JuliaEnzymeNameMap = Dict{String, Any}( "enz_val_true" => Val(true), "enz_val_false" => Val(false), - "enz_val_1" => Val(1), - "enz_any_array_1" => AnyArray(1), "enz_any_array_2" => AnyArray(2), - "enz_any_array_3" => AnyArray(3) + "enz_any_array_3" => AnyArray(3), + "enz_runtime_exc" => EnzymeRuntimeException, + "enz_mut_exc" => EnzymeMutabilityException, + "enz_runtime_activity_exc" => EnzymeRuntimeActivityError, ) +else +const JuliaEnzymeNameMap = Dict{String, Any}() +end const JuliaGlobalNameMap = Dict{String, Any}( "jl_type_type" => Type, @@ -1783,31 +1825,12 @@ end return end -struct EnzymeRuntimeException <: Base.Exception - msg::Cstring -end - -function Base.showerror(io::IO, ece::EnzymeRuntimeException) - print(io, "Enzyme execution failed.\n") - msg = Base.unsafe_string(ece.msg) - print(io, msg, '\n') -end - -function throwerr(cstr::Cstring) - throw(EnzymeRuntimeException(cstr)) -end - -function emit_error(B::LLVM.IRBuilder, orig, string) +function emit_error(B::LLVM.IRBuilder, orig, string, errty=EnzymeRuntimeException) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) # 1. get the error function - funcT = LLVM.FunctionType(LLVM.VoidType(), LLVMType[LLVM.PointerType(LLVM.Int8Type())]) - ptr = @cfunction(throwerr, Union{}, (Cstring,)) - ptr = convert(UInt, ptr) - ptr = LLVM.ConstantInt(ptr) - func = inttoptr!(B, ptr, LLVM.PointerType(funcT)) if orig !== nothing bt = GPUCompiler.backtrace(orig) function printBT(io) @@ -1820,19 +1843,18 @@ function emit_error(B::LLVM.IRBuilder, orig, string) ct = if occursin("ptx", LLVM.triple(mod)) || occursin("amdgcn", LLVM.triple(mod)) GPUCompiler.emit_exception!(B, string, orig) else - call!(B, funcT, func, LLVM.Value[globalstring_ptr!(B, string)]) + err = emit_allocobj!(B, errty) + err2 = bitcast!(B, err, LLVM.PointerType(LLVM.PointerType(LLVM.Int8Type()), 10)) + store!(B, globalstring_ptr!(B, string), err2) + emit_jl_throw!(B, addrspacecast!(B, err, LLVM.PointerType(LLVM.StructType(LLVMType[]), 12))) end # 2. Call error function and insert unreachable LLVM.API.LLVMAddCallSiteAttribute(ct, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), EnumAttribute("noreturn")) - LLVM.API.LLVMAddCallSiteAttribute(ct, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("enzyme_error")) + if EnzymeMutabilityException != errty + LLVM.API.LLVMAddCallSiteAttribute(ct, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("enzyme_error")) + end return ct - # FIXME(@wsmoses): Allow for emission of new BB in this code path - # unreachable!(B) - - # 3. Change insertion point so that we don't stumble later - # after_error = BasicBlock(fn, "after_error"; ctx) - # position!(B, after_error) end function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, f, tt, world) @@ -2062,15 +2084,11 @@ function julia_sanitize(orig::LLVM.API.LLVMValueRef, val::LLVM.API.LLVMValueRef, position!(builder, good) ret!(builder) # ret!(builder, inp) - position!(builder, bad) - - funcT = LLVM.FunctionType(LLVM.VoidType(), LLVMType[LLVM.PointerType(LLVM.Int8Type())]) - ptr = @cfunction(throwerr, Union{}, (Cstring,)) - ptr = convert(UInt, ptr) - ptr = LLVM.ConstantInt(ptr) - func = inttoptr!(builder, ptr, LLVM.PointerType(funcT)) - call!(builder, funcT, func, LLVM.Value[sval]) + err = emit_allocobj!(builder, EnzymeRuntimeException) + err2 = bitcast!(builder, err, LLVM.PointerType(LLVM.PointerType(LLVM.Int8Type()), 10)) + store!(builder, globalstring_ptr!(builder, string), err2) + emit_jl_throw!(builder, addrspacecast!(builder, err, LLVM.PointerType(LLVM.StructType(LLVMType[]), 12))) unreachable!(builder) dispose(builder) @@ -2486,7 +2504,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err Base.show_backtrace(io, bt) end end - emit_error(b, nothing, msg2) + emit_error(b, nothing, msg2, EnzymeRuntimeActivityError) return C_NULL elseif errtype == API.ET_GetIndexError @assert B != C_NULL @@ -3331,6 +3349,8 @@ struct EnzymeCompilerParams <: AbstractEnzymeCompilerParams expectedTapeType::Type # Whether to use the pointer ABI, default true ABI::Type{<:ABI} + # Whether to error if the function is written to + err_if_func_written::Bool end struct UnknownTapeType end @@ -6114,6 +6134,119 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; TapeType::Type = Cvoid + if params.err_if_func_written + FT = TT.parameters[1] + Ty = eltype(FT) + reg = active_reg_inner(Ty, (), world) + if reg == DupState || reg == MixedState + swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(primalf, i)))) for i in 1:length(collect(parameters(primalf)))) + todo = LLVM.Value[parameters(primalf)[1+swiftself]] + done = Set{LLVM.Value}() + doneInst = Set{LLVM.Instruction}() + while length(todo) != 0 + cur = pop!(todo) + if cur in done + continue + end + push!(done, cur) + for u in LLVM.uses(cur) + user = LLVM.user(u) + if user in doneInst + continue + end + if LLVM.API.LLVMIsAReturnInst(user) != C_NULL + continue + end + + if !mayWriteToMemory(user) + slegal , foundv = abs_typeof(user) + if slegal + reg2 = active_reg_inner(foundv, (), world) + if reg2 == ActiveState || reg2 == AnyState + continue + end + end + push!(todo, user) + continue + end + + if isa(user, LLVM.StoreInst) + # we are capturing the variable + if operands(user)[1] == cur + base = operands(user)[2] + while isa(base, LLVM.BitCastInst) || isa(base, LLVM.AddrSpaceCastInst) || isa(base, LLVM.GetElementPtrInst) + base = operands(base)[1] + end + if isa(base, LLVM.AllocaInst) + push!(doneInst, user) + push!(todo, base) + continue + end + end + # we are storing into the variable + if operands(user)[2] == cur + slegal , foundv = abs_typeof(operands(user)[1]) + if slegal + reg2 = active_reg_inner(foundv, (), world) + if reg2 == AnyState + continue + end + end + end + end + + if isa(user, LLVM.CallInst) + called = LLVM.called_operand(user) + if isa(called, LLVM.Function) + nm = LLVM.name(called) + if nm == "ijl_alloc_array_1d" || nm == "jl_alloc_array_1d" || + nm == "ijl_alloc_array_2d" || nm == "jl_alloc_array_2d" || + nm == "ijl_alloc_array_3d" || nm == "jl_alloc_array_3d" + continue + end + if is_readonly(called) + slegal , foundv = abs_typeof(user) + if slegal + reg2 = active_reg_inner(foundv, (), world) + if reg2 == ActiveState || reg2 == AnyState + continue + end + end + push!(todo, user) + continue + end + if !isempty(blocks(called)) && length(collect(LLVM.uses(called))) == 1 + for (parm, op) in zip(LLVM.parameters(called), operands(user)[1:end-1]) + if op == cur + push!(todo, parm) + end + end + slegal , foundv = abs_typeof(user) + if slegal + reg2 = active_reg_inner(foundv, (), world) + if reg2 == ActiveState || reg2 == AnyState + continue + end + end + push!(todo, user) + continue + end + end + end + + builder = LLVM.IRBuilder() + position!(builder, user) + resstr = "Function argument passed to autodiff cannot be proven readonly.\nIf the the function argument cannot contain derivative data, instead call autodiff(Mode, Const(f), ...)\nSee https://enzyme.mit.edu/index.fcgi/julia/stable/faq/#Activity-of-temporary-storage for more information.\nThe potentially writing call is "*string(user)*", using "*string(cur) + slegal , foundv = absint(cur) + if slegal + resstr *= "of type "*string(foundv) + end + emit_error(builder, user, resstr, EnzymeMutabilityException) + end + end + end + end + if params.run_enzyme # Generate the adjoint memcpy_alloca_to_loadstore(mod) @@ -6988,9 +7121,9 @@ end @inline remove_innerty(::Type{<:MixedDuplicated}) = MixedDuplicated @inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated -@inline function thunkbase(ctx, mi::Core.MethodInstance, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} +@inline function thunkbase(ctx, mi::Core.MethodInstance, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI, ErrIfFuncWritten} target = Compiler.EnzymeTarget() - params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) + params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI, ErrIfFuncWritten) tmp_job = if World isa Nothing Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) else @@ -7029,7 +7162,7 @@ end A2 end - params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, run_enzyme, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) + params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, run_enzyme, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI, ErrIfFuncWritten) job = if World isa Nothing Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) else @@ -7066,7 +7199,7 @@ end end end -@inline function thunk(mi::Core.MethodInstance, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, ABI} +@inline function thunk(mi::Core.MethodInstance, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, ABI, ErrIfFuncWritten} ts_ctx = JuliaContext() ctx = @static if VERSION >= v"1.9.0-DEV.115" context(ts_ctx) @@ -7075,7 +7208,7 @@ end end activate(ctx) try - return thunkbase(ctx, mi, Val(#=World=#nothing), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI) + return thunkbase(ctx, mi, Val(#=World=#nothing), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten)) finally deactivate(ctx) @static if VERSION >= v"1.9.0-DEV.115" @@ -7084,7 +7217,7 @@ end end end -@inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} +@inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI, ErrIfFuncWritten} mi = fspec(eltype(FA), TT, World) ts_ctx = JuliaContext() ctx = @static if VERSION >= v"1.9.0-DEV.115" @@ -7094,7 +7227,7 @@ end end activate(ctx) res = try - thunkbase(ctx, mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI) + thunkbase(ctx, mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten)) finally deactivate(ctx) @static if VERSION >= v"1.9.0-DEV.115" @@ -7110,14 +7243,14 @@ end import GPUCompiler: deferred_codegen_jobs @generated function deferred_codegen(::Val{World}, ::Type{FA}, ::Val{TT}, ::Val{A},::Val{Mode}, - ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false),::Val{ShadowInit}=Val(false),::Type{ExpectedTapeType}=UnknownTapeType) where {World, FA<:Annotation,TT, A, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType} + ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal},::Val{ShadowInit},::Type{ExpectedTapeType}, ::Val{ErrIfFuncWritten}) where {World, FA<:Annotation,TT, A, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, ErrIfFuncWritten} JuliaContext() do ctx Base.@_inline_meta mi = fspec(eltype(FA), TT, World) target = EnzymeTarget() rt2 = if A isa UnionAll - params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI) + params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI, ErrIfFuncWritten) tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) interp = GPUCompiler.get_interpreter(tmp_job) @@ -7141,7 +7274,7 @@ import GPUCompiler: deferred_codegen_jobs A end - params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI) + params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI, ErrIfFuncWritten) job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) addr = get_trampoline(job) diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl index b9838e1620..5b7100f887 100644 --- a/src/compiler/reflection.jl +++ b/src/compiler/reflection.jl @@ -1,5 +1,5 @@ function get_job(@nospecialize(func), @nospecialize(A), @nospecialize(types); - run_enzyme::Bool=true, mode::API.CDerivativeMode=API.DEM_ReverseModeCombined, dupClosure::Bool=false, argwrap::Bool=true, width::Int=1, modifiedBetween=nothing, returnPrimal::Bool=false, augmentedInit=false, world=nothing, ABI=DefaultABI, kwargs...) + run_enzyme::Bool=true, mode::API.CDerivativeMode=API.DEM_ReverseModeCombined, dupClosure::Bool=false, argwrap::Bool=true, width::Int=1, modifiedBetween=nothing, returnPrimal::Bool=false, augmentedInit=false, world=nothing, ABI=DefaultABI, ErrIfFuncWritten=false, kwargs...) tt = Tuple{map(eltype, types.parameters)...} if world === nothing @@ -15,7 +15,7 @@ function get_job(@nospecialize(func), @nospecialize(A), @nospecialize(types); defaultMod = mode != API.DEM_ReverseModeCombined && mode != API.DEM_ForwardMode modifiedBetween = (defaultMod, (defaultMod for _ in types.parameters)...) end - params = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){Core.Typeof(func)}, types.parameters...}, mode, width, rt, run_enzyme, argwrap, modifiedBetween, returnPrimal, augmentedInit, Compiler.UnknownTapeType, ABI) + params = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){Core.Typeof(func)}, types.parameters...}, mode, width, rt, run_enzyme, argwrap, modifiedBetween, returnPrimal, augmentedInit, Compiler.UnknownTapeType, ABI, ErrIfFuncWritten) return Compiler.CompilerJob(primal, CompilerConfig(target, params; kernel=false), world) end diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index 9595927aff..fb2ee4714a 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -110,6 +110,16 @@ function is_noreturn(f::LLVM.Function) end function is_readonly(f::LLVM.Function) + intr = LLVM.API.LLVMGetIntrinsicID(f) + if intr == LLVM.Intrinsic("llvm.lifetime.start").id + return true + end + if intr == LLVM.Intrinsic("llvm.lifetime.end").id + return true + end + if intr == LLVM.Intrinsic("llvm.assume").id + return true + end for attr in collect(function_attributes(f)) if kind(attr) == kind(EnumAttribute("readonly")) return true @@ -129,6 +139,16 @@ function is_readonly(f::LLVM.Function) end function is_readnone(f::LLVM.Function) + intr = LLVM.API.LLVMGetIntrinsicID(f) + if intr == LLVM.Intrinsic("llvm.lifetime.start").id + return true + end + if intr == LLVM.Intrinsic("llvm.lifetime.end").id + return true + end + if intr == LLVM.Intrinsic("llvm.assume").id + return true + end for attr in collect(function_attributes(cur)) if kind(attr) == kind(EnumAttribute("readnone")) return true @@ -145,6 +165,16 @@ function is_readnone(f::LLVM.Function) end function is_writeonly(f::LLVM.Function) + intr = LLVM.API.LLVMGetIntrinsicID(f) + if intr == LLVM.Intrinsic("llvm.lifetime.start").id + return true + end + if intr == LLVM.Intrinsic("llvm.lifetime.end").id + return true + end + if intr == LLVM.Intrinsic("llvm.assume").id + return true + end for attr in collect(function_attributes(cur)) if kind(attr) == kind(EnumAttribute("readnone")) return true diff --git a/src/internal_rules.jl b/src/internal_rules.jl index e3040de747..8e2976944a 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -268,7 +268,7 @@ end function EnzymeRules.augmented_primal(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} - config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI}() + config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI, false}() fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) TapeType = EnzymeRules.tape_type(fwd_thunk) @@ -293,7 +293,7 @@ end function EnzymeRules.reverse(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, tapes, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} - config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI}() + config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI, false}() fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) Enzyme.pmap(pmap_rev, count.val, tapes, rev_thunk, body, args...) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 8f68cec991..337b54ace6 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -205,7 +205,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) world = codegen_world_age(FT, tt) opt_mi = Val(world) - forward = thunk(opt_mi, (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val($ModifiedBetween), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + forward = thunk(opt_mi, (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val($ModifiedBetween), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) res = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) @@ -308,7 +308,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) opt_mi = Val(world) forward, adjoint = thunk(opt_mi, dupClosure0 ? Duplicated{FT} : Const{FT}, annotationA, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) internal_tape, origRet, initShadow = forward(dupClosure0 ? Duplicated(f, df) : Const(f), args...) annotation = annotationA @@ -444,7 +444,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act opt_mi = Val(world) _, adjoint = thunk(opt_mi, dupClosure0 ? Duplicated{FT} : Const{FT}, annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) tup = if annotation0 <: Active adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] @@ -702,7 +702,7 @@ function fwddiff_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType end opt_mi = Val(world) res = thunk(opt_mi, FA, annotation, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), - ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI)(fa, args...) + ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false))(fa, args...) return if annotation <: Const ReturnType(allFirst(Val(width+1), res)) else @@ -837,7 +837,7 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} opt_mi = Val(world) forward, adjoint = thunk(opt_mi, FA, annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) forward(fa, args...) else nothing, primal_tuple(args...), annotation <: Active ? nothing : shadow_tuple(annotation, Val(width), args...) @@ -984,7 +984,7 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween opt_mi = Val(world) forward, adjoint = thunk(opt_mi, FA, annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) args2 = if tape.shadow_return !== nothing if width == 1 diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index 651c05952e..1db4cd8d0b 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -3,7 +3,7 @@ function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ss FT = Core.Typeof(fn) ghos = guaranteed_const(FT) opt_mi = world - forward = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI) + forward = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) ft = ghos ? Const(fn) : Duplicated(fn, dfn) function fclosure() res = forward(ft) @@ -18,7 +18,7 @@ function runtime_newtask_augfwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, FT = Core.Typeof(fn) ghos = guaranteed_const(FT) opt_mi = world - forward, adjoint = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI) + forward, adjoint = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) ft = ghos ? Const(fn) : Duplicated(fn, dfn) taperef = Ref{Any}() @@ -194,7 +194,7 @@ end if mode == API.DEM_ForwardMode if fwdmodenm === nothing etarget = Compiler.EnzymeTarget() - eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ForwardMode, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, FFIABI) + eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ForwardMode, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, FFIABI, #=ErrIfFuncWritten=#false) ejob = Compiler.CompilerJob(mi2, CompilerConfig(etarget, eparams; kernel=false), world) cmod, fwdmodenm, _, _ = _thunk(ejob, #=postopt=#false) @@ -225,7 +225,7 @@ end if augfwdnm === nothing || adjointnm === nothing etarget = Compiler.EnzymeTarget() # TODO modifiedBetween - eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ReverseModePrimal, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, FFIABI) + eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ReverseModePrimal, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, FFIABI, #=ErrIfFuncWritten=#false) ejob = Compiler.CompilerJob(mi2, CompilerConfig(etarget, eparams; kernel=false), world) cmod, adjointnm, augfwdnm, TapeType = _thunk(ejob, #=postopt=#false) diff --git a/src/utils.jl b/src/utils.jl index fdc15db247..cc2af40c74 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,7 +5,7 @@ Assumes that `val` is globally rooted and pointer to it can be leaked. Prefer `pointer_from_objref`. Only use inside Enzyme.jl should be for Types. """ -@inline unsafe_to_pointer(val::Type{T}) where T = ccall(Base.@cfunction(x->x, Ptr{Cvoid}, (Ptr{Cvoid},)), Ptr{Cvoid}, (Any,), val) +@inline unsafe_to_pointer(val::Type{T}) where T = ccall(Base.@cfunction(Base.identity, Ptr{Cvoid}, (Ptr{Cvoid},)), Ptr{Cvoid}, (Any,), val) export unsafe_to_pointer @inline is_concrete_tuple(x::Type{T2}) where T2 = (T2 <: Tuple) && !(T2 === Tuple) && !(T2 isa UnionAll) @@ -66,8 +66,7 @@ function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val)) if legal curent_bb = position(B) fn = LLVM.parent(curent_bb) - world = Compiler.enzyme_extract_world(fn) - if Compiler.guaranteed_const_nongen(jTy, world) + if Compiler.guaranteed_const_nongen(jTy, nothing) API.SetMD(gv, "enzyme_inactive", LLVM.MDNode(LLVM.Metadata[])) end end @@ -88,8 +87,7 @@ function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val)) if legal curent_bb = position(B) fn = LLVM.parent(curent_bb) - world = Compiler.enzyme_extract_world(fn) - if Compiler.guaranteed_const_nongen(jTy, world) + if Compiler.guaranteed_const_nongen(jTy, nothing) API.SetMD(gv, "enzyme_inactive", LLVM.MDNode(LLVM.Metadata[])) end end diff --git a/test/abi.jl b/test/abi.jl index 93bf471fde..fa1cd723cb 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -9,11 +9,11 @@ using Test res = autodiff(Reverse, f, Const, Const(nothing)) @test res === ((nothing,),) - res = autodiff(ReverseMode{false,NonGenABI, false}(), f, Const, Const(nothing)) + res = autodiff(ReverseMode{false,NonGenABI, false, false}(), f, Const, Const(nothing)) @test res === ((nothing,),) @test () === autodiff(Forward, f, Const, Const(nothing)) - @test () === autodiff(ForwardMode{NonGenABI}(), f, Const, Const(nothing)) + @test () === autodiff(ForwardMode{NonGenABI, false}(), f, Const, Const(nothing)) res = autodiff(Reverse, f, Const(nothing)) @test res === ((nothing,),) @@ -22,11 +22,11 @@ using Test res = autodiff_deferred(Reverse, f, Const(nothing)) @test res === ((nothing,),) - res = autodiff_deferred(ReverseMode{false,NonGenABI, false}(), f, Const, Const(nothing)) + res = autodiff_deferred(ReverseMode{false,NonGenABI, false, false}(), f, Const, Const(nothing)) @test res === ((nothing,),) @test () === autodiff_deferred(Forward, f, Const(nothing)) - @test () === autodiff_deferred(ForwardMode{NonGenABI}(), f, Const, Const(nothing)) + @test () === autodiff_deferred(ForwardMode{NonGenABI, false}(), f, Const, Const(nothing)) # ConstType -> Type{Int} res = autodiff(Reverse, f, Const, Const(Int)) @@ -65,7 +65,7 @@ using Test _, res0 = autodiff(Reverse, unused, Active, Const(nothing), Active(2.0))[1] @test res0 ≈ 1.0 - _, res0 = autodiff(ReverseMode{false, NonGenABI, false}(), unused, Active, Const(nothing), Active(2.0))[1] + _, res0 = autodiff(ReverseMode{false, NonGenABI, false, false}(), unused, Active, Const(nothing), Active(2.0))[1] @test res0 ≈ 1.0 res0, = autodiff(Forward, unused, DuplicatedNoNeed, Const(nothing), Duplicated(2.0, 1.0)) @@ -73,7 +73,7 @@ using Test res0, = autodiff(Forward, unused, DuplicatedNoNeed, Const(nothing), DuplicatedNoNeed(2.0, 1.0)) @test res0 ≈ 1.0 - res0, = autodiff(ForwardMode{NonGenABI}(), unused, DuplicatedNoNeed, Const(nothing), Duplicated(2.0, 1.0)) + res0, = autodiff(ForwardMode{NonGenABI, false}(), unused, DuplicatedNoNeed, Const(nothing), Duplicated(2.0, 1.0)) @test res0 ≈ 1.0 _, res0 = autodiff(Reverse, unused, Const(nothing), Active(2.0))[1] @@ -409,8 +409,40 @@ end @test Enzyme.autodiff(Forward, method, DuplicatedNoNeed, Const(ABar()), Duplicated(3.0, 1.0))[1] ≈ 2.0 @test Enzyme.autodiff(Forward, ABar(), DuplicatedNoNeed, Duplicated(3.0, 1.0))[1] ≈ 2.0 + + struct RWClos + x::Vector{Float64} + end + + function (c::RWClos)(y) + c.x[1] *= y + return y + end + + c = RWClos([4.]) + + @test_throws Enzyme.Compiler.EnzymeMutabilityException autodiff(Reverse, c, Active(3.0)) + + @test autodiff(Reverse, Const(c), Active(3.0))[1][1] ≈ 1.0 + @test autodiff(Reverse, Duplicated(c, RWClos([2.7])), Active(3.0))[1][1] ≈ (1.0 + 2.7 * 4 * 3) + + struct RWClos2 + x::Vector{Float64} + end + + function (c::RWClos2)(y) + return y + c.x[1] + end + + c2 = RWClos2([4.]) + + @test autodiff(Reverse, c2, Active(3.0))[1][1] ≈ 1.0 + @test autodiff(Reverse, Const(c2), Active(3.0))[1][1] ≈ 1.0 + @test autodiff(Reverse, Duplicated(c2, RWClos2([2.7])), Active(3.0))[1][1] ≈ 1.0 end + + @testset "Promotion" begin x = [1.0, 2.0]; dx_1 = [1.0, 0.0]; dx_2 = [0.0, 1.0]; rosenbrock_inp(x) = (1.0 - x[1])^2 + 100.0 * (x[2] - x[1]^2)^2 diff --git a/test/amdgpu.jl b/test/amdgpu.jl index 09d120e246..9c9b097422 100644 --- a/test/amdgpu.jl +++ b/test/amdgpu.jl @@ -38,8 +38,6 @@ function grad_exp_kernel(A, dA) return nothing end -Enzyme.API.printall!(true) - @testset "exp_kernel" begin A = AMDGPU.ones(64,) @roc groupsize=length(A) exp_kernel(A) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 9a5c0bdbb2..835f195c4b 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -238,30 +238,31 @@ function tcholsolv_upper(A, B, i) return c[i] end + @testset "Cholesky PotRS 3x5" begin x = [1.0 0.13147601759884564 0.5282944836504488; 0.13147601759884564 1.0 0.18506733179093515; 0.5282944836504488 0.18506733179093515 1.0] for i in 1:15 B = [3.1 2.7 5.9 2.4 1.6; 7.9 8.2 1.3 9.4 5.5; 4.7 2.9 9.8 7.1 4.3] - reverse_grad = Enzyme.gradient(Reverse, B -> tcholsolv_lower(x, B, i), B) + reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_lower(x, B, i)), B) # forward_grad = reshape(collect(Enzyme.gradient(Forward, B -> tcholsolv_lower(x, B, i), B)), size(B)) finite_diff = FiniteDifferences.grad(central_fdm(5, 1), B -> tcholsolv_lower(x, B, i), B)[1] @test reverse_grad ≈ finite_diff # @test forward_grad ≈ finite_diff - reverse_grad = Enzyme.gradient(Reverse, B -> tcholsolv_upper(x, B, i), B) + reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_upper(x, B, i)), B) # forward_grad = reshape(collect(Enzyme.gradient(Forward, B -> tcholsolv_upper(x, B, i), B)), size(B)) finite_diff = FiniteDifferences.grad(central_fdm(5, 1), B -> tcholsolv_upper(x, B, i), B)[1] @test reverse_grad ≈ finite_diff # @test forward_grad ≈ finite_diff - reverse_grad = Enzyme.gradient(Reverse, x -> tcholsolv_lower(x, B, i), x) + reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_lower(x, B, i)), x) #forward_grad = reshape(collect(Enzyme.gradient(Forward, x -> tcholsolv_lower(x, B, i), x)), size(x)) finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tcholsolv_lower(x, B, i), x)[1] @test reverse_grad ≈ finite_diff #@test forward_grad ≈ finite_diff # - reverse_grad = Enzyme.gradient(Reverse, x -> tcholsolv_upper(x, B, i), x) + reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_upper(x, B, i)), x) #forward_grad = reshape(collect(Enzyme.gradient(Forward, x -> tcholsolv_upper(x, B, i), x)), size(x)) finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tcholsolv_upper(x, B, i), x)[1] @test reverse_grad ≈ finite_diff diff --git a/test/rrules.jl b/test/rrules.jl index b6681e6739..3d330cf5fc 100644 --- a/test/rrules.jl +++ b/test/rrules.jl @@ -340,7 +340,7 @@ end @testset "Closure rule" begin cl = Closure([3.14]) - res = autodiff(Reverse, cl, Active, Active(2.7))[1][1] + res = autodiff(Reverse, Const(cl), Active, Active(2.7))[1][1] @test res ≈ 7 * 2.7 + 3.14 * 1000 @test cl.v[1] ≈ 0.0 end diff --git a/test/runtests.jl b/test/runtests.jl index 092cf8abce..46a5a7c484 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -144,10 +144,10 @@ end @test Enzyme.Compiler.active_reg_inner(Tuple, (), nothing, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) == Enzyme.Compiler.MixedState @test Enzyme.Compiler.active_reg_inner(Tuple{A,A} where A, (), nothing, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) == Enzyme.Compiler.MixedState world = codegen_world_age(typeof(f0), Tuple{Float64}) - thunk_a = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI) - thunk_b = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Const, Tuple{Const{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI) - thunk_c = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI) - thunk_d = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI) + thunk_a = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false)) + thunk_b = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Const, Tuple{Const{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false)) + thunk_c = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false)) + thunk_d = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false)) @test thunk_a.adjoint !== thunk_b.adjoint @test thunk_c.adjoint === thunk_a.adjoint @test thunk_c.adjoint === thunk_d.adjoint @@ -156,7 +156,7 @@ end @test thunk_a(Const(f0), Active(2.0), 2.0) == ((2.0,),) @test thunk_b(Const(f0), Const(2.0)) === ((nothing,),) - forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI) + forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false)) @test forward(Const(f0), Active(2.0)) == (nothing,nothing,nothing) @test pullback(Const(f0), Active(2.0), 1.0, nothing) == ((1.0,),) @@ -167,7 +167,7 @@ end d = Duplicated([3.0, 5.0], [0.0, 0.0]) world = codegen_world_age(typeof(mul2), Tuple{Vector{Float64}}) - forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(mul2)}, Active, Tuple{Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, true)), Val(false), Val(false), DefaultABI) + forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(mul2)}, Active, Tuple{Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, true)), Val(false), Val(false), DefaultABI, Val(false)) res = forward(Const(mul2), d) @test typeof(res[1]) == Tuple{Float64, Float64} pullback(Const(mul2), d, 1.0, res[1]) @@ -176,7 +176,7 @@ end d = Duplicated([3.0, 5.0], [0.0, 0.0]) world = codegen_world_age(typeof(vrec), Tuple{Int, Vector{Float64}}) - forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(vrec)}, Active, Tuple{Const{Int}, Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false, true)), Val(false), Val(false), DefaultABI) + forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(vrec)}, Active, Tuple{Const{Int}, Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false, true)), Val(false), Val(false), DefaultABI, Val(false)) res = forward(Const(vrec), Const(Int(1)), d) pullback(Const(vrec), Const(1), d, 1.0, res[1]) @test d.dval[1] ≈ 5.0 @@ -261,11 +261,9 @@ sqrtsumsq2(x) = (sum(abs2, x)*sum(abs2,x)) # TODO we need to fix julia to remove unused bounds checks # @test !occursin("aug",fn) - Enzyme.API.printall!(true) fn = sprint() do io Enzyme.Compiler.enzyme_code_llvm(io, sqrtsumsq2, Active, Tuple{Duplicated{Vector{Float64}}}; dump_module=true) end - Enzyme.API.printall!(false) @test occursin("diffe",fn) if count("call fastcc void @diffejulia__mapreduce", fn) != 1 println(sprint() do io @@ -937,7 +935,7 @@ function grad_closure(f, x) dy = zeros(n) dy[1] = 1.0 - autodiff(Reverse, noretval, Duplicated(x,dx), Duplicated(y, dy)) + autodiff(Reverse, Const(noretval), Duplicated(x,dx), Duplicated(y, dy)) return dx end @@ -1060,7 +1058,7 @@ end @test res.y == nothing end -@testset "Methoe errors" begin +@testset "Method errors" begin fwd = Enzyme.autodiff_thunk(Forward, Const{typeof(sum)}, Duplicated, Duplicated{Vector{Float64}}) @test_throws MethodError fwd(ones(10)) @test_throws MethodError fwd(Duplicated(ones(10), ones(10))) @@ -1167,7 +1165,7 @@ end # doesn't use any of the const data values, but now that we error for activity confusion, we need to # mark runtimeActivity to let this pass Enzyme.API.runtimeActivity!(true) - Enzyme.autodiff(Enzyme.Reverse, smallrf, Enzyme.Duplicated(weights, dweights), Enzyme.Const(data)) + Enzyme.autodiff(Enzyme.Reverse, Const(smallrf), Enzyme.Duplicated(weights, dweights), Enzyme.Const(data)) @test dweights[1] ≈ 1. function invokesum(weights::Vector{Float64}, data::Vector{Float64})::Float64 @@ -2122,8 +2120,8 @@ end -t nothing end - autodiff(Reverse, tobedifferentiated, Duplicated(F, L), Const(false)) - autodiff(Forward, tobedifferentiated, Duplicated(F, L), Const(false)) + autodiff(Reverse, Const(tobedifferentiated), Duplicated(F, L), Const(false)) + autodiff(Forward, Const(tobedifferentiated), Duplicated(F, L), Const(false)) end main() @@ -2855,9 +2853,9 @@ end J_r_2(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_2(A, θ), x, Val(5)) J_r_3(u, A, x) = Enzyme.jacobian(Reverse, θ -> f_test_3!(u, A, θ), x, Val(5)) - J_f_1(A, x) = Enzyme.jacobian(Forward, θ -> f_test_1(A, θ), x) - J_f_2(A, x) = Enzyme.jacobian(Forward, θ -> f_test_2(A, θ), x) - J_f_3(u, A, x) = Enzyme.jacobian(Forward, θ -> f_test_3!(u, A, θ), x) + J_f_1(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_1(A, θ)), x) + J_f_2(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_2(A, θ)), x) + J_f_3(u, A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_3!(u, A, θ)), x) x = ones(6) A = Matrix{Float64}(LinearAlgebra.I, 5, 5) @@ -3030,7 +3028,7 @@ end c = ones(3) inner(e) = c .+ e - fres = Enzyme.autodiff(Enzyme.Forward, inner, Duplicated{Vector{Float64}}, Duplicated([0., 0., 0.], [1., 1., 1.]))[1] + fres = Enzyme.autodiff(Enzyme.Forward, Const(inner), Duplicated{Vector{Float64}}, Duplicated([0., 0., 0.], [1., 1., 1.]))[1] @test c ≈ [1.0, 1.0, 1.0] @test fres ≈ [1.0, 1.0, 1.0] end @@ -3133,7 +3131,7 @@ end end Enzyme.API.runtimeActivity!(true) - res = autodiff(Forward, f2, Duplicated, Duplicated(0.2, 1.0)) + res = autodiff(Forward, Const(f2), Duplicated, Duplicated(0.2, 1.0)) Enzyme.API.runtimeActivity!(false) @test res[1] ≈ 0.2 # broken as the return of an apply generic is {primal, primal} From ece399d6802616b36c046e4a7ecb9039a417079b Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 7 Aug 2024 09:24:23 -0700 Subject: [PATCH 211/495] Add bfloat16 (#1708) * Add bfloat16 * Add ext --- Project.toml | 4 ++++ ext/EnzymeBFloat16sExt.jl | 10 ++++++++++ src/api.jl | 8 +++++++- src/typetree.jl | 6 ++++++ 4 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 ext/EnzymeBFloat16sExt.jl diff --git a/Project.toml b/Project.toml index 9cdd1ebe9c..67f42588ae 100644 --- a/Project.toml +++ b/Project.toml @@ -17,18 +17,21 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [weakdeps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] +EnzymeBFloat16sExt = "BFloat16s" EnzymeChainRulesCoreExt = "ChainRulesCore" EnzymeLogExpFunctionsExt = "LogExpFunctions" EnzymeSpecialFunctionsExt = "SpecialFunctions" EnzymeStaticArraysExt = "StaticArrays" [compat] +BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.8" @@ -43,6 +46,7 @@ StaticArrays = "1" julia = "1.6" [extras] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" diff --git a/ext/EnzymeBFloat16sExt.jl b/ext/EnzymeBFloat16sExt.jl new file mode 100644 index 0000000000..35766c28e6 --- /dev/null +++ b/ext/EnzymeBFloat16sExt.jl @@ -0,0 +1,10 @@ +module EnzymeBFloat16sExt + +using BFloat16s +using Enzyme + +function Enzyme.Compiler.typetree_inner(::Type{Core.BFloat16}, ctx, dl, seen::Enzyme.Compiler.TypeTreeTable) + return TypeTree(Enzyme.API.DT_BFloat16, -1, ctx) +end + +end diff --git a/src/api.jl b/src/api.jl index d5cb2a5451..6de95beb55 100644 --- a/src/api.jl +++ b/src/api.jl @@ -29,7 +29,9 @@ IntList() = IntList(Ptr{Int64}(0),0) DT_Half = 3, DT_Float = 4, DT_Double = 5, - DT_Unknown = 6 + DT_Unknown = 6, + DT_FP80 = 7, + DT_BFloat16 = 8 ) function EnzymeConcreteTypeIsFloat(cc::CConcreteType) @@ -39,6 +41,10 @@ function EnzymeConcreteTypeIsFloat(cc::CConcreteType) return LLVM.FloatType() elseif cc == DT_Double return LLVM.DoubleType() + elseif cc == DT_FP80 + return LLVM.X86FP80Type() + elseif cc == DT_BFloat16 + return LLVM.BFloatType() else return nothing end diff --git a/src/typetree.jl b/src/typetree.jl index f8c70808be..065dccbbd8 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -111,6 +111,12 @@ function typetree_inner(::Type{Float64}, ctx, dl, seen::TypeTreeTable) return TypeTree(API.DT_Double, -1, ctx) end +@static if VERSION >= v"1.11-" +function typetree_inner(::Type{Core.BFloat16}, ctx, dl, seen::TypeTreeTable) + return TypeTree(API.DT_BFloat16, -1, ctx) +end +end + function typetree_inner(::Type{BigFloat}, ctx, dl, seen::TypeTreeTable) return TypeTree() end From d19307d6e55f38b072bd309ee4a5174045f23bad Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 7 Aug 2024 12:04:27 -0700 Subject: [PATCH 212/495] Add ABI setter (#1709) --- lib/EnzymeCore/src/EnzymeCore.jl | 4 ++++ src/Enzyme.jl | 4 ++-- test/abi.jl | 12 ++++++------ 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 94df9a61c5..ef1b34e56d 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -228,6 +228,8 @@ const ReverseHolomorphicWithPrimal = ReverseMode{true,DefaultABI, true, false}() @inline set_err_if_func_written(::ReverseMode{ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,ABI,Holomorphic,true}() @inline clear_err_if_func_written(::ReverseMode{ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,ABI,Holomorphic,false}() +@inline set_abi(::ReverseMode{ReturnPrimal,OldABI,Holomorphic,ErrIfFuncWritten}, ::Type{NewABI}) where {ReturnPrimal,OldABI,Holomorphic,ErrIfFuncWritten,NewABI<:ABI} = ReverseMode{ReturnPrimal,NewABI,Holomorphic,ErrIfFuncWritten}() + """ struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI} <: Mode{ABI} @@ -255,6 +257,8 @@ const Forward = ForwardMode{DefaultABI, false}() @inline set_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten}) where {ABI,ErrIfFuncWritten} = ForwardMode{ABI,true}() @inline clear_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten}) where {ABI,ErrIfFuncWritten} = ForwardMode{ABI,false}() +@inline set_abi(::ForwardMode{OldABI,ErrIfFuncWritten}, ::Type{NewABI}) where {OldABI,ErrIfFuncWritten,NewABI<:ABI} = ForwardMode{NewABI,ErrIfFuncWritten}() + function autodiff end function autodiff_deferred end function autodiff_thunk end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 076eb62761..5ac12d1fe0 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -5,8 +5,8 @@ import EnzymeCore import EnzymeCore: Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal -import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written -export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written +import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi +export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi import EnzymeCore: BatchDuplicatedFunc export BatchDuplicatedFunc diff --git a/test/abi.jl b/test/abi.jl index fa1cd723cb..e07b7403ce 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -9,11 +9,11 @@ using Test res = autodiff(Reverse, f, Const, Const(nothing)) @test res === ((nothing,),) - res = autodiff(ReverseMode{false,NonGenABI, false, false}(), f, Const, Const(nothing)) + res = autodiff(Enzyme.set_abi(Reverse, NonGenABI), f, Const, Const(nothing)) @test res === ((nothing,),) @test () === autodiff(Forward, f, Const, Const(nothing)) - @test () === autodiff(ForwardMode{NonGenABI, false}(), f, Const, Const(nothing)) + @test () === autodiff(Enzyme.set_abi(Forward, NonGenABI), f, Const, Const(nothing)) res = autodiff(Reverse, f, Const(nothing)) @test res === ((nothing,),) @@ -22,11 +22,11 @@ using Test res = autodiff_deferred(Reverse, f, Const(nothing)) @test res === ((nothing,),) - res = autodiff_deferred(ReverseMode{false,NonGenABI, false, false}(), f, Const, Const(nothing)) + res = autodiff_deferred(Enzyme.set_abi(Reverse, NonGenABI), f, Const, Const(nothing)) @test res === ((nothing,),) @test () === autodiff_deferred(Forward, f, Const(nothing)) - @test () === autodiff_deferred(ForwardMode{NonGenABI, false}(), f, Const, Const(nothing)) + @test () === autodiff_deferred(Enzyme.set_abi(Forward, NonGenABI), f, Const, Const(nothing)) # ConstType -> Type{Int} res = autodiff(Reverse, f, Const, Const(Int)) @@ -65,7 +65,7 @@ using Test _, res0 = autodiff(Reverse, unused, Active, Const(nothing), Active(2.0))[1] @test res0 ≈ 1.0 - _, res0 = autodiff(ReverseMode{false, NonGenABI, false, false}(), unused, Active, Const(nothing), Active(2.0))[1] + _, res0 = autodiff(Enzyme.set_abi(Reverse, NonGenABI), unused, Active, Const(nothing), Active(2.0))[1] @test res0 ≈ 1.0 res0, = autodiff(Forward, unused, DuplicatedNoNeed, Const(nothing), Duplicated(2.0, 1.0)) @@ -73,7 +73,7 @@ using Test res0, = autodiff(Forward, unused, DuplicatedNoNeed, Const(nothing), DuplicatedNoNeed(2.0, 1.0)) @test res0 ≈ 1.0 - res0, = autodiff(ForwardMode{NonGenABI, false}(), unused, DuplicatedNoNeed, Const(nothing), Duplicated(2.0, 1.0)) + res0, = autodiff(Enzyme.set_abi(Forward, NonGenABI), unused, DuplicatedNoNeed, Const(nothing), Duplicated(2.0, 1.0)) @test res0 ≈ 1.0 _, res0 = autodiff(Reverse, unused, Const(nothing), Active(2.0))[1] From 356ef3493085cdfa5084c4c6d23859e8f183ed3e Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 7 Aug 2024 12:04:57 -0700 Subject: [PATCH 213/495] More diverse error types (#1710) * More diverse types * fix --- src/compiler.jl | 55 ++++++++++++++++++++++++++++++++++++++++-------- test/runtests.jl | 2 +- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 477a9f6a3c..c0f00a0bf1 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1022,6 +1022,35 @@ function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError) print(io, msg, '\n') end +struct EnzymeNoTypeError <: Base.Exception + msg::Cstring +end + +function Base.showerror(io::IO, ece::EnzymeNoTypeError) + print(io, "Enzyme cannot deduce type\n") + msg = Base.unsafe_string(ece.msg) + print(io, msg, '\n') +end + +struct EnzymeNoShadowError <: Base.Exception + msg::Cstring +end + +function Base.showerror(io::IO, ece::EnzymeNoShadowError) + print(io, "Enzyme could not find shadow for value\n") + msg = Base.unsafe_string(ece.msg) + print(io, msg, '\n') +end + +struct EnzymeNoDerivativeError <: Base.Exception + msg::Cstring +end + +function Base.showerror(io::IO, ece::EnzymeNoDerivativeError) + msg = Base.unsafe_string(ece.msg) + print(io, msg, '\n') +end + @static if VERSION >= v"1.8.0" const JuliaEnzymeNameMap = Dict{String, Any}( "enz_val_true" => Val(true), @@ -1033,6 +1062,9 @@ const JuliaEnzymeNameMap = Dict{String, Any}( "enz_runtime_exc" => EnzymeRuntimeException, "enz_mut_exc" => EnzymeMutabilityException, "enz_runtime_activity_exc" => EnzymeRuntimeActivityError, + "enz_no_type_exc" => EnzymeNoTypeError, + "enz_no_shadow_exc" => EnzymeNoShadowError, + "enz_no_derivative_exc" => EnzymeNoDerivativeError, ) else const JuliaEnzymeNameMap = Dict{String, Any}() @@ -2139,21 +2171,27 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err if occursin("No create nofree of empty function", msg) || occursin("No forward mode derivative found for", msg) || occursin("No augmented forward pass", msg) || occursin("No reverse pass found", msg) ir = nothing end - exc = NoDerivativeException(msg, ir, bt) if B != C_NULL B = IRBuilder(B) - msg2 = sprint() do io - Base.showerror(io, exc) + msg2 = sprint() do io + if ir !== nothing + print(io, "Current scope: \n") + print(io, ir) + end + print(io, '\n', msg, '\n') + if bt !== nothing + Base.show_backtrace(io, bt) + println(io) + end end - emit_error(B, nothing, msg2) + emit_error(B, nothing, msg2, EnzymeNoDerivativeError) return C_NULL end - throw(exc) + throw(NoDerivativeException(msg, ir, bt)) elseif errtype == API.ET_NoShadow gutils = GradientUtils(API.EnzymeGradientUtilsRef(data)) msgN = sprint() do io::IO - print(io, "Enzyme could not find shadow for value\n") if isa(val, LLVM.Argument) fn = parent_scope(val) ir = string(LLVM.name(fn))*string(function_type(fn)) @@ -2174,7 +2212,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err println(io) end end - emit_error(IRBuilder(B), nothing, msgN) + emit_error(IRBuilder(B), nothing, msgN, EnzymeNoShadowError) return LLVM.null(get_shadow_type(gutils, value_type(val))).ref elseif errtype == API.ET_IllegalTypeAnalysis data = API.EnzymeTypeAnalyzerRef(data) @@ -2199,7 +2237,6 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err API.EnzymeStringFree(ip) msg2 = sprint() do io::IO - print(io, "Enzyme cannot deduce type\n") if !occursin("Cannot deduce single type of store", msg) if ir !== nothing print(io, "Current scope: \n") @@ -2220,7 +2257,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err println(io, "within ", mi) end end - emit_error(B, nothing, msg2) + emit_error(B, nothing, msg2, EnzymeNoTypeError) return C_NULL elseif errtype == API.ET_IllegalFirstPointer throw(IllegalFirstPointerException(msg, ir, bt)) diff --git a/test/runtests.jl b/test/runtests.jl index 46a5a7c484..8fdd6f6037 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2494,7 +2494,7 @@ end @testset "Exception" begin f_no_derv(x) = ccall("extern doesnotexist", llvmcall, Float64, (Float64,), x) - @test_throws Enzyme.Compiler.EnzymeRuntimeException autodiff(Reverse, f_no_derv, Active, Active(0.5)) + @test_throws Enzyme.Compiler.EnzymeNoDerivativeError autodiff(Reverse, f_no_derv, Active, Active(0.5)) f_union(cond, x) = cond ? x : 0 g_union(cond, x) = f_union(cond,x)*x From 2b681d54234190fdcf70b2844af8cd33700ee0be Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 8 Aug 2024 12:44:09 -0700 Subject: [PATCH 214/495] Fixup more than simple jacobian (#1712) * Fixup more than simple jacobian * remove prints * work around stack * ease jacobian * fix * fix * fix * fix * fix * fix --- src/Enzyme.jl | 277 +++++++++++++++++++++++++++++++++++++++-------- src/compiler.jl | 5 + test/runtests.jl | 125 +++++++++++++++++++++ 3 files changed, 362 insertions(+), 45 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 5ac12d1fe0..5bdecbcceb 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1002,6 +1002,10 @@ end end end +@inline function onehot(x::AbstractFloat) + return (one(x),) +end + """ gradient(::ReverseMode, f, x) @@ -1126,10 +1130,15 @@ grad = gradient(Forward, f, [2.0, 3.0]) ``` """ @inline function gradient(::ForwardMode, f, x; shadow=onehot(x)) - if length(x) == 0 + if length(shadow) == 0 return () end - values(only(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) + res = values(only(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) + if x isa AbstractFloat + res[1] + else + res + end end @inline function chunkedonehot(x, ::Val{chunk}) where chunk @@ -1141,6 +1150,10 @@ end end end +@inline function chunkedonehot(x::AbstractFloat, ::Val{chunk}) where chunk + return ((one(x),),) +end + @inline tupleconcat(x) = x @inline tupleconcat(x, y) = (x..., y...) @inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...) @@ -1171,44 +1184,84 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2)) tmp = ntuple(length(shadow)) do i values(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1]) end - tupleconcat(tmp...) + res = tupleconcat(tmp...) + if x isa AbstractFloat + res[1] + else + res + end end @inline function gradient(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X} - ntuple(length(shadow)) do i + res = ntuple(length(shadow)) do i autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] end + if x isa AbstractFloat + res[1] + else + res + end end """ jacobian(::ForwardMode, f, x; shadow=onehot(x)) jacobian(::ForwardMode, f, x, ::Val{chunk}; shadow=onehot(x)) -Compute the jacobian of an array-input function `f` using (potentially vector) -forward mode. This is a simple rename of the [`gradient`](@ref) function, -and all relevant arguments apply here. +Compute the jacobian of an array or scalar-input function `f` using (potentially vector) +forward mode. All relevant arguments of the forward-mode [`gradient`](@ref) function +apply here. Example: ```jldoctest -f(x) = [x[1]*x[2], x[2]] +f(x) = [ x[1] * x[2], x[2] + x[3] ] -grad = jacobian(Forward, f, [2.0, 3.0]) +grad = jacobian(Forward, f, [2.0, 3.0, 4.0]) # output -2×2 Matrix{Float64}: - 3.0 2.0 - 0.0 1.0 +2×3 Matrix{Float64}: + 3.0 2.0 0.0 + 0.0 1.0 1.0 ``` + +For functions which return an AbstractArray, this function will return an array +whose shape is `(size(output)..., size(input)...)` + +For functions who return other types, this function will retun an array or tuple +of shape `size(input)` of values of the output type. """ @inline function jacobian(::ForwardMode, f, x; shadow=onehot(x)) - cols = if length(x) == 0 - return () + cols = if length(shadow) == 0 + () else values(only(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) end - reduce(hcat, cols) + if x isa AbstractFloat + cols[1] + elseif length(cols) > 0 && cols[1] isa AbstractArray + inshape = size(x) + outshape = size(cols[1]) + # st : outshape x total inputs + st = @static if VERSION >= v"1.9" + Base.stack(cols) + else + reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) + end + + st3 = if length(inshape) <= 1 || VERSION < v"1.9" + st + else + reshape(st, (outshape..., inshape...)) + end + + st3 + elseif x isa AbstractArray + inshape = size(x) + reshape(collect(cols), inshape) + else + cols + end end @inline function jacobian(::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk} @@ -1216,50 +1269,109 @@ end throw(ErrorException("Cannot differentiate with a batch size of 0")) end tmp = ntuple(length(shadow)) do i + Base.@_inline_meta values(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1]) end cols = tupleconcat(tmp...) - reduce(hcat, cols) + if x isa AbstractFloat + cols[1] + elseif length(cols) > 0 && cols[1] isa AbstractArray + inshape = size(x) + outshape = size(cols[1]) + # st : outshape x total inputs + st = @static if VERSION >= v"1.9" + Base.stack(cols) + else + reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) + end + + st3 = if length(inshape) <= 1 || VERSION < v"1.9" + st + else + reshape(st, (outshape..., inshape...)) + end + + st3 + elseif x isa AbstractArray + inshape = size(x) + reshape(collect(cols), inshape) + else + cols + end end @inline function jacobian(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F,X} cols = ntuple(length(shadow)) do i + Base.@_inline_meta autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] end - reduce(hcat, cols) + if x isa AbstractFloat + cols[1] + elseif length(cols) > 0 && cols[1] isa AbstractArray + inshape = size(x) + outshape = size(cols[1]) + # st : outshape x total inputs + st = @static if VERSION >= v"1.9" + Base.stack(cols) + else + reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) + end + + st3 = if length(inshape) <= 1 || VERSION < v"1.9" + st + else + reshape(st, (outshape..., inshape...)) + end + + st3 + elseif x isa AbstractArray + inshape = size(x) + reshape(collect(cols), inshape) + else + cols + end end """ - jacobian(::ReverseMode, f, x, ::Val{num_outs}, ::Val{chunk}) + jacobian(::ReverseMode, f, x, ::Val{num_outs}, ::Val{chunk}=Val(1)) + jacobian(::ReverseMode, f, x) -Compute the jacobian of an array-input function `f` using (potentially vector) +Compute the jacobian of an array-output function `f` using (potentially vector) reverse mode. The `chunk` argument denotes the chunk size to use and `num_outs` denotes the number of outputs `f` will return in an array. Example: ```jldoctest -f(x) = [x[1]*x[2], x[2]] +f(x) = [ x[1] * x[2], x[2] + x[3] ] -grad = jacobian(Reverse, f, [2.0, 3.0], Val(2)) +grad = jacobian(Reverse, f, [2.0, 3.0, 4.0], Val(2)) # output -2×2 Matrix{Float64}: - 3.0 2.0 - 0.0 1.0 +2×3 transpose(::Matrix{Float64}) with eltype Float64: + 3.0 2.0 0.0 + 0.0 1.0 1.0 +``` + +For functions which return an AbstractArray, this function will return an array +whose shape is `(size(output)..., size(input)...)` + +For functions who return other types, this function will retun an array or tuple +of shape `size(output)` of values of the input type. ``` """ -@inline function jacobian(::ReverseMode{ReturnPrimal,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, ReturnPrimal, RABI<:ABI, ErrIfFuncWritten} - @assert !ReturnPrimal +@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, RABI<:ABI, ErrIfFuncWritten} num = ((n_out_val + chunk - 1) ÷ chunk) if chunk == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end - tt′ = Tuple{BatchDuplicated{Core.Typeof(x), chunk}} - tt = Tuple{Core.Typeof(x)} + XT = Core.Typeof(x) + MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState + tt′ = MD ? Tuple{BatchMixedDuplicated{XT, chunk}} : Tuple{BatchDuplicated{XT, chunk}} + tt = Tuple{XT} rt = Core.Compiler.return_type(f, tt) ModifiedBetween = Val((false, false)) FA = Const{Core.Typeof(f)} @@ -1281,28 +1393,59 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2)) tmp = ntuple(num) do i Base.@_inline_meta - dx = ntuple(i == num ? last_size : chunk) do idx + dx = ntuple(Val(i == num ? last_size : chunk)) do idx Base.@_inline_meta - zero(x) + z = make_zero(x) + MD ? Ref(z) : z end - res = (i == num ? primal2 : primal)(Const(f), BatchDuplicated(x, dx)) + res = (i == num ? primal2 : primal)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx)) tape = res[1] j = 0 for shadow in res[3] j += 1 @inbounds shadow[(i-1)*chunk+j] += Compiler.default_adjoint(eltype(typeof(shadow))) end - (i == num ? adjoint2 : adjoint)(Const(f), BatchDuplicated(x, dx), tape) - return dx + (i == num ? adjoint2 : adjoint)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), tape) + return MD ? (ntuple(Val(i == num ? last_size : chunk)) do idx + Base.@_inline_meta + dx[idx][] + end) : dx, (i == 1 ? size(res[3][1]) : nothing) + end + rows = tupleconcat(map(first, tmp)...) + outshape = tmp[1][2] + if x isa AbstractArray + inshape = size(x) + + st = @static if VERSION >= v"1.9" + Base.stack(rows) + else + reshape(cat(rows..., dims=length(inshape)), (inshape..., outshape...)) + end + + st2 = if length(outshape) == 1 || VERSION < v"1.9" + st + else + reshape(st, (inshape..., outshape...)) + end + + st3 = if length(outshape) == 1 && length(inshape) == 1 + transpose(st2) + else + transp = ( ((length(inshape)+1):(length(inshape)+length(outshape)))... , (1:length(inshape))... ) + PermutedDimsArray(st2, transp) + end + + st3 + else + reshape(collect(rows), outshape) end - rows = tupleconcat(tmp...) - mapreduce(LinearAlgebra.adjoint, vcat, rows) end -@inline function jacobian(::ReverseMode{ReturnPrimal,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,ReturnPrimal,RABI<:ABI, ErrIfFuncWritten} - @assert !ReturnPrimal - tt′ = Tuple{Duplicated{Core.Typeof(x)}} - tt = Tuple{Core.Typeof(x)} +@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,RABI<:ABI, ErrIfFuncWritten} + XT = Core.Typeof(x) + MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState + tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} + tt = Tuple{XT} rt = Core.Compiler.return_type(f, tt) ModifiedBetween = Val((false, false)) FA = Const{Core.Typeof(f)} @@ -1312,16 +1455,60 @@ end Val(codegen_world_age(Core.Typeof(f), tt)) end primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) - rows = ntuple(n_outs) do i + tmp = ntuple(n_outs) do i Base.@_inline_meta - dx = zero(x) - res = primal(Const(f), Duplicated(x, dx)) + z = make_zero(x) + dx = MD ? Ref(z) : z + res = primal(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx)) tape = res[1] @inbounds res[3][i] += Compiler.default_adjoint(eltype(typeof(res[3]))) - adjoint(Const(f), Duplicated(x, dx), tape) - return dx + adjoint(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx), tape) + return MD ? dx[] : dx, (i == 1 ? size(res[3]) : nothing) + end + rows = map(first, tmp) + outshape = tmp[1][2] + if x isa AbstractArray + inshape = size(x) + st = @static if VERSION >= v"1.9" + Base.stack(rows) + else + reshape(cat(rows..., dims=length(inshape)), (inshape..., outshape...)) + end + + st2 = if length(outshape) == 1 || VERSION < v"1.9" + st + else + reshape(st, (inshape..., outshape...)) + end + + st3 = if length(outshape) == 1 && length(inshape) == 1 + transpose(st2) + else + transp = ( ((length(inshape)+1):(length(inshape)+length(outshape)))... , (1:length(inshape))... ) + PermutedDimsArray(st2, transp) + end + + st3 + else + reshape(collect(rows), outshape) + end +end + +@inline function jacobian(::ReverseMode{ReturnPrimal,RABI, ErrIfFuncWritten}, f::F, x::X) where {ReturnPrimal, F, X, n_out_val,RABI<:ABI, ErrIfFuncWritten} + res = f(x) + jac = if res isa AbstractArray + jacobian(ReverseMode{false,RABI, ErrIfFuncWritten}(), f, x, Val(length(jac))) + elseif res isa AbstractFloat + gradient(ReverseMode{false,RABI, ErrIfFuncWritten}(), f, x) + else + throw(AssertionError("Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))")) + end + + if ReturnPrimal + (res, jac) + else + jac end - mapreduce(LinearAlgebra.adjoint, vcat, rows) end """ diff --git a/src/compiler.jl b/src/compiler.jl index c0f00a0bf1..e1ca4c395a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3886,7 +3886,12 @@ include("rules/activityrules.jl") @inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: DuplicatedNoNeed = API.DFT_DUP_NONEED @inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: BatchDuplicatedNoNeed = API.DFT_DUP_NONEED +const DumpPreEnzyme = Ref(false) + function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wrap, modifiedBetween, returnPrimal, expectedTapeType, loweredArgs, boxedArgs) + if DumpPreEnzyme[] + API.EnzymeDumpModuleRef(mod.ref) + end world = job.world interp = GPUCompiler.get_interpreter(job) rt = job.config.params.rt diff --git a/test/runtests.jl b/test/runtests.jl index 8fdd6f6037..94015cfa4f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2803,6 +2803,131 @@ end end end +@testset "Simple Jacobian" begin + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0) ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0) ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0]) ≈ [4.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, Val(1)) ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, Val(1)) ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], Val(1)) ≈ [4.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, Val(2)) ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, Val(2)) ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], Val(2)) ≈ [4.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, Val(2)) ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, Val(2), Val(1)) ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, Val(2), Val(2)) ≈ [1.0, 2.0] + + x = float.(reshape(1:6, 2, 3)) + + fillabs2(x) = [sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x), 1000*sum(abs2, x)] + + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x) + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, Val(1)) + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, Val(2)) + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + + jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, Val(4), Val(1)) + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, Val(4), Val(2)) + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + struct InpStruct + i1::Float64 + i2::Float64 + i3::Float64 + end + + fillinpabs2(x) = [(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 10*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 100*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 1000*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3)] + + x2 = InpStruct(1.0, 2.0, 3.0) + + jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, Val(4), Val(1)) + + @test jac[1] == InpStruct(2.0, 4.0, 6.0) + @test jac[2] == InpStruct(20.0, 40.0, 60.0) + @test jac[3] == InpStruct(200.0, 400.0, 600.0) + @test jac[4] == InpStruct(2000.0, 4000.0, 6000.0) + + jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, Val(4), Val(2)) + + @test jac[1] == InpStruct(2.0, 4.0, 6.0) + @test jac[2] == InpStruct(20.0, 40.0, 60.0) + @test jac[3] == InpStruct(200.0, 400.0, 600.0) + @test jac[4] == InpStruct(2000.0, 4000.0, 6000.0) + + struct OutStruct + i1::Float64 + i2::Float64 + i3::Float64 + end + + filloutabs2(x) = OutStruct(sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x)) + + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x) + + @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) + @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) + + @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) + @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) + + @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) + @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) + + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, Val(1)) + + @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) + @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) + + @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) + @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) + + @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) + @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) + + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, Val(2)) + + @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) + @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) + + @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) + @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) + + @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) + @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) + +end + + @testset "Jacobian" begin function inout(v) [v[2], v[1]*v[1], v[1]*v[1]*v[1]] From 91aaac9a413d9b00104e246004878aa3852267e3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 9 Aug 2024 02:24:32 -0700 Subject: [PATCH 215/495] fix: typetree_inner is not in Compiler module (#1716) --- ext/EnzymeBFloat16sExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/EnzymeBFloat16sExt.jl b/ext/EnzymeBFloat16sExt.jl index 35766c28e6..0fda13617e 100644 --- a/ext/EnzymeBFloat16sExt.jl +++ b/ext/EnzymeBFloat16sExt.jl @@ -3,7 +3,7 @@ module EnzymeBFloat16sExt using BFloat16s using Enzyme -function Enzyme.Compiler.typetree_inner(::Type{Core.BFloat16}, ctx, dl, seen::Enzyme.Compiler.TypeTreeTable) +function Enzyme.typetree_inner(::Type{Core.BFloat16}, ctx, dl, seen::Enzyme.Compiler.TypeTreeTable) return TypeTree(Enzyme.API.DT_BFloat16, -1, ctx) end From 700eaa38bec6e78c9c26f7e6c4c1054b4130a87c Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 9 Aug 2024 13:45:15 +0200 Subject: [PATCH 216/495] Fix doctests, remove Formatter CI, add caching to Documenter CI (#1721) --- .JuliaFormatter.toml | 1 - .github/workflows/CI.yml | 1 + .github/workflows/Format.yml | 40 ---------------------------- .github/workflows/scripts_deploy.yml | 1 + docs/src/index.md | 11 ++++---- src/Enzyme.jl | 7 +---- 6 files changed, 9 insertions(+), 52 deletions(-) delete mode 100644 .JuliaFormatter.toml delete mode 100644 .github/workflows/Format.yml diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml deleted file mode 100644 index 857c3ae3e5..0000000000 --- a/.JuliaFormatter.toml +++ /dev/null @@ -1 +0,0 @@ -style = "yas" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 20415db568..5093cf3a5a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -259,6 +259,7 @@ jobs: - uses: julia-actions/setup-julia@v1 with: version: '1' + - uses: julia-actions/cache@v1 - run: | julia --project=docs -e ' using Pkg diff --git a/.github/workflows/Format.yml b/.github/workflows/Format.yml deleted file mode 100644 index 88098a453e..0000000000 --- a/.github/workflows/Format.yml +++ /dev/null @@ -1,40 +0,0 @@ -on: - push: - branches: - - master - tags: '*' - pull_request: - types: - - opened - - reopened - - synchronize - - ready_for_review - -jobs: - format: - runs-on: ubuntu-20.04 - timeout-minutes: 30 - steps: - - uses: actions/checkout@v4.1.5 - - - uses: dorny/paths-filter@v3.0.2 - id: filter - with: - filters: | - julia_file_change: - - added|modified: '**.jl' - - - uses: julia-actions/setup-julia@latest - if: steps.filter.outputs.julia_file_change == 'true' - with: - version: 1.9 - - - name: Apply JuliaFormatter - if: steps.filter.outputs.julia_file_change == 'true' - run: | - julia --color=yes dev/flux_format.jl --verbose . - - - name: Check formatting diff - if: steps.filter.outputs.julia_file_change == 'true' - run: | - git diff --color=always --exit-code diff --git a/.github/workflows/scripts_deploy.yml b/.github/workflows/scripts_deploy.yml index 961a0bd3a6..bcfad08f6b 100644 --- a/.github/workflows/scripts_deploy.yml +++ b/.github/workflows/scripts_deploy.yml @@ -16,6 +16,7 @@ jobs: - uses: julia-actions/setup-julia@v1 with: version: '1' + - uses: julia-actions/cache@v1 - run: | julia --project=docs -e ' using Pkg diff --git a/docs/src/index.md b/docs/src/index.md index 7ea84296ad..1f7f092a99 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -184,13 +184,13 @@ julia> foo(x) = [rosenbrock_inp(x), prod(x)]; julia> output_size = Val(2) # here we have to provide the output size of `foo` since it cannot be statically inferred jacobian(Reverse, foo, [1.0, 2.0], output_size) -2×2 Matrix{Float64}: +2×2 transpose(::Matrix{Float64}) with eltype Float64: -400.0 200.0 2.0 1.0 julia> chunk_size = Val(2) # By specifying the optional chunk size argument, we can use vector inverse mode to propogate derivatives of multiple outputs at once. jacobian(Reverse, foo, [1.0, 2.0], output_size, chunk_size) -2×2 Matrix{Float64}: +2×2 transpose(::Matrix{Float64}) with eltype Float64: -400.0 200.0 2.0 1.0 @@ -217,7 +217,7 @@ julia> f(x) = sin(x[1] * x[2]); julia> hvp(f, [2.0, 3.0], [5.0, 2.7]) 2-element Vector{Float64}: - 19.6926882637302 + 19.69268826373025 16.201003759768003 ``` @@ -225,6 +225,7 @@ Enzyme also provides an in-place variant which will store the hessian vector pro ```jldoctest hvp2; filter = r"([0-9]+\\.[0-9]{8})[0-9]+" => s"\\1***" julia> f(x) = sin(x[1] * x[2]) +f (generic function with 1 method) julia> res = Vector{Float64}(undef, 2); @@ -232,7 +233,7 @@ julia> hvp!(res, f, [2.0, 3.0], [5.0, 2.7]); julia> res 2-element Vector{Float64}: - 19.6926882637302 + 19.69268826373025 16.201003759768003 ``` @@ -249,7 +250,7 @@ julia> hvp_and_gradient!(res, grad, f, [2.0, 3.0], [5.0, 2.7]) julia> res 2-element Vector{Float64}: - 19.6926882637302 + 19.69268826373025 16.201003759768003 julia> grad diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 5bdecbcceb..433906e4c8 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1494,7 +1494,7 @@ end end end -@inline function jacobian(::ReverseMode{ReturnPrimal,RABI, ErrIfFuncWritten}, f::F, x::X) where {ReturnPrimal, F, X, n_out_val,RABI<:ABI, ErrIfFuncWritten} +@inline function jacobian(::ReverseMode{ReturnPrimal,RABI, ErrIfFuncWritten}, f::F, x::X) where {ReturnPrimal, F, X, RABI<:ABI, ErrIfFuncWritten} res = f(x) jac = if res isa AbstractArray jacobian(ReverseMode{false,RABI, ErrIfFuncWritten}(), f, x, Val(length(jac))) @@ -1566,7 +1566,6 @@ res 16.201003759768003 ``` """ - @inline function hvp!(res::X, f::F, x::X, v::X) where {F, X} grad = make_zero(x) Enzyme.autodiff(Forward, gradient_deferred!, Const(Reverse), DuplicatedNoNeed(grad, res), Const(f), Duplicated(x, v)) @@ -1598,15 +1597,11 @@ hvp_and_gradient!(res, grad, f, [2.0, 3.0], [5.0, 2.7]) res grad # output -2-element Vector{Float64}: - 19.6926882637302 - 16.201003759768003 2-element Vector{Float64}: 2.880510859951098 1.920340573300732 ``` """ - @inline function hvp_and_gradient!(res::X, grad::X, f::F, x::X, v::X) where {F, X} Enzyme.autodiff(Forward, gradient_deferred!, Const(Reverse), Duplicated(grad, res), Const(f), Duplicated(x, v)) return nothing From 957f082b1a11c1eb237712a6e266c9bb18dac3d3 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 9 Aug 2024 13:46:28 +0200 Subject: [PATCH 217/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 67f42588ae..ff04441c83 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.28" +version = "0.12.29" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 02c58558ce381497a248a5be7eeed5a1a0335916 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 9 Aug 2024 12:41:07 -0700 Subject: [PATCH 218/495] Mark gc preserve as readonly (#1724) --- src/compiler/utils.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index fb2ee4714a..e4825e5226 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -120,6 +120,9 @@ function is_readonly(f::LLVM.Function) if intr == LLVM.Intrinsic("llvm.assume").id return true end + if LLVM.name(f) == "llvm.julia.gc_preserve_begin" || LLVM.name(f) == "llvm.julia.gc_preserve_end" + return true + end for attr in collect(function_attributes(f)) if kind(attr) == kind(EnumAttribute("readonly")) return true @@ -149,6 +152,9 @@ function is_readnone(f::LLVM.Function) if intr == LLVM.Intrinsic("llvm.assume").id return true end + if LLVM.name(f) == "llvm.julia.gc_preserve_begin" || LLVM.name(f) == "llvm.julia.gc_preserve_end" + return true + end for attr in collect(function_attributes(cur)) if kind(attr) == kind(EnumAttribute("readnone")) return true @@ -175,6 +181,9 @@ function is_writeonly(f::LLVM.Function) if intr == LLVM.Intrinsic("llvm.assume").id return true end + if LLVM.name(f) == "llvm.julia.gc_preserve_begin" || LLVM.name(f) == "llvm.julia.gc_preserve_end" + return true + end for attr in collect(function_attributes(cur)) if kind(attr) == kind(EnumAttribute("readnone")) return true From 08dd866888a421d3fb77b602f50445b589967156 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 9 Aug 2024 20:11:22 -0700 Subject: [PATCH 219/495] fix: BFloat16 extension and tests (#1725) --- Project.toml | 2 +- ext/EnzymeBFloat16sExt.jl | 2 +- test/Project.toml | 1 + test/ext/bfloat16s.jl | 7 +++++++ test/runtests.jl | 4 ++++ 5 files changed, 14 insertions(+), 2 deletions(-) create mode 100644 test/ext/bfloat16s.jl diff --git a/Project.toml b/Project.toml index ff04441c83..a02ade766c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.29" +version = "0.12.30" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/ext/EnzymeBFloat16sExt.jl b/ext/EnzymeBFloat16sExt.jl index 0fda13617e..c23797ffff 100644 --- a/ext/EnzymeBFloat16sExt.jl +++ b/ext/EnzymeBFloat16sExt.jl @@ -3,7 +3,7 @@ module EnzymeBFloat16sExt using BFloat16s using Enzyme -function Enzyme.typetree_inner(::Type{Core.BFloat16}, ctx, dl, seen::Enzyme.Compiler.TypeTreeTable) +function Enzyme.typetree_inner(::Type{BFloat16}, ctx, dl, seen::Enzyme.Compiler.TypeTreeTable) return TypeTree(Enzyme.API.DT_BFloat16, -1, ctx) end diff --git a/test/Project.toml b/test/Project.toml index 5c8286d1af..a3f8452712 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/test/ext/bfloat16s.jl b/test/ext/bfloat16s.jl new file mode 100644 index 0000000000..0a47f48f03 --- /dev/null +++ b/test/ext/bfloat16s.jl @@ -0,0 +1,7 @@ +using Enzyme +using Test +using BFloat16s + +@test_broken Enzyme.gradient(Reverse, sum, ones(BFloat16, 10)) ≈ ones(BFloat16, 10) + +@test_broken Enzyme.gradient(Forward, sum, ones(BFloat16, 10)) ≈ ones(BFloat16, 10) diff --git a/test/runtests.jl b/test/runtests.jl index 94015cfa4f..ac62137d35 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3693,6 +3693,10 @@ end include("ext/chainrulescore.jl") end include("ext/logexpfunctions.jl") + + @testset "BFloat16s ext" begin + include("ext/bfloat16s.jl") + end end From 74ecd029b8bc7308084e8e8c35ff103bdff92229 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Aug 2024 11:12:14 -0700 Subject: [PATCH 220/495] Precompilation simplification (#1715) * Precompilation simplification * fix * f * f * fix * restrict colon to floats * f * ix * fix * Update jitrules.jl * fixup * fix * fix --------- Co-authored-by: Valentin Churavy --- src/compiler.jl | 29 +- src/internal_rules.jl | 2 +- src/rules/jitrules.jl | 465 ++++++++++++++++++--------------- src/rules/typeunstablerules.jl | 179 +++++++------ test/usermixed.jl | 4 +- 5 files changed, 359 insertions(+), 320 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index e1ca4c395a..af5b32678d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5488,6 +5488,8 @@ function no_type_setting(@nospecialize(specTypes); world=nothing) return (false, false) end +const DumpPreOpt = Ref(false) + function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; libraries::Bool=true, deferred_codegen::Bool=true, optimize::Bool=true, toplevel::Bool=true, strip::Bool=false, validate::Bool=true, only_entry::Bool=false, parent_job::Union{Nothing, CompilerJob} = nothing) @@ -6084,6 +6086,10 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end end + if DumpPreOpt[] + API.EnzymeDumpModuleRef(mod.ref) + end + # Run early pipeline optimize!(mod, target_machine) @@ -6676,15 +6682,13 @@ end end if !RawCall && !(CC <: PrimalErrorThunk) - if rettype <: Active + if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated if length(argtypes) + is_adjoint + needs_tape != length(argexprs) return quote - throw(MethodError($CC(fptr), (fn, args...))) - end - end - elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated - if length(argtypes) + is_adjoint * width + needs_tape != length(argexprs) - return quote + @show $width + @show $(length(argtypes)), $is_adjoint, $needs_tape, $(length(argexprs)) + @show $argtypes + @show $argexprs throw(MethodError($CC(fptr), (fn, args...))) end end @@ -6879,15 +6883,8 @@ end NTuple{width, jlRT} end push!(types, j_drT) - if width == 1 || rettype <: Active - push!(ccexprs, argexprs[i]) - i+=1 - else - push!(ccexprs, quote - ($(argexprs[i:i+width-1]...),) - end) - i+=width - end + push!(ccexprs, argexprs[i]) + i+=1 end end diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 8e2976944a..9e82cdd5dc 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -823,7 +823,7 @@ end function EnzymeRules.forward(func::Const{Colon}, RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated, BatchDuplicated,BatchDuplicatedNoNeed}}, - start::Annotation, step::Annotation, stop::Annotation) + start::Annotation{<:AbstractFloat}, step::Annotation{<:AbstractFloat}, stop::Annotation{<:AbstractFloat}) ret = func.val(start.val, step.val, stop.val) dstart = if start isa Const zero(eltype(ret)) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 337b54ace6..58b407ba46 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -173,11 +173,24 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, end function body_runtime_generic_fwd(N, Width, wrapped, primtypes) - nnothing = ntuple(i->nothing, Val(Width+1)) - nres = ntuple(i->:(res[1]), Val(Width+1)) - ModifiedBetween = ntuple(i->false, Val(N+1)) - ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) - Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) + nnothing = Vector{Nothing}(undef, Width+1) + nres = Vector{Expr}(undef, Width+1) + fill!(nnothing, nothing) + fill!(nres, :(res[1])) + ModifiedBetween = Vector{Bool}(undef, N+1) + fill!(ModifiedBetween, false) + ElTypes = Vector{Expr}(undef, N) + Types = Vector{Expr}(undef, N) + for i in 1:N + @inbounds ElTypes[i] = :(eltype(Core.Typeof(args[$i]))) + @inbounds Types[i] = :(Core.Typeof(args[$i])) + end + + retres = if Width == 1 + :(return ReturnType((res[1], res[2]))) + else + :(return ReturnType((res[1], res[2]...))) + end return quote args = ($(wrapped...),) @@ -205,22 +218,18 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) world = codegen_world_age(FT, tt) opt_mi = Val(world) - forward = thunk(opt_mi, (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val($ModifiedBetween), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) + forward = thunk(opt_mi, (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val(($(ModifiedBetween...),)), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) res = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) if length(res) == 0 - return ReturnType($nnothing) + return ReturnType(($(nnothing...),)) end if annotation <: Const return ReturnType(($(nres...),)) end - if $Width == 1 - return ReturnType((res[1], res[2])) - else - return ReturnType((res[1], res[2]...)) - end + $retres end end @@ -242,17 +251,21 @@ end end function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) - nnothing = ntuple(i->nothing, Val(Width+1)) - nres = ntuple(i->:(origRet), Val(Width+1)) - nzeros = ntuple(i->:(Ref(make_zero(origRet))), Val(Width)) - nres3 = ntuple(i->:(res[3]), Val(Width)) - ElTypes = ntuple(i->:(eltype($(Symbol("type_$i")))), Val(N)) - - MakeTypes = ntuple(i->:($(Symbol("type_$i")) = Core.Typeof(args[$i])), Val(N)) - - Types = ntuple(i->Symbol("type_$i"), Val(N)) - - MixedTypes = ntuple(i->:($(Symbol("active_ref_$i") == MixedState) ? Ref($(Symbol("type_$i"))) : $(Symbol("type_$i"))), Val(N)) + nres = Vector{Symbol}(undef, Width+1) + fill!(nres, :origRet) + nzeros = Vector{Expr}(undef, Width) + fill!(nzeros, :(Ref(make_zero(origRet)))) + + ElTypes = Vector{Expr}(undef, N) + MakeTypes = Vector{Expr}(undef, N) + Types = Vector{Symbol}(undef, N) + MixedTypes = Vector{Expr}(undef, N) + for i in 1:N + @inbounds ElTypes[i] = :(eltype($(Symbol("type_$i")))) + @inbounds MakeTypes[i] = :($(Symbol("type_$i")) = Core.Typeof(args[$i])) + @inbounds Types[i] = Symbol("type_$i") + @inbounds MixedTypes[i] = :($(Symbol("active_ref_$i") == MixedState) ? Ref($(Symbol("type_$i"))) : $(Symbol("type_$i"))) + end ending = if Width == 1 quote @@ -279,7 +292,19 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) end end end - + + shadowretinit = if Width == 1 + :(Ref(make_zero(origRet))) + else + :(($(nzeros...),)) + end + + shadowretret = if Width == 1 + :(return ReturnType((origRet, shadow_return, tape))) + else + :(return ReturnType((origRet, shadow_return..., tape))) + end + return quote $(active_refs...) args = ($(wrapped...),) @@ -319,17 +344,9 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) return ReturnType(($(nres...), tape)) elseif annotation <: Active - if $Width == 1 - shadow_return = Ref(make_zero(origRet)) - else - shadow_return = ($(nzeros...),) - end + shadow_return = $shadowretinit tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - if $Width == 1 - return ReturnType((origRet, shadow_return, tape)) - else - return ReturnType((origRet, shadow_return..., tape)) - end + $shadowretret end $ending @@ -409,13 +426,14 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act shadowret = :(($(shadowret...),)) end - ElTypes = ntuple(i->:(eltype($(Symbol("type_$i")))), Val(N)) - - MakeTypes = ntuple(i->:($(Symbol("type_$i")) = Core.Typeof(args[$i])), Val(N)) - - Types = ntuple(i->Symbol("type_$i"), Val(N)) - - MixedTypes = ntuple(i->:($(Symbol("active_ref_$i") == MixedState) ? Ref($(Symbol("type_$i"))) : $(Symbol("type_$i"))), Val(N)) + ElTypes = Vector{Expr}(undef, N) + MakeTypes = Vector{Expr}(undef, N) + Types = Vector{Symbol}(undef, N) + for i in 1:N + @inbounds ElTypes[i] = :(eltype($(Symbol("type_$i")))) + @inbounds MakeTypes[i] = :($(Symbol("type_$i")) = Core.Typeof(args[$i])) + @inbounds Types[i] = Symbol("type_$i") + end quote $(active_refs...) @@ -446,14 +464,8 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) - tup = if annotation0 <: Active + tup = if annotation0 <: Active || annotation0 <: MixedDuplicated || annotation0 <: BatchMixedDuplicated adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] - elseif annotation0 <: MixedDuplicated || annotation0 <: BatchMixedDuplicated - if $Width == 1 - adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] - else - adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret..., tape.internal_tape)[1] - end else adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] end @@ -715,7 +727,10 @@ function fwddiff_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType end function body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) - wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) + wrappedexexpand = Vector{Expr}(undef, N) + for i in 1:N + @inbounds wrappedexexpand[i] = :($(wrapped[i])...) + end return quote $(active_refs...) args = ($(wrappedexexpand...),) @@ -742,50 +757,67 @@ end return body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) end -function primal_tuple(args::Vararg{Annotation, Nargs}) where Nargs - ntuple(Val(Nargs)) do i +@generated function primal_tuple(args::Vararg{Annotation, Nargs}) where Nargs + expr = Vector{Expr}(undef, Nargs) + for i in 1:Nargs + @inbounds expr[i] = :(args[$i].val) + end + return quote Base.@_inline_meta - args[i].val + ($(expr...),) end end -function shadow_tuple(::Type{Ann}, ::Val{1}, args::Vararg{Annotation, Nargs}) where {Ann, Nargs} - res = ntuple(Val(Nargs)) do i - Base.@_inline_meta - @assert !(args[i] isa Active) - if args[i] isa Const - args[i].val - elseif args[i] isa MixedDuplicated - args[i].dval[] - else - args[i].dval +@generated function shadow_tuple(::Type{Ann}, ::Val{1}, args::Vararg{Annotation, Nargs}) where {Ann, Nargs} + expr = Vector{Expr}(undef, Nargs) + for i in 1:Nargs + @inbounds expr[i] = quote + @assert !(args[$i] isa Active) + if args[$i] isa Const + args[$i].val + elseif args[$i] isa MixedDuplicated + args[$i].dval[] + else + args[$i].dval + end end end + rval = :(($(expr...),)) if Ann <: MixedDuplicated - Ref(res) - else - res + rval = :(Ref($rval)) end -end - -function shadow_tuple(::Type{Ann}, ::Val{width}, args::Vararg{Annotation, Nargs}) where {Ann, width, Nargs} - ntuple(Val(width)) do w - res = ntuple(Val(Nargs)) do i - Base.@_inline_meta - @assert !(args[i] isa Active) - if args[i] isa Const - args[i].val - elseif args[i] isa BatchMixedDuplicated - args[i].dval[w][] - else - args[i].dval[w] + return quote + Base.@_inline_meta + $rval + end +end + +@generated function shadow_tuple(::Type{Ann}, ::Val{width}, args::Vararg{Annotation, Nargs}) where {Ann, width, Nargs} + wexpr = Vector{Expr}(undef, width) + for w in 1:width + expr = Vector{Expr}(undef, Nargs) + for i in 1:Nargs + @inbounds expr[i] = quote + @assert !(args[$i] isa Active) + if args[$i] isa Const + args[$i].val + elseif args[$i] isa BatchMixedDuplicated + args[$i].dval[$w][] + else + args[$i].dval[$w] + end end end + rval = :(($(expr...),)) if Ann <: BatchMixedDuplicated - Ref(res) - else - res + rval = :(Ref($rval)) end + @inbounds wexpr[w] = rval + end + + return quote + Base.@_inline_meta + ($(wexpr...),) end end @@ -887,10 +919,13 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} end function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) - wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) - results = Expr[] + wrappedexexpand = Vector{Expr}(undef, N) + for i in 1:N + @inbounds wrappedexexpand[i] = :($(wrapped[i])...) + end + results = Vector{Expr}(undef, Width+1) for i in 1:(Width+1) - push!(results, :(tmpvals[$i])) + results[i] = :(tmpvals[$i]) end return quote refs = Base.RefValue[] @@ -935,148 +970,156 @@ function add_into_vec!(val::T, expr, vec, idx_in_vec) where T end # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] -function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween0}, ::Val{lengths}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, tape, shadowargs, args::Vararg{Annotation, Nargs})::Nothing where {width, dupClosure0, ModifiedBetween0, lengths, FT, tt′, DF, Nargs} - ReturnPrimal = Val(true) - ModifiedBetween = Val(ModifiedBetween0) - - dupClosure = dupClosure0 && !guaranteed_const(FT) - FA = dupClosure ? Duplicated{FT} : Const{FT} - - tt = Enzyme.vaEltypes(tt′) +@generated function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween0}, ::Val{lengths}, ::Type{FT}, ::Type{ttp}, f::FT, df::DF, tape, shadowargs, args::Vararg{Annotation, Nargs})::Nothing where {width, dupClosure0, ModifiedBetween0, lengths, FT, ttp, DF, Nargs} - rt = Core.Compiler.return_type(f, tt) - annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - - annotation = if width != 1 - if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated - BatchDuplicated{rt, width} - elseif annotation0 <: MixedDuplicated - BatchMixedDuplicated{rt, width} - elseif annotation0 <: Active - Active{rt} - else - Const{rt} - end - else - if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated - Duplicated{rt} - elseif annotation0 <: MixedDuplicated - MixedDuplicated{rt} - elseif annotation0 <: Active - Active{rt} + nontupexprs = Vector{Expr}(undef, Nargs) + for i in 1:Nargs + mid = if width == 1 + :(tape.shadow_return[][$i]) else - Const{rt} + mexprs = Vector{Expr}(undef, width) + for w in 1:width + @inbounds mexprs[w] = :(tape.shadow_return[$w][][$i]) + end + quote + ($(mexprs...),) + end end - end - - tup = if f != Base.tuple - world = codegen_world_age(FT, tt) - fa = if dupClosure - if width == 1 - Duplicated(f, df) + @inbounds nontupexprs[i] = quote + if args[$i] isa Active || args[$i] isa MixedDuplicated || args[$i] isa BatchMixedDuplicated + $mid else - BatchDuplicated(f, df) + nothing end - else - Const(f) end - opt_mi = Val(world) - forward, adjoint = thunk(opt_mi, FA, - annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) - - args2 = if tape.shadow_return !== nothing - if width == 1 - (args..., tape.shadow_return[]) - else - shads = ntuple(Val(width)) do w - Base.@_inline_meta - tape.shadow_return[w][] - end - if annotation <: MixedDuplicated || annotation <: BatchMixedDuplicated - (args..., shads...,) - else - (args..., shads) + end + + endexprs = Matrix{Expr}(undef, Nargs, width) + for i in 1:Nargs + for w in 1:width + @inbounds endexprs[i, w] = quote + if args[$i] isa Active || args[$i] isa MixedDuplicated || args[$i] isa BatchMixedDuplicated + expr = if args[$i] isa Active || f == Base.tuple + if $width == 1 + tup[$i] + else + tup[$i][$w] + end + elseif args[$i] isa MixedDuplicated + args[$i].dval[] + else + # if args[$i] isa BatchMixedDuplicated + args[$i].dval[$w][] + end + + idx_of_vec, idx_in_vec = $(lengths[i]) + vec = @inbounds shadowargs[idx_of_vec][$w] + if vec isa Base.RefValue + vecld = vec[] + T = Core.Typeof(vecld) + vec[] = recursive_index_add(T, vecld, Val(idx_in_vec), expr) + else + val = @inbounds vec[idx_in_vec] + add_into_vec!(Base.inferencebarrier(val), expr, vec, idx_in_vec) + end end end - else - args end + end - adjoint(fa, args2..., tape.internal_tape)[1] + tgen = if FT == typeof(Base.tuple) + :(tup = ($(nontupexprs...),)) else - ntuple(Val(Nargs)) do i - Base.@_inline_meta - if args[i] isa Active - if width == 1 - tape.shadow_return[][i] + annotation = if width != 1 + quote + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + BatchDuplicated{rt, $width} + elseif annotation0 <: MixedDuplicated + BatchMixedDuplicated{rt, $width} + elseif annotation0 <: Active + Active{rt} else - ntuple(Val(width)) do w - Base.@_inline_meta - tape.shadow_return[w][][i] - end + Const{rt} end - elseif args[i] isa MixedDuplicated || args[i] isa BatchMixedDuplicated - if width == 1 - tape.shadow_return[][i] + end + else + quote + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + Duplicated{rt} + elseif annotation0 <: MixedDuplicated + MixedDuplicated{rt} + elseif annotation0 <: Active + Active{rt} else - ntuple(Val(width)) do w - Base.@_inline_meta - tape.shadow_return[w][][i] - end + Const{rt} end - else - nothing end end - end - ntuple(Val(Nargs)) do i - Base.@_inline_meta + shadadj = if width == 1 + :(adjoint(fa, args..., tape.shadow_return[], tape.internal_tape)[1]) + else + margs = Vector{Expr}(undef, width) + for w in 1:width + @inbounds margs[w] = :(tape.shadow_return[$w][]) + end + :(adjoint(fa, args..., ($(margs...),), tape.internal_tape)[1]) + end - ntuple(Val(width)) do w - Base.@_inline_meta - if args[i] isa Active || args[i] isa MixedDuplicated || args[i] isa BatchMixedDuplicated - expr = if args[i] isa Active || f == Base.tuple - if width == 1 - tup[i] - else - tup[i][w] - end - elseif args[i] isa MixedDuplicated - args[i].dval[] - else - # if args[i] isa BatchMixedDuplicated - args[i].dval[w][] - end + tt = Enzyme.vaEltypes(ttp) - idx_of_vec, idx_in_vec = lengths[i] - vec = @inbounds shadowargs[idx_of_vec][w] - if vec isa Base.RefValue - vecld = vec[] - T = Core.Typeof(vecld) - vec[] = splatnew(T, ntuple(Val(fieldcount(T))) do i - Base.@_inline_meta - prev = getfield(vecld, i) - if i == idx_in_vec - recursive_add(prev, expr, identity, guaranteed_nonactive) - else - prev - end - end) - else - val = @inbounds vec[idx_in_vec] - add_into_vec!(Base.inferencebarrier(val), expr, vec, idx_in_vec) - end - end + quote + ReturnPrimal = Val(true) + ModifiedBetween = Val($ModifiedBetween0) + + dupClosure = $dupClosure0 && !guaranteed_const($FT) + FA = dupClosure ? Duplicated{$FT} : Const{$FT} - nothing + tt = $tt + + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + + annotation = $annotation + world = codegen_world_age(FT, tt) + + fa = if dupClosure + $(width == 1 ? :Duplicated : :BatchDuplicated)(f, df) + else + Const(f) + end + opt_mi = Val(world) + forward, adjoint = thunk(opt_mi, FA, + annotation, $ttp, Val(API.DEM_ReverseModePrimal), Val($width), + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) + + tup = if tape.shadow_return !== nothing + $shadadj + else + adjoint(fa, args..., tape.internal_tape)[1] + end end + end + return quote + $tgen + $(endexprs...) nothing end - nothing +end + +@generated function ntuple_pair(::Val{Len}, ::Val{i}) where {Len, i} + mexprs = Vector{Expr}(undef, Len) + for j in 1:Len + @inbounds mexprs[j] = quote + ($i, $j) + end + end + quote + Base.@_inline_meta + ($(mexprs...),) + end end function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shadowargs, active_refs) @@ -1084,23 +1127,23 @@ function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shado if Width == 1 shadowret = :(tape.shadow_return[]) else - shadowret = [] + shadowret = Expr[] for w in 1:Width push!(shadowret, :(tape.shadow_return[$w][])) end shadowret = :(($(shadowret...),)) end - ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) - Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) - - wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) - lengths = ntuple(i->quote - (ntuple(Val(length($(primargs[i])))) do j - Base.@_inline_meta - ($i, j) - end) - end, Val(N)) + wrappedexexpand = Vector{Expr}(undef, N) + for i in 1:N + wrappedexexpand[i] = :($(wrapped[i])...) + end + lengths = Vector{Expr}(undef, N) + for i in 1:N + lengths[i] = quote + ntuple_pair(Val(length($(primargs[i]))), Val($i)) + end + end shadowsplat = Expr[] for s in shadowargs diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index dafc367ef3..42fcbe18cf 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -598,6 +598,47 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) return false end +@generated function ntuple_ref_zero(::Val{N}, ::Type{RT}, res) where {N, RT} + expr = Vector{Expr}(undef, N) + fill!(expr, :(Ref{$RT}(make_zero(res)))) + return quote + Base.@_inline_meta + ($(expr...),) + end +end + +@generated function ntuple_ref_lookup(::Val{N}, ::Type{RT}, dptrs, symname) where {N, RT} + expr = Vector{Expr}(undef, N) + for i in 1:N + @inbounds expr[i] = quote + begin + dv = dptrs[$i] + Ref{RT}(getfield(dv isa Base.RefValue ? dv[] : dv, symname)) + end + end + end + return quote + Base.@_inline_meta + ($(expr...),) + end +end + +@generated function ntuple_lookup(::Val{N}, ptrs, symname) where {N} + expr = Vector{Expr}(undef, N) + for i in 1:N + @inbounds expr[i] = quote + begin + dv = ptrs[$i] + getfield(dv isa Base.RefValue ? dv[] : dv, symname) + end + end + end + return quote + Base.@_inline_meta + ($(expr...),) + end +end + function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {NT, T, T2, Nargs, symname, isconst} res = if dptr isa Base.RefValue Base.getfield(dptr[], symname) @@ -611,41 +652,27 @@ function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isco if length(dptrs) == 0 return Ref{RT}(make_zero(res)) else - return NT(ntuple(Val(1+length(dptrs))) do i - Base.@_inline_meta - Ref{RT}(make_zero(res)) - end) + return NT(ntuple_ref_zero(Val(1+length(dptrs)), RT, res)) end elseif actreg == MixedState if length(dptrs) == 0 return Ref{RT}(res) else - fval = NT((Ref{RT}(res), (ntuple(Val(length(dptrs))) do i - Base.@_inline_meta - dv = dptrs[i] - Ref{RT}(getfield(dv isa Base.RefValue ? dv[] : dv, symname)) - end)...)) + fval = NT((Ref{RT}(res), ntuple_ref_lookup(Val(length(dptrs)), RT, dptrs, symname)...)) return fval end elseif isconst if length(dptrs) == 0 return make_zero(res) else - fval = NT((res, (ntuple(Val(length(dptrs))) do i - Base.@_inline_meta - make_zero(res) - end)...)) + fval = NT((res, ntuple_ref_zero(Val(length(dptrs)), RT, res)...)) return fval end else if length(dptrs) == 0 return res else - fval = NT((res, (ntuple(Val(length(dptrs))) do i - Base.@_inline_meta - dv = dptrs[i] - getfield(dv isa Base.RefValue ? dv[] : dv, symname) - end)...)) + fval = NT((res, ntuple_lookup(Val(length(dptrs)), dptrs, symname)...)) return fval end end @@ -663,46 +690,49 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc if length(dptrs) == 0 return Ref{RT}(make_zero(res))::Any else - return NT(ntuple(Val(1+length(dptrs))) do i - Base.@_inline_meta - Ref{RT}(make_zero(res)) - end) + return NT(ntuple_ref_zero(Val(1+length(dptrs)), RT, res)) end elseif actreg == MixedState if length(dptrs) == 0 return Ref{RT}(res) else - fval = NT((Ref{RT}(res), (ntuple(Val(length(dptrs))) do i - Base.@_inline_meta - dv = dptrs[i] - Ref{RT}(getfield(dv isa Base.RefValue ? dv[] : dv, symname+1)) - end)...)) + fval = NT((Ref{RT}(res), ntuple_ref_lookup(Val(length(dptrs)), RT, dptrs, symname+1)...)) return fval end elseif isconst if length(dptrs) == 0 return make_zero(res)::Any else - fval = NT((res, (ntuple(Val(length(dptrs))) do i - Base.@_inline_meta - make_zero(res) - end)...)) + fval = NT((res, ntuple_ref_zero(Val(length(dptrs)), RT, res)...)) return fval end else if length(dptrs) == 0 return res::Any else - fval = NT((res, (ntuple(Val(length(dptrs))) do i - Base.@_inline_meta - dv = dptrs[i] - getfield(dv isa Base.RefValue ? dv[] : dv, symname+1) - end)...)) + fval = NT((res, ntuple_lookup(Val(length(dptrs)), dptrs, symname+1)...)) return fval end end end +@generated function recursive_field_add(::Type{dRT}, vload, ::Val{symname}, dret) where {dRT, symname} + N = fieldcount(dRT) + exprs = Vector{Expr}(undef, N) + for i in 1:N + @inbounds exprs[i] = if fieldname(dRT, i) == symname + :(recursive_add(getfield(vload, $i), dret, identity, guaranteed_nonactive)) + else + :(getfield(vload, $i)) + end + end + res = Expr(:splatnew, dRT, :(($(exprs...)),)) + return quote + Base.@_inline_meta + $res + end +end + function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst} cur = if dptr isa Base.RefValue getfield(dptr[], symname) @@ -718,15 +748,7 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, if dptr isa Base.RefValue vload = dptr[] dRT = Core.Typeof(vload) - dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do i - Base.@_inline_meta - prev = getfield(vload, i) - if fieldname(dRT, i) == symname - recursive_add(prev, dret[], identity, guaranteed_nonactive) - else - prev - end - end) + dptr[] = recursive_field_add(dRT, vload, Val(symname), dret[]) else setfield!(dptr, symname, recursive_add(cur, dret[], identity, guaranteed_nonactive)) end @@ -734,15 +756,7 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, if dptr isa Base.RefValue vload = dptr[] dRT = Core.Typeof(vload) - dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j - Base.@_inline_meta - prev = getfield(vload, j) - if fieldname(dRT, j) == symname - recursive_add(prev, dret[1][], identity, guaranteed_nonactive) - else - prev - end - end) + dptr[] = recursive_field_add(dRT, vload, Val(symname), dret[1][]) else setfield!(dptr, symname, recursive_add(cur, dret[1][])) end @@ -750,15 +764,7 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, if dptrs[i] isa Base.RefValue vload = dptrs[i][] dRT = Core.Typeof(vload) - dptrs[i][] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j - Base.@_inline_meta - prev = getfield(vload, j) - if fieldname(dRT, j) == symname - recursive_add(prev, dret[1+i][], identity, guaranteed_nonactive) - else - prev - end - end) + dptrs[i][] = recursive_field_add(dRT, vload, Val(symname), dret[1+i][]) else curi = if dptr isa Base.RefValue Base.getfield(dptrs[i][], symname) @@ -773,6 +779,23 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, return nothing end +@generated function recursive_index_add(::Type{dRT}, vload, ::Val{symname}, dret) where {dRT, symname} + N = fieldcount(dRT) + exprs = Vector{Expr}(undef, N) + for i in 1:N + @inbounds exprs[i] = if i == symname + :(recursive_add(getfield(vload, $i), dret, identity, guaranteed_nonactive)) + else + :(getfield(vload, $i)) + end + end + res = Expr(:splatnew, dRT, :(($(exprs...)),)) + return quote + Base.@_inline_meta + $res + end +end + function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst} cur = if dptr isa Base.RefValue Base.getfield(dptr[], symname+1) @@ -788,15 +811,7 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} if dptr isa Base.RefValue vload = dptr[] dRT = Core.Typeof(vload) - dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do i - Base.@_inline_meta - prev = getfield(vload, i) - if i == symname+1 - recursive_add(prev, dret[], identity, guaranteed_nonactive) - else - prev - end - end) + dptr[] = recursive_index_add(dRT, vload, Val(symname+1), dret[]) else setfield!(dptr, symname+1, recursive_add(cur, dret[], identity, guaranteed_nonactive)) end @@ -804,15 +819,7 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} if dptr isa Base.RefValue vload = dptr[] dRT = Core.Typeof(vload) - dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j - Base.@_inline_meta - prev = getfield(vload, j) - if j == symname+1 - recursive_add(prev, dret[1][], identity, guaranteed_nonactive) - else - prev - end - end) + dptr[] = recursive_index_add(dRT, vload, Val(symname+1), dret[1][]) else setfield!(dptr, symname+1, recursive_add(cur, dret[1][], identity, guaranteed_nonactive)) end @@ -820,15 +827,7 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} if dptrs[i] isa Base.RefValue vload = dptrs[i][] dRT = Core.Typeof(vload) - dptrs[i][] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j - Base.@_inline_meta - prev = getfield(vload, j) - if j == symname+1 - recursive_add(prev, dret[1+i][], identity, guaranteed_nonactive) - else - prev - end - end) + dptrs[i][] = recursive_index_add(dRT, vload, Val(symname+1), dret[1+i][]) else curi = if dptr isa Base.RefValue Base.getfield(dptrs[i][], symname+1) diff --git a/test/usermixed.jl b/test/usermixed.jl index f97c5737ec..e48dba1b72 100644 --- a/test/usermixed.jl +++ b/test/usermixed.jl @@ -45,7 +45,7 @@ end @test dres[2][][1] ≈ 0.0 @test dres[2][][2] === dy2 - outs = rev(Const(user_mixret), Active(x), BatchDuplicated(y, (dy, dy2)), (47.0, dy), (56.0, dy), tape) + outs = rev(Const(user_mixret), Active(x), BatchDuplicated(y, (dy, dy2)), ((47.0, dy), (56.0, dy)), tape) @test outs[1][1][1] ≈ 47.0 @test outs[1][1][2] ≈ 56.0 @@ -86,7 +86,7 @@ end @test dres[1][] ≈ 0.0 @test dres[2][] ≈ 0.0 - outs = rev(Const(user_fltret), Active(x), BatchDuplicated(y, (dy, dy2)), 47.0, 56.0, tape) + outs = rev(Const(user_fltret), Active(x), BatchDuplicated(y, (dy, dy2)), (47.0, 56.0), tape) @test outs[1][1][1] ≈ 47.0 @test outs[1][1][2] ≈ 56.0 From 15a19d2e4c8fa399bc001be91d4d644633c703c3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Aug 2024 14:06:57 -0700 Subject: [PATCH 221/495] boxfloat fixup (#1726) --- src/rules/llvmrules.jl | 61 +++++++++++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 664d643af0..7b27cf298e 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -589,7 +589,12 @@ end @register_fwd function boxfloat_fwd(B, orig, gutils, normalR, shadowR) origops = collect(operands(orig)) width = get_width(gutils) - if is_constant_value(gutils, orig) + + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if is_constant_value(gutils, orig) || needsShadowP[] == 0 return true end @@ -616,7 +621,12 @@ end @register_aug function boxfloat_augfwd(B, orig, gutils, normalR, shadowR, tapeR) origops = collect(operands(orig)) width = get_width(gutils) - if is_constant_value(gutils, orig) + + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if is_constant_value(gutils, orig) || needsShadowP[] == 0 return true end @@ -642,30 +652,37 @@ end end @register_rev function boxfloat_rev(B, orig, gutils, tape) + + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if is_constant_value(gutils, orig) || needsShadowP[] == 0 + return nothing + end + origops = collect(operands(orig)) width = get_width(gutils) - if !is_constant_value(gutils, orig) - ip = lookup_value(gutils, invert_pointer(gutils, orig, B), B) - flt = value_type(origops[1]) - if width == 1 - ipc = bitcast!(B, ip, LLVM.PointerType(flt, addrspace(value_type(orig)))) + ip = lookup_value(gutils, invert_pointer(gutils, orig, B), B) + flt = value_type(origops[1]) + if width == 1 + ipc = bitcast!(B, ip, LLVM.PointerType(flt, addrspace(value_type(orig)))) + ld = load!(B, flt, ipc) + store!(B, ConstantFP(flt, 0.0), ipc) + if !is_constant_value(gutils, origops[1]) + API.EnzymeGradientUtilsAddToDiffe(gutils, origops[1], ld, B, flt) + end + else + shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, flt))) + for idx in 1:width + ipc = extract_value!(B, ip, idx-1) + ipc = bitcast!(B, ipc, LLVM.PointerType(flt, addrspace(value_type(orig)))) ld = load!(B, flt, ipc) store!(B, ConstantFP(flt, 0.0), ipc) - if !is_constant_value(gutils, origops[1]) - API.EnzymeGradientUtilsAddToDiffe(gutils, origops[1], ld, B, flt) - end - else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, flt))) - for idx in 1:width - ipc = extract_value!(B, ip, idx-1) - ipc = bitcast!(B, ipc, LLVM.PointerType(flt, addrspace(value_type(orig)))) - ld = load!(B, flt, ipc) - store!(B, ConstantFP(flt, 0.0), ipc) - shadowres = insert_value!(B, shadowres, ld, idx-1) - end - if !is_constant_value(gutils, origops[1]) - API.EnzymeGradientUtilsAddToDiffe(gutils, origops[1], shadowret, B, flt) - end + shadowres = insert_value!(B, shadowres, ld, idx-1) + end + if !is_constant_value(gutils, origops[1]) + API.EnzymeGradientUtilsAddToDiffe(gutils, origops[1], shadowret, B, flt) end end return nothing From cf619b33d7ba2524ed22f1eae715f2be0792bced Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Aug 2024 15:36:35 -0700 Subject: [PATCH 222/495] Unhash the jl_idtable_rehash call (#1727) --- src/compiler/validation.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 8bf562addf..9c8bc3921c 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -100,6 +100,7 @@ module FFI "jl_array_isassigned", "ijl_array_isassigned", "jl_array_ptr_copy", "ijl_array_ptr_copy", "jl_array_typetagdata", "ijl_array_typetagdata", + "jl_idtable_rehash" ) for name in known_names sym = LLVM.find_symbol(name) From 0cf47c18f17ad9b47cb848252a965a0a255bba04 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Aug 2024 21:53:02 -0700 Subject: [PATCH 223/495] More bfloat fix (#1728) * More bfloat fix * Update bfloat16s.jl * Update bfloat16s.jl --- ext/EnzymeBFloat16sExt.jl | 2 +- src/typetree.jl | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/ext/EnzymeBFloat16sExt.jl b/ext/EnzymeBFloat16sExt.jl index c23797ffff..050825aef0 100644 --- a/ext/EnzymeBFloat16sExt.jl +++ b/ext/EnzymeBFloat16sExt.jl @@ -4,7 +4,7 @@ using BFloat16s using Enzyme function Enzyme.typetree_inner(::Type{BFloat16}, ctx, dl, seen::Enzyme.Compiler.TypeTreeTable) - return TypeTree(Enzyme.API.DT_BFloat16, -1, ctx) + return Enzyme.TypeTree(Enzyme.API.DT_BFloat16, -1, ctx) end end diff --git a/src/typetree.jl b/src/typetree.jl index 065dccbbd8..ae4329c172 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -121,11 +121,6 @@ function typetree_inner(::Type{BigFloat}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where {T<:AbstractFloat} - GPUCompiler.@safe_warn "Unknown floating point type" T - return TypeTree() -end - function typetree_inner(::Type{<:DataType}, ctx, dl, seen::TypeTreeTable) return TypeTree() end @@ -225,6 +220,10 @@ function typetree_inner(@nospecialize(T), ctx, dl, seen::TypeTreeTable) end end + if T <: AbstractFloat + throw(AssertionError("Unknown floating point type $T")) + end + try fieldcount(T) catch From 410a8662c5920e11e5eee22db3939e616438b512 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 12 Aug 2024 09:49:11 -0400 Subject: [PATCH 224/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a02ade766c..e9a224b325 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.8" -Enzyme_jll = "0.0.142" +Enzyme_jll = "0.0.143" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" LogExpFunctions = "0.3" From 65770c98bee7ef7da57bfd135e6f8f91cee9823c Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 12 Aug 2024 07:45:10 -0700 Subject: [PATCH 225/495] Update compiler.jl (#1729) --- src/compiler.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index af5b32678d..b04550fe0a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5538,8 +5538,8 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; disableFallback = String[] - ForwardModeDerivatives = ("nrm2","dot","gemm","gemv","axpy","copy","scal", "syrk", "potrf") - ReverseModeDerivatives = ("nrm2","dot","gemm","gemv","axpy","copy","scal", "trmv", "syrk", "trmm", "trsm", "potrf") + ForwardModeDerivatives = ("nrm2","dot","gemm","gemv","axpy","copy","scal", "symm", "syrk", "potrf") + ReverseModeDerivatives = ("nrm2","dot","gemm","gemv","axpy","copy","scal", "symm", "trmv", "syrk", "trmm", "trsm", "potrf") ForwardModeTypes = ("s", "d", "c", "z") ReverseModeTypes = ("s", "d") # Tablegen BLAS does not support forward mode yet From 95fd3f69587b0f790243b4e362a493ed59890e2a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 13 Aug 2024 18:35:42 -0700 Subject: [PATCH 226/495] chore: bump jll to 0.0.144 (#1731) --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index e9a224b325..c3697f8888 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.30" +version = "0.12.31" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -35,7 +35,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.8" -Enzyme_jll = "0.0.143" +Enzyme_jll = "0.0.144" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" LogExpFunctions = "0.3" From e1edaf7482dee3be5bcf35afb7527cea808a1f07 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 17 Aug 2024 12:52:01 -0500 Subject: [PATCH 227/495] Fix cuda tape type (#1737) * Fix cuda tape type * bump --- Project.toml | 2 +- src/Enzyme.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index c3697f8888..3f83fafc0a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.31" +version = "0.12.32" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 433906e4c8..b7d86b4705 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -783,7 +783,7 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType params = Compiler.EnzymeCompilerParams( Tuple{FA, TT.parameters...}, API.DEM_ReverseModeGradient, width, Compiler.remove_innerty(A), true, #=abiwrap=#false, ModifiedBetweenT, - ReturnPrimal, #=ShadowInit=#false, Compiler.UnknownTapeType, RABI + ReturnPrimal, #=ShadowInit=#false, Compiler.UnknownTapeType, RABI, #=errifwritte=#false ) job = Compiler.CompilerJob(mi, Compiler.CompilerConfig(target, params; kernel=false)) From ffc1035294db29f80222cb10b804faedd8a97069 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 20 Aug 2024 16:47:04 -0500 Subject: [PATCH 228/495] Fix return diffe (#1740) * Fix return diffe * Add test --- src/compiler/optimize.jl | 3 +++ src/internal_rules.jl | 6 ----- src/rules/jitrules.jl | 8 +++---- src/rules/llvmrules.jl | 2 +- test/runtests.jl | 48 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 56 insertions(+), 11 deletions(-) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index a5a4908def..d27c7800c5 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -701,6 +701,9 @@ function nodecayed_phis!(mod::LLVM.Module) if addr == 11 && isa(v, LLVM.ConstantExpr) if opcode(v) == LLVM.API.LLVMAddrSpaceCast v2 = operands(v)[1] + if addrspace(value_type(v2)) == 10 + return v2, offset, hasload + end if addrspace(value_type(v2)) == 0 if addr == 11 v2 = const_addrspacecast(v2, LLVM.PointerType(eltype(value_type(v)), 10)) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 9e82cdd5dc..c2a31c2e4d 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -66,15 +66,9 @@ end function EnzymeRules.inactive(::typeof(Core.kwfunc), args...) return nothing end -function EnzymeRules.inactive(::typeof(Random.rand), ::Random.AbstractRNG, ::Random.Sampler) - return nothing -end function EnzymeRules.inactive(::typeof(Random.rand!), ::Random.AbstractRNG, ::Random.Sampler, ::AbstractArray) return nothing end -function EnzymeRules.inactive(::typeof(Random.randn), args...) - return nothing -end function EnzymeRules.inactive(::typeof(Random.randn!), args...) return nothing end diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 58b407ba46..bdcdd79b25 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1445,7 +1445,7 @@ end function common_generic_rev(offset, B, orig, gutils, tape)::Cvoid needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return nothing @@ -1562,7 +1562,7 @@ end function common_apply_latest_rev(offset, B, orig, gutils, tape)::Cvoid needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return nothing @@ -1759,7 +1759,7 @@ end function common_apply_iterate_rev(offset, B, orig, gutils, tape) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return nothing @@ -1882,7 +1882,7 @@ end function common_invoke_rev(offset, B, orig, gutils, tape) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return nothing diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 7b27cf298e..3d578f4dfc 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -655,7 +655,7 @@ end needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) if is_constant_value(gutils, orig) || needsShadowP[] == 0 return nothing diff --git a/test/runtests.jl b/test/runtests.jl index ac62137d35..114dfb6833 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3268,6 +3268,54 @@ end end end +@inline function uns_mymean(f, A, ::Type{T}, c) where T + c && return Base.inferencebarrier(nothing) + x1 = f(@inbounds A[1]) / 1 + return @inbounds A[1][1] +end + +function uns_sum2(x::Array{T})::T where T + op = Base.add_sum + itr = x + y = iterate(itr)::Tuple{T, Int} + v = y[1]::T + while true + y = iterate(itr, y[2]) + y === nothing && break + v = (v + y[1])::T + end + return v +end + +function uns_ad_forward(scale_diag::Vector{T}, c) where T + ccall(:jl_, Cvoid, (Any,), scale_diag) + res = uns_mymean(uns_sum2, [scale_diag,], T, c) + return res +end + +@testset "Split box float32" begin + q = ones(Float32, 1) + dx = make_zero(q) + res, y = Enzyme.autodiff( + Enzyme.ReverseWithPrimal, + uns_ad_forward, + Enzyme.Active, + Enzyme.Duplicated(q, dx), + Enzyme.Const(false), + ) + @test dx ≈ Float32[1.0] + q = ones(Float64, 1) + dx = make_zero(q) + res, y = Enzyme.autodiff( + Enzyme.ReverseWithPrimal, + uns_ad_forward, + Enzyme.Active, + Enzyme.Duplicated(q, dx), + Enzyme.Const(false), + ) + @test dx ≈ Float64[1.0] +end + @static if VERSION < v"1.8-" || VERSION >= v"1.9-" @inline extract_bc(bc, ::Val{:north}) = (bc.north) @inline extract_bc(bc, ::Val{:top}) = (bc.top) From c75bbd4610b73c57e4acebaa852974abca44a538 Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Tue, 20 Aug 2024 21:28:43 -0400 Subject: [PATCH 229/495] small fix for static array onehot (#1732) * small fix for static array onehot * weaken tests for old julia versions --- ext/EnzymeStaticArraysExt.jl | 2 +- test/runtests.jl | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index 672d1c03bc..6dbd390cb7 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -14,7 +14,7 @@ end ntuple(Val(endl-start+1)) do i Base.@_inline_meta StaticArrays.SArray{S, T, N, L}( - ntuple(Val(N)) do idx + ntuple(Val(L)) do idx Base.@_inline_meta return (i + start - 1 == idx) ? 1.0 : 0.0 end) diff --git a/test/runtests.jl b/test/runtests.jl index 114dfb6833..113eb6f531 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2774,6 +2774,22 @@ end @test dx isa SArray @test dx ≈ [0 30 0] + x = @SVector [1.0, 2.0, 3.0] + y = onehot(x) + # this should be a very specific type of SArray, but there + # is a bizarre issue with older julia versions where it can be MArray + @test eltype(y) <: StaticVector + @test length(y) == 3 + @test y[1] == [1.0, 0.0, 0.0] + @test y[2] == [0.0, 1.0, 0.0] + @test y[3] == [0.0, 0.0, 1.0] + + y = onehot(x, 2, 3) + @test eltype(y) <: StaticVector + @test length(y) == 2 + @test y[1] == [0.0, 1.0, 0.0] + @test y[2] == [0.0, 0.0, 1.0] + @static if VERSION ≥ v"1.9-" x = @SArray [5.0 0.0 6.0] dx = Enzyme.gradient(Forward, prod, x) From e7be3ce315f7895bee6fe65ad84e02ce24833bb8 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 21 Aug 2024 08:29:58 -0500 Subject: [PATCH 230/495] Update Project.toml (#1742) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3f83fafc0a..981e705c23 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.8" -Enzyme_jll = "0.0.144" +Enzyme_jll = "0.0.145" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" LogExpFunctions = "0.3" From 8bf8b4897837b57f927e48130938f6bfe6da554b Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 22 Aug 2024 11:54:24 -0500 Subject: [PATCH 231/495] Fix cpu features (#1744) --- Project.toml | 4 ++-- src/compiler/optimize.jl | 33 ++++++++++++++++++++++++++------- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 981e705c23..50f357574b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.32" +version = "0.12.33" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -36,7 +36,7 @@ CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.8" Enzyme_jll = "0.0.145" -GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" +GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8" LogExpFunctions = "0.3" ObjectFile = "0.4" diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index d27c7800c5..8c6385edb8 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -352,6 +352,31 @@ else end end +@static if VERSION < v"1.11-" + function cpu_features_tm!(pm, tm) + @static if isdefined(LLVM.Interop, :cpu_features!) + LLVM.Interop.cpu_features!(pm) + else + @static if isdefined(GPUCompiler, :cpu_features!) + GPUCompiler.cpu_features!(pm) + end + end + end +else + function cpu_features_tm!(pm, tm) + function cpu_features(mod) + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMModulePassManager()) do mpm + add!(mpm, CPUFeaturesPass()) + end + run!(pb, mod) + end + return true + end + add!(pm, ModulePass("CPUFeatures", cpu_features)) + end +end + function addNA(inst, node::LLVM.Metadata, MD) md = metadata(inst) next = nothing @@ -2041,13 +2066,7 @@ function optimize!(mod::LLVM.Module, tm) basic_alias_analysis!(pm) cfgsimplification!(pm) dce!(pm) -@static if isdefined(LLVM.Interop, :cpu_features!) - LLVM.Interop.cpu_features!(pm) -else -@static if isdefined(GPUCompiler, :cpu_features!) - GPUCompiler.cpu_features!(pm) -end -end + cpu_features_tm!(pm, tm) scalar_repl_aggregates_ssa!(pm) # SSA variant? mem_cpy_opt!(pm) always_inliner!(pm) From 71ae71353a7dd5f4698036ef4d55a52bb69474ca Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 22 Aug 2024 15:40:06 -0500 Subject: [PATCH 232/495] Try diferent name for gpu exception (#1746) --- src/compiler.jl | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index b04550fe0a..2fe1e38d69 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6079,10 +6079,12 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; # annotate annotate!(mod, mode) - if haskey(functions(mod), "gpu_report_exception") - exc = functions(mod)["gpu_report_exception"] - if !isempty(blocks(exc)) - linkage!(exc, LLVM.API.LLVMExternalLinkage) + for name in ("gpu_report_exception", "report_exception") + if haskey(functions(mod), name) + exc = functions(mod)[name] + if !isempty(blocks(exc)) + linkage!(exc, LLVM.API.LLVMExternalLinkage) + end end end @@ -6097,10 +6099,12 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; GPUCompiler.optimize_module!(parent_job, mod) end - if haskey(functions(mod), "gpu_report_exception") - exc = functions(mod)["gpu_report_exception"] - if !isempty(blocks(exc)) - linkage!(exc, LLVM.API.LLVMInternalLinkage) + for name in ("gpu_report_exception", "report_exception") + if haskey(functions(mod), name) + exc = functions(mod)[name] + if !isempty(blocks(exc)) + linkage!(exc, LLVM.API.LLVMInternalLinkage) + end end end From b9be0ac09923d905f45d13383af8ce19d9cd2d58 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 25 Aug 2024 16:30:14 -0500 Subject: [PATCH 233/495] Handle const addr casts (#1752) * Handle const addr casts * Update runtests.jl * Update runtests.jl --- src/absint.jl | 28 +++++++++++++++++----------- test/runtests.jl | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index a9748d2dd5..b84657aadb 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -6,6 +6,11 @@ function absint(arg::LLVM.Value, partial::Bool=false) isa(arg, LLVM.AddrSpaceCastInst) return absint(operands(arg)[1], partial) end + if isa(arg, ConstantExpr) + if opcode(arg) == LLVM.API.LLVMAddrSpaceCast || opcode(arg) == LLVM.API.LLVMBitCast + return absint(operands(arg)[1], partial) + end + end if isa(arg, LLVM.CallInst) fn = LLVM.called_operand(arg) nm = "" @@ -92,19 +97,14 @@ function absint(arg::LLVM.Value, partial::Bool=false) end if isa(arg, ConstantExpr) ce = arg - while isa(ce, ConstantExpr) - if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr - ce = operands(ce)[1] - else - break + if opcode(ce) == LLVM.API.LLVMIntToPtr + ce = operands(ce)[1] + if isa(ce, LLVM.ConstantInt) + ptr = reinterpret(Ptr{Cvoid}, convert(UInt, ce)) + typ = Base.unsafe_pointer_to_objref(ptr) + return (true, typ) end end - if !isa(ce, LLVM.ConstantInt) - return (false, nothing) - end - ptr = reinterpret(Ptr{Cvoid}, convert(UInt, ce)) - typ = Base.unsafe_pointer_to_objref(ptr) - return (true, typ) end if isa(arg, GlobalVariable) @@ -154,6 +154,12 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ isa(arg, LLVM.AddrSpaceCastInst) return abs_typeof(operands(arg)[1], partial) end + if isa(arg, ConstantExpr) + if opcode(arg) == LLVM.API.LLVMAddrSpaceCast || opcode(arg) == LLVM.API.LLVMBitCast + return abs_typeof(operands(arg)[1], partial) + end + end + if isa(arg, LLVM.CallInst) fn = LLVM.called_operand(arg) nm = "" diff --git a/test/runtests.jl b/test/runtests.jl index 113eb6f531..477f646c74 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -780,6 +780,38 @@ end @test hess[1][2] ≈ 1.0 @test hess[2][1] ≈ 1.0 @test hess[2][2] ≈ 0.0 + + function f_ip(x, tmp) + tmp .= x ./ 2 + return dot(tmp, x) + end + + function f_gradient_deferred!(dx, x, tmp) + dtmp = make_zero(tmp) + autodiff_deferred(Reverse, f_ip, Active, Duplicated(x, dx), Duplicated(tmp, dtmp)) + return nothing + end + + function f_hvp!(hv, x, v, tmp) + dx = make_zero(x) + btmp = make_zero(tmp) + autodiff( + Forward, + f_gradient_deferred!, + Duplicated(dx, hv), + Duplicated(x, v), + Duplicated(tmp, btmp), + ) + return nothing + end + + x = [1.0] + v = [-1.0] + hv = make_zero(v) + tmp = similar(x) + + f_hvp!(hv, x, v, tmp) + @test hv ≈ [-1.0] end @testset "Array tests" begin From 0f7dfcc7f870c95b4ec14d3330a15d3da4917ab5 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Sun, 25 Aug 2024 14:31:23 -0700 Subject: [PATCH 234/495] Clarify use cases for DuplicatedNoNeed (#1753) --- lib/EnzymeCore/src/EnzymeCore.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index ef1b34e56d..b5ecd348d6 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -79,6 +79,13 @@ end Like [`Duplicated`](@ref), except also specifies that Enzyme may avoid computing the original result and only compute the derivative values. + +This should only be used if `x` is a write-only variable. Otherwise, if the differentiated +function stores values in `x` and reads them back in subsequent computations, using +`DuplicatedNoNeed` may result in incorrect derivatives. In particular, `DuplicatedNoNeed` +should not be used for preallocated workspace, even if the user might not care about its +final value, as marking a variable as NoNeed means that reads from the variable are now +undefined. """ struct DuplicatedNoNeed{T} <: Annotation{T} val::T From 32b7aa2b00b6f04d789ed920f0671474258f98a3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 25 Aug 2024 18:29:44 -0500 Subject: [PATCH 235/495] Fix multidim solve (#1754) --- src/internal_rules.jl | 8 +++++--- test/internal_rules.jl | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index c2a31c2e4d..0199d7077c 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -327,6 +327,8 @@ end return LinearAlgebra.qr(cache_A, ColumnNorm()) end +@inline onedimensionalize(::Type{T}) where T <: Array = Vector{eltype(T)} + # y=inv(A) B # dA −= z y^T # dB += z, where z = inv(A^T) dy @@ -371,7 +373,7 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT} @static if VERSION < v"1.8.0" UT = Union{ - LinearAlgebra.Diagonal{eltype(AT), BT}, + LinearAlgebra.Diagonal{eltype(AT), onedimensionalize(BT)}, LinearAlgebra.LowerTriangular{eltype(AT), AT}, LinearAlgebra.UpperTriangular{eltype(AT), AT}, LinearAlgebra.LU{eltype(AT), AT}, @@ -379,11 +381,11 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT} } else UT = Union{ - LinearAlgebra.Diagonal{eltype(AT), BT}, + LinearAlgebra.Diagonal{eltype(AT), onedimensionalize(BT)}, LinearAlgebra.LowerTriangular{eltype(AT), AT}, LinearAlgebra.UpperTriangular{eltype(AT), AT}, LinearAlgebra.LU{eltype(AT), AT, Vector{Int}}, - LinearAlgebra.QRPivoted{eltype(AT), AT, BT, Vector{Int}} + LinearAlgebra.QRPivoted{eltype(AT), AT, onedimensionalize(BT), Vector{Int}} } end diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 835f195c4b..659d5dee98 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -133,6 +133,21 @@ end y = A \ b @test dA ≈ (-z * transpose(y)) + + # Ensure multi dim doesn't crash + function test2!(A) + A .= A \ [1.0 0;0.0 1.0] + return nothing + end + + A = rand(2,2) + dA = [1.0 0.0; 0.0 0.0] + + Enzyme.autodiff( + Enzyme.Reverse, + test2!, + Enzyme.Duplicated(A,dA), + ) end function tr_solv(A, B, uplo, trans, diag, idx) From c1e98c9ebb2921a78fd69ecab104b4a3cae1efb3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 25 Aug 2024 20:36:54 -0500 Subject: [PATCH 236/495] Type unstable custom rule tape (#1755) * Type unstable custom rule tape * fix * fix * Specialize tape type further --- src/compiler.jl | 46 ++++++++++++++++++++++++++++++++++ src/rules/customrules.jl | 53 +++++++++++++++++++++++++++++++++++----- test/rrules.jl | 33 +++++++++++++++++++++++++ 3 files changed, 126 insertions(+), 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 2fe1e38d69..66aeadefbb 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -862,6 +862,52 @@ function emit_jl!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value call!(B, FT, fn, [val]) end +function emit_getfield!(B::LLVM.IRBuilder, val::LLVM.Value, fld::LLVM.Value)::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_pprjlvalue = LLVM.PointerType(T_prjlvalue) + T_int32 = LLVM.Int32Type() + + gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32]) + inv, _ = get_function!(mod, "jl_f_getfield", gen_FT) + + args = [val, fld] + + @static if VERSION < v"1.9.0-" + FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue]; vararg=true) + inv = bitcast!(B, inv, LLVM.PointerType(FT)) + res = call!(B, FT, inv, args) + LLVM.callconv!(res, 37) + else + julia_call, FT = get_function!(mod, "julia.call", + LLVM.FunctionType(T_prjlvalue, + [LLVM.PointerType(gen_FT), T_prjlvalue]; vararg=true)) + res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) + end + return res +end + + +function emit_nthfield!(B::LLVM.IRBuilder, val::LLVM.Value, fld::LLVM.Value)::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_size_t = convert(LLVM.LLVMType, Int) + + gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_size_t]) + inv, _ = get_function!(mod, "jl_get_nth_field_checked", gen_FT) + + args = [val, fld] + call!(B, gen_FT, inv, args) +end + function emit_jl_throw!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 3303dd6c59..749ec36cfd 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -696,6 +696,20 @@ end if (aug_RT <: EnzymeRules.AugmentedReturn || aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && !(aug_RT isa UnionAll) && !(aug_RT isa Union) && !(aug_RT === Union{}) TapeT = EnzymeRules.tape_type(aug_RT) + elseif (aug_RT isa UnionAll) && (aug_RT <: EnzymeRules.AugmentedReturn) && aug_RT.body.name == EnzymeCore.EnzymeRules.AugmentedReturn.body.body.body.name + if aug_RT.body.parameters[3] isa TypeVar + TapeT = aug_RT.body.parameters[3].ub + else + TapeT = Any + end + elseif (aug_RT isa UnionAll) && (aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && aug_RT.body.name == EnzymeCore.EnzymeRules.AugmentedReturnFlexShadow.body.body.body.name + if aug_RT.body.parameters[3] isa TypeVar + TapeT = aug_RT.body.parameters[3].ub + else + TapeT = Any + end + else + TapeT = Any end mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -778,6 +792,12 @@ end miRT = enzyme_custom_extract_mi(llvmf)[2] _, sret, returnRoots = get_return_info(miRT) + sret_union = is_sret_union(miRT) + + if sret_union + emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " had a union sret of type "*string(miRT)*" which is not currently supported") + return tapeV + end if !forward funcTy = rev_TT.parameters[isKWCall ? 4 : 2] @@ -960,16 +980,33 @@ end ST = EnzymeRules.AugmentedReturnFlexShadow{needsPrimal ? RealRt : Nothing, needsShadowJL ? EnzymeRules.shadow_type(aug_RT) : Nothing, TapeT} end end + abstract = false if aug_RT != ST - ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, Any} - emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " return type mismatch, expected "*string(ST)*" found "* string(aug_RT)) - return tapeV + abs = (EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, T} where T) + if aug_RT <: abs + abstract = true + else + ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, Any} + emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " return type mismatch, expected "*string(ST)*" found "* string(aug_RT)) + return tapeV + end + end + + resV = if abstract + StructTy = convert(LLVMType, EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, Nothing}) + if StructTy != LLVM.VoidType() + load!(B, StructTy, bitcast!(B, res, LLVM.PointerType(StructTy, addrspace(value_type(res))))) + else + res + end + else + res end idx = 0 if needsPrimal @assert !isghostty(RealRt) - normalV = extract_value!(B, res, idx) + normalV = extract_value!(B, resV, idx) if get_return_info(RealRt)[2] !== nothing val = new_from_original(gutils, operands(orig)[1]) store!(B, normalV, val) @@ -982,7 +1019,7 @@ end if needsShadow if needsShadowJL @assert !isghostty(RealRt) - shadowV = extract_value!(B, res, idx) + shadowV = extract_value!(B, resV, idx) if get_return_info(RealRt)[2] !== nothing dval = invert_pointer(gutils, operands(orig)[1], B) @@ -1002,7 +1039,11 @@ end end end if needsTape - tapeV = extract_value!(B, res, idx).ref + tapeV = if abstract + emit_nthfield!(B, res, LLVM.ConstantInt(2)).ref + else + extract_value!(B, res, idx).ref + end idx+=1 end else diff --git a/test/rrules.jl b/test/rrules.jl index 3d330cf5fc..be4c4f1424 100644 --- a/test/rrules.jl +++ b/test/rrules.jl @@ -378,5 +378,38 @@ end @test dvals ≈ [0., 0., 46.7, 0.] end +unstabletape(x) = x^2 + +function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(unstabletape)}, ::Type{<:Active}, x::Active) + tape = if x.val < 3 + 400 + else + (x.val +7 ) * 10 + end + if needs_primal(config) + return AugmentedReturn{eltype(x), Nothing, typeof(tape)}(func.val(x.val), nothing, tape) + else + return AugmentedReturn{Nothing, Nothing, typeof(tape)}(nothing, nothing, tape) + end +end + +function reverse(config::ConfigWidth{1}, ::Const{typeof(unstabletape)}, dret, tape, x::Active{T}) where T + return (T(tape)::T,) +end + +unstabletapesq(x) = unstabletape(x)^2 + +@testset "Unstable Tape" begin + @test Enzyme.autodiff(Enzyme.Reverse, unstabletape, Active(2.0))[1][1] ≈ 400.0 + @test Enzyme.autodiff(Enzyme.ReverseWithPrimal, unstabletape, Active(2.0))[1][1] ≈ 400.0 + @test Enzyme.autodiff(Enzyme.Reverse, unstabletape, Active(5.0))[1][1] ≈ (5.0 + 7) * 10 + @test Enzyme.autodiff(Enzyme.ReverseWithPrimal, unstabletape, Active(5.0))[1][1] ≈ (5.0 + 7) * 10 + + @test Enzyme.autodiff(Enzyme.Reverse, unstabletapesq, Active(2.0))[1][1] ≈ (400.0) + @test Enzyme.autodiff(Enzyme.ReverseWithPrimal, unstabletapesq, Active(2.0))[1][1] ≈ (400.0) + @test Enzyme.autodiff(Enzyme.Reverse, unstabletapesq, Active(5.0))[1][1] ≈ ((5.0 + 7) * 10) + @test Enzyme.autodiff(Enzyme.ReverseWithPrimal, unstabletapesq, Active(5.0))[1][1] ≈ ((5.0 + 7) * 10) +end + include("mixedrrule.jl") end # ReverseRules From 44febc52cbc7b154900cc5afd846e658d483e931 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 26 Aug 2024 10:16:13 -0400 Subject: [PATCH 237/495] WIP: Add internal reverse-mode rules for ranges (#1656) * WIP: Add internal reverse-mode rules for ranges This is the second PR to fix https://github.com/EnzymeAD/Enzyme.jl/issues/274. It's separated as I think the forward mode one can just be merged no problem, and this one may take a little bit more time. The crux of why this one is hard is because of how Julia deals with malformed ranges. ``` Basically dret.val = 182.0:156.0:26.0, the 26.0 is not the true value. Same as julia> 10:1:1 10:1:9 ``` Because of that behavior, the reverse `dret` does not actually have the information as to what its final point is, and its length is "incorrect" as it's changed by the constructor. In order to "fix" the reverse, we'd want to swap the `step` to negative and then use the same start/stop, but that information is already lost so it cannot be fixed within the rule. You can see the commented out code that would do the fixing if the information is there, and without that we cannot get a correctly sized reversed range for the rule. But it's a bit puzzling to figure out how to remove that behavior. In Base Julia it seems to be done in the `function (:)(start::T, step::T, stop::T) where T<:IEEEFloat`, and as I showed in the issue, I can overload that function and the behavior goes away, but Enzyme's constructed range still has that truncation behavior, which means I missed spot or something. namespace ConfigWidth namespace namespace needs_primal namespace AugmentedReturn * Complete implementation * fix * fix --------- Co-authored-by: Billy Moses --- src/internal_rules.jl | 53 ++++++++++++++++++++++++++++++++++++++++++ test/internal_rules.jl | 11 +++++++++ 2 files changed, 64 insertions(+) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 0199d7077c..9dc1b83a69 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -861,6 +861,59 @@ function EnzymeRules.forward(func::Const{Colon}, end end + + +function EnzymeRules.augmented_primal(config, func::Const{Colon}, ::Type{<:Active}, + start::Annotation{<:AbstractFloat}, step::Annotation{<:AbstractFloat}, stop::Annotation{<:AbstractFloat}) + + if EnzymeRules.needs_primal(config) + primal = func.val(start.val, step.val, stop.val) + else + primal = nothing + end + return EnzymeRules.AugmentedReturn(primal, nothing, nothing) +end + +function EnzymeRules.reverse(config, func::Const{Colon}, dret, tape::Nothing, + start::Annotation{T1}, step::Annotation{T2}, stop::Annotation{T3}) where {T1<:AbstractFloat, T2<:AbstractFloat, T3<:AbstractFloat} + + dstart = if start isa Const + nothing + elseif EnzymeRules.width(config) == 1 + T1(dret.val.ref.hi) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + T1(dret.val[i].ref.hi) + end + end + + dstep = if step isa Const + nothing + elseif EnzymeRules.width(config) == 1 + T2(dret.val.step.hi) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + T2(dret.val[i].step.hi) + end + end + + dstop = if stop isa Const + nothing + elseif EnzymeRules.width(config) == 1 + zero(T3) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + zero(T3) + end + end + + return (dstart, dstep, dstop) +end + + function EnzymeRules.forward( Ty::Const{Type{BigFloat}}, RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}; diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 659d5dee98..2b1e9bc621 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -668,6 +668,17 @@ end ((var"1"=75.0, var"2"=150.0),) @test Enzyme.autodiff(Forward, f4, BatchDuplicated(0.12, (1.0, 2.0))) == ((var"1"=0.0, var"2"=0.0),) + + @test Enzyme.autodiff(Reverse, f1, Active, Active(0.1)) == ((25.0,),) + @test Enzyme.autodiff(Reverse, f2, Active, Active(0.1)) == ((25.0,),) + @test Enzyme.autodiff(Reverse, f3, Active, Active(0.1)) == ((75.0,),) + @test Enzyme.autodiff(Reverse, f4, Active, Active(0.12)) == ((0.0,),) + + # Batch active rule isnt setup + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f1(x); nothing end, Active(1.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((25.0,50.0)),) + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f2(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((25.0,50.0)),) + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f3(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((75.0,150.0)),) + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f4(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((0.0,0.0)),) end end # InternalRules From 8f3351dd2d355cd0c9819c84f70132b544888281 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 28 Aug 2024 12:07:00 -0500 Subject: [PATCH 238/495] Fix jlcall error handler (#1757) --- src/rules/llvmrules.jl | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 3d578f4dfc..df3b0ae181 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -104,7 +104,22 @@ include("parallelrules.jl") end end - emit_error(B, orig, "Enzyme: jl_call calling convention not implemented in forward for "*string(orig)) + err = emit_error(B, orig, "Enzyme: jl_call calling convention not implemented in forward for "*string(orig)) + + newo = new_from_original(gutils, orig) + + API.moveBefore(newo, err, B) + normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + if shadowR != C_NULL && normal !== nothing + unsafe_store!(shadowR, normal.ref) + end + # Delete the primal code + if normal !== nothing + unsafe_store!(normalR, C_NULL) + else + ni = new_from_original(gutils, orig) + API.EnzymeGradientUtilsErase(gutils, ni) + end return false end @@ -145,7 +160,21 @@ end end end - emit_error(B, orig, "Enzyme: jl_call calling convention not implemented in aug_forward for "*string(orig)) + err = emit_error(B, orig, "Enzyme: jl_call calling convention not implemented in aug_forward for "*string(orig)) + newo = new_from_original(gutils, orig) + + API.moveBefore(newo, err, B) + normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + if shadowR != C_NULL && normal !== nothing + unsafe_store!(shadowR, normal.ref) + end + # Delete the primal code + if normal !== nothing + unsafe_store!(normalR, C_NULL) + else + ni = new_from_original(gutils, orig) + API.EnzymeGradientUtilsErase(gutils, ni) + end return false end From 0900cd0cba5a116dddd3af54b6109181d77e0e2d Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Thu, 29 Aug 2024 11:29:11 -0400 Subject: [PATCH 239/495] fix LLVM 8 depwarns (#1758) * fix depwarns * revert on < Julia-1.9 * detect LLVM version from Julia 1.6 on * just use old LLVM defs --- src/compiler.jl | 2 +- src/compiler/orcv2.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 66aeadefbb..5e0e0a0b93 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -47,7 +47,7 @@ import GPUCompiler: @safe_debug, @safe_info, @safe_warn, @safe_error include("compiler/utils.jl") -if LLVM.has_orc_v1() +if v"8" <= LLVM.version() < v"12" include("compiler/orcv1.jl") else include("compiler/orcv2.jl") diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index 61971f47ed..d36b1ca1c1 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -9,7 +9,7 @@ import ..Compiler import ..Compiler: API, cpu_name, cpu_features @inline function use_ojit() - return LLVM.has_julia_ojit() && !Sys.iswindows() + return (VERSION >= v"1.10.0-DEV.1395") && !Sys.iswindows() end export get_trampoline From 6d24d3d4d7569f56776a7b77f843384fb9fb236e Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Sat, 31 Aug 2024 07:50:28 +0000 Subject: [PATCH 240/495] Improve RuntimeActivity error message --- src/compiler.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 5e0e0a0b93..be939fdf16 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1064,6 +1064,13 @@ struct EnzymeRuntimeActivityError <: Base.Exception end function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError) + println(io, "Constant memory is stored (or returned) to a differentiable variable.\n") + println(io, "As a result, Enzyme cannot provably ensure correctness and throws this error.\n") + println(io, "This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage).\n") + println(io, "If Enzyme should be able to prove this use non-differentable, open an issue!\n"); + println(io, "To work around this issue, either:\n"); + println(io, " a) rewrite this variable to not be conditionally active (fastest, but requires a code change, or\n") + println(io, " a) set Enzyme.API.runtimeActivity!(true) immediately after loading Enzyme (which maintains correctness, but may slightly reduce performance).\n") msg = Base.unsafe_string(ece.msg) print(io, msg, '\n') end @@ -2582,7 +2589,6 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err if illegalVal !== nothing println(io, " llvalue="*string(illegalVal)) end - println(io, "You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now") if bt !== nothing Base.show_backtrace(io, bt) end From 4a5dbba83e43f03fe745bbf58a940c86cabeba59 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 31 Aug 2024 05:28:38 -0500 Subject: [PATCH 241/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 50f357574b..4991d9074c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.33" +version = "0.12.34" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 821483812d7798f1b974ab9fdd1aad82998fc0ca Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 31 Aug 2024 15:37:29 -0500 Subject: [PATCH 242/495] Handle deepcopy of constant (#1765) --- src/internal_rules.jl | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 9dc1b83a69..bcb3b1c413 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -193,15 +193,21 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)} x.val end - shadow = ntuple(Val(EnzymeRules.width(config))) do _ - Base.@_inline_meta - Enzyme.make_zero(source, - #=copy_if_inactive=#Val(!EnzymeRules.needs_primal(config)) - ) - end - - if EnzymeRules.width(config) == 1 - shadow = shadow[1] + shadow = if EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + Enzyme.make_zero(source, + #=copy_if_inactive=#Val(!EnzymeRules.needs_primal(config)) + ) + else + ntuple(Val(EnzymeRules.width(config))) do _ + Base.@_inline_meta + Enzyme.make_zero(source, + #=copy_if_inactive=#Val(!EnzymeRules.needs_primal(config)) + ) + end + end + else + nothing end return EnzymeRules.AugmentedReturn(primal, shadow, shadow) @@ -241,11 +247,13 @@ end end function EnzymeRules.reverse(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, x::Annotation{Ty}) where {RT, Ty} - if EnzymeRules.width(config) == 1 - accumulate_into(x.dval, IdDict(), shadow) - else - for i in 1:EnzymeRules.width(config) - accumulate_into(x.dval[i], IdDict(), shadow[i]) + if EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + accumulate_into(x.dval, IdDict(), shadow) + else + for i in 1:EnzymeRules.width(config) + accumulate_into(x.dval[i], IdDict(), shadow[i]) + end end end From c054399ce9273bffaca230a4ee7d5c2cec799765 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 31 Aug 2024 16:37:38 -0400 Subject: [PATCH 243/495] fix: conditionally define typetree_inner for BFloat16 (#1759) --- ext/EnzymeBFloat16sExt.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ext/EnzymeBFloat16sExt.jl b/ext/EnzymeBFloat16sExt.jl index 050825aef0..b08cdae278 100644 --- a/ext/EnzymeBFloat16sExt.jl +++ b/ext/EnzymeBFloat16sExt.jl @@ -3,8 +3,10 @@ module EnzymeBFloat16sExt using BFloat16s using Enzyme +if !(isdefined(Core, :BFloat16) && Core.BFloat16 === BFloat16) function Enzyme.typetree_inner(::Type{BFloat16}, ctx, dl, seen::Enzyme.Compiler.TypeTreeTable) return Enzyme.TypeTree(Enzyme.API.DT_BFloat16, -1, ctx) end +end end From 7037529751138a3a3a8bb6b7d908f5cdb3341363 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 31 Aug 2024 15:49:47 -0500 Subject: [PATCH 244/495] Bump LLVM.jl version (#1767) * Bump LLVM.jl version * Retain compat with old llvm's * fix --------- Co-authored-by: William Moses --- Project.toml | 4 ++-- src/compiler/validation.jl | 42 +++++++++++++++++++++++++++++--------- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 4991d9074c..551569a719 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.34" +version = "0.12.35" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -37,7 +37,7 @@ ChainRulesCore = "1" EnzymeCore = "0.7.8" Enzyme_jll = "0.0.145" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" -LLVM = "6.1, 7, 8" +LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" ObjectFile = "0.4" Preferences = "1.4" diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 9c8bc3921c..b95d343bfb 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -196,8 +196,12 @@ function rewrite_ccalls!(mod::LLVM.Module) if changed prevname = LLVM.name(inst) LLVM.name!(inst, "") - newinst = call!(B, called_type(inst), called_operand(inst), uservals, collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), prevname) - for idx = [LLVM.API.LLVMAttributeFunctionIndex, LLVM.API.LLVMAttributeReturnIndex, [LLVM.API.LLVMAttributeIndex(i) for i in 1:(length(arguments(inst)))]...] + if !isdefined(LLVM, :OperandBundleDef) + newinst = call!(B, called_type(inst), called_operand(inst), uservals, collect(operand_bundles(inst)), prevname) + else + newinst = call!(B, called_type(inst), called_operand(inst), uservals, collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), prevname) + end + for idx = [LLVM.API.LLVMAttributeFunctionIndex, LLVM.API.LLVMAttributeReturnIndex, [LLVM.API.LLVMAttributeIndex(i) for i in 1:(length(arguments(inst)))]...] idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx); Attrs = Base.unsafe_convert(Ptr{LLVM.API.LLVMAttributeRef}, Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef)*count)) @@ -213,13 +217,27 @@ function rewrite_ccalls!(mod::LLVM.Module) end continue end - newbundles = OperandBundleDef[] - for bunduse in operand_bundles(inst) - bunduse = LLVM.OperandBundleDef(bunduse) - if LLVM.tag_name(bunduse) != "jl_roots" - push!(newbundles, bunduse) - continue - end + if !isdefined(LLVM, :OperandBundleDef) + newbundles = OperandBundle[] + else + newbundles = OperandBundleDef[] + end + for bunduse in operand_bundles(inst) + if isdefined(LLVM, :OperandBundleDef) + bunduse = LLVM.OperandBundleDef(bunduse) + end + + if !isdefined(LLVM, :OperandBundleDef) + if LLVM.tag(bunduse) != "jl_roots" + push!(newbundles, bunduse) + continue + end + else + if LLVM.tag_name(bunduse) != "jl_roots" + push!(newbundles, bunduse) + continue + end + end uservals = LLVM.Value[] subchanged = false for lval in LLVM.inputs(bunduse) @@ -246,7 +264,11 @@ function rewrite_ccalls!(mod::LLVM.Module) continue end changed = true - push!(newbundles, OperandBundleDef(LLVM.tag_name(bunduse), uservals)) + if !isdefined(LLVM, :OperandBundleDef) + push!(newbundles, OperandBundle(LLVM.tag(bunduse), uservals)) + else + push!(newbundles, OperandBundleDef(LLVM.tag_name(bunduse), uservals)) + end end changed = false if changed From 2cb733f1414c3471f648a7916a0f48b28c644c7d Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 31 Aug 2024 23:50:34 -0400 Subject: [PATCH 245/495] Reduce newlines in newer error message --- src/compiler.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index be939fdf16..b17b478244 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1064,13 +1064,13 @@ struct EnzymeRuntimeActivityError <: Base.Exception end function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError) - println(io, "Constant memory is stored (or returned) to a differentiable variable.\n") - println(io, "As a result, Enzyme cannot provably ensure correctness and throws this error.\n") - println(io, "This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage).\n") - println(io, "If Enzyme should be able to prove this use non-differentable, open an issue!\n"); - println(io, "To work around this issue, either:\n"); - println(io, " a) rewrite this variable to not be conditionally active (fastest, but requires a code change, or\n") - println(io, " a) set Enzyme.API.runtimeActivity!(true) immediately after loading Enzyme (which maintains correctness, but may slightly reduce performance).\n") + println(io, "Constant memory is stored (or returned) to a differentiable variable.") + println(io, "As a result, Enzyme cannot provably ensure correctness and throws this error.") + println(io, "This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage).") + println(io, "If Enzyme should be able to prove this use non-differentable, open an issue!"); + println(io, "To work around this issue, either:"); + println(io, " a) rewrite this variable to not be conditionally active (fastest, but requires a code change, or") + println(io, " a) set Enzyme.API.runtimeActivity!(true) immediately after loading Enzyme (which maintains correctness, but may slightly reduce performance).") msg = Base.unsafe_string(ece.msg) print(io, msg, '\n') end From 4b74341e399ec96d4b523b2350c6e3f67055a19e Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 1 Sep 2024 00:16:30 -0400 Subject: [PATCH 246/495] Slight message fixes --- src/compiler.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index b17b478244..37f24dcb07 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1069,8 +1069,8 @@ function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError) println(io, "This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage).") println(io, "If Enzyme should be able to prove this use non-differentable, open an issue!"); println(io, "To work around this issue, either:"); - println(io, " a) rewrite this variable to not be conditionally active (fastest, but requires a code change, or") - println(io, " a) set Enzyme.API.runtimeActivity!(true) immediately after loading Enzyme (which maintains correctness, but may slightly reduce performance).") + println(io, " a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or") + println(io, " b) set Enzyme.API.runtimeActivity!(true) immediately after loading Enzyme (which maintains correctness, but may slightly reduce performance).") msg = Base.unsafe_string(ece.msg) print(io, msg, '\n') end From eb80bc9802513897df7ed64484dfb758e09afbd6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 1 Sep 2024 11:19:50 -0500 Subject: [PATCH 247/495] Manual report exception for cuda (#1766) * Manual report exception for cuda * Update compiler.jl * Update compiler.jl * Update compiler.jl * Update compiler.jl --------- Co-authored-by: William Moses --- src/compiler.jl | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 37f24dcb07..2338f60c83 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1926,7 +1926,21 @@ function emit_error(B::LLVM.IRBuilder, orig, string, errty=EnzymeRuntimeExceptio end ct = if occursin("ptx", LLVM.triple(mod)) || occursin("amdgcn", LLVM.triple(mod)) - GPUCompiler.emit_exception!(B, string, orig) + exc = functions(mod)["gpu_report_exception"] + + name = globalstring_ptr!(B, string, "exception") + call!(B, LLVM.function_type(exc), exc, [name]) + + sig = GPUCompiler.Runtime.get(:signal_exception) + call!(B, sig) + + trap_ft = LLVM.FunctionType(LLVM.VoidType()) + trap = if haskey(functions(mod), "llvm.trap") + functions(mod)["llvm.trap"] + else + LLVM.Function(mod, "llvm.trap", trap_ft) + end + call!(B, trap_ft, trap) else err = emit_allocobj!(B, errty) err2 = bitcast!(B, err, LLVM.PointerType(LLVM.PointerType(LLVM.Int8Type()), 10)) From 3588c9938f704558beb13f530f208de5741630eb Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 1 Sep 2024 13:25:31 -0500 Subject: [PATCH 248/495] Loosen runtime activity errors (#1768) * Loosen runtime activity errors * Update compiler.jl * Update compiler.jl --------- Co-authored-by: William Moses --- src/compiler.jl | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 2338f60c83..98f068dfac 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2460,7 +2460,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err if v != tmp changed = true end - push!(todo, tmp) + push!(cvals, tmp) end cur2 = if changed @@ -2474,7 +2474,10 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err return cur2 end if isa(cur, LLVM.ConstantInt) - if LLVM.width(value_type(cur)) <= 8 + if LLVM.width(value_type(cur)) <= sizeof(Int)*8 + return make_batched(ncur, prevbb) + end + if LLVM.width(value_type(cur)) == sizeof(Int)*8 && abs(convert(Int, cur)) < 10000 return make_batched(ncur, prevbb) end # if storing a constant int as a non-pointer, presume it is not a GC'd var and is safe @@ -2484,6 +2487,35 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end end + if isa(cur, LLVM.SelectInst) + lhs = make_replacement(operands(cur)[2], prevbb) + if illegal + return ncur + end + rhs = make_replacement(operands(cur)[3], prevbb) + if illegal + return ncur + end + if lhs == operands(cur)[2] && rhs == operands(cur)[3] + return make_batched(ncur, prevbb) + end + if width == 1 + nv = select!(prevbb, new_from_original(gutils, operands(cur)[1]), lhs, rhs) + push!(created, nv) + seen[cur] = nv + return nv + else + shadowres = LLVM.UndefValue(value_type(lhs)) + for idx in 1:width + shadowres = insert_value!(prevbb, shadowres, select!(new_from_original(gutils, operands(cur)[1]), extract_value!(prevbb, lhs, idx), extract_value!(prevbb, rhs, idx)), idx) + if isa(shadowres, LLVM.Instruction) + push!(created, shadowres) + end + end + return shadowres + end + end + if isa(cur, LLVM.InsertValueInst) lhs = make_replacement(operands(cur)[1], prevbb) if illegal From c6b3f61b01e5266ac03d5649587f8e9052ab3f4a Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 1 Sep 2024 13:25:49 -0500 Subject: [PATCH 249/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 551569a719..82196219d2 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.8" -Enzyme_jll = "0.0.145" +Enzyme_jll = "0.0.146" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From 67c529cbc7b9031c99791491f77f02c6586f1f39 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 1 Sep 2024 18:50:38 -0500 Subject: [PATCH 250/495] Nancheck for GPU (#1770) --- src/compiler.jl | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 98f068dfac..d44d16362f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1925,26 +1925,29 @@ function emit_error(B::LLVM.IRBuilder, orig, string, errty=EnzymeRuntimeExceptio string*=sprint(io->Base.show_backtrace(io, bt)) end + if !isa(string, LLVM.Value) + string = globalstring_ptr!(B, string, "enz_exception") + end + ct = if occursin("ptx", LLVM.triple(mod)) || occursin("amdgcn", LLVM.triple(mod)) exc = functions(mod)["gpu_report_exception"] - name = globalstring_ptr!(B, string, "exception") - call!(B, LLVM.function_type(exc), exc, [name]) + call!(B, LLVM.function_type(exc), exc, [string]) - sig = GPUCompiler.Runtime.get(:signal_exception) - call!(B, sig) + sig = GPUCompiler.Runtime.get(:signal_exception) + call!(B, sig) - trap_ft = LLVM.FunctionType(LLVM.VoidType()) - trap = if haskey(functions(mod), "llvm.trap") - functions(mod)["llvm.trap"] - else - LLVM.Function(mod, "llvm.trap", trap_ft) - end - call!(B, trap_ft, trap) + trap_ft = LLVM.FunctionType(LLVM.VoidType()) + trap = if haskey(functions(mod), "llvm.trap") + functions(mod)["llvm.trap"] + else + LLVM.Function(mod, "llvm.trap", trap_ft) + end + call!(B, trap_ft, trap) else err = emit_allocobj!(B, errty) err2 = bitcast!(B, err, LLVM.PointerType(LLVM.PointerType(LLVM.Int8Type()), 10)) - store!(B, globalstring_ptr!(B, string), err2) + store!(B, string, err2) emit_jl_throw!(B, addrspacecast!(B, err, LLVM.PointerType(LLVM.StructType(LLVMType[]), 12))) end @@ -2182,14 +2185,11 @@ function julia_sanitize(orig::LLVM.API.LLVMValueRef, val::LLVM.API.LLVMValueRef, position!(builder, good) ret!(builder) - # ret!(builder, inp) + position!(builder, bad) - err = emit_allocobj!(builder, EnzymeRuntimeException) - err2 = bitcast!(builder, err, LLVM.PointerType(LLVM.PointerType(LLVM.Int8Type()), 10)) - store!(builder, globalstring_ptr!(builder, string), err2) - emit_jl_throw!(builder, addrspacecast!(builder, err, LLVM.PointerType(LLVM.StructType(LLVMType[]), 12))) - unreachable!(builder) + emit_error(builder, nothing, sval, EnzymeNoDerivativeError) + unreachable!(builder) dispose(builder) end end From 0ed957866c915565924dcf1b81fbc22216777038 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 2 Sep 2024 15:18:37 -0500 Subject: [PATCH 251/495] Fix GPU errors (#1774) * Fix GPU errors * fix * ptrtoint * Update compiler.jl * Update compiler.jl --- src/compiler.jl | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index d44d16362f..8aa3114870 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1915,28 +1915,38 @@ function emit_error(B::LLVM.IRBuilder, orig, string, errty=EnzymeRuntimeExceptio fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) - # 1. get the error function - if orig !== nothing - bt = GPUCompiler.backtrace(orig) - function printBT(io) - print(io,"\nCaused by:") - Base.show_backtrace(io, bt) - end - string*=sprint(io->Base.show_backtrace(io, bt)) - end - if !isa(string, LLVM.Value) string = globalstring_ptr!(B, string, "enz_exception") end ct = if occursin("ptx", LLVM.triple(mod)) || occursin("amdgcn", LLVM.triple(mod)) - exc = functions(mod)["gpu_report_exception"] + + vt = LLVM.VoidType() + ptr = convert(LLVMType, Ptr{Cvoid}) + + exc, _ = get_function!(mod, "gpu_report_exception", LLVM.FunctionType(vt, [ptr])) + + string = ptrtoint!(B, string, ptr) call!(B, LLVM.function_type(exc), exc, [string]) - sig = GPUCompiler.Runtime.get(:signal_exception) - call!(B, sig) + framefn, ft = get_function!(mod, "gpu_report_exception_frame", LLVM.FunctionType(vt, [LLVM.Int32Type(), ptr, ptr, LLVM.Int32Type()])) + if orig !== nothing + bt = GPUCompiler.backtrace(orig) + for (i,frame) in enumerate(bt) + idx = ConstantInt(parameters(ft)[1], i) + func = globalstring_ptr!(B, String(frame.func), "di_func") + func = ptrtoint!(B, func, ptr) + file = globalstring_ptr!(B, String(frame.file), "di_file") + file = ptrtoint!(B, file, ptr) + line = ConstantInt(parameters(ft)[4], frame.line) + call!(B, ft, framefn, [idx, func, file, line]) + end + end + + sigfn, sigft = get_function!(mod, "gpu_signal_exception", LLVM.FunctionType(vt, LLVM.LLVMType[])) + call!(B, sigft, sigfn) trap_ft = LLVM.FunctionType(LLVM.VoidType()) trap = if haskey(functions(mod), "llvm.trap") functions(mod)["llvm.trap"] From 4ddaccd470a34ef354d6e25e354fdc61a1504b19 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 2 Sep 2024 15:21:37 -0500 Subject: [PATCH 252/495] Fix type error (#1775) --- src/compiler.jl | 4 ++-- src/typetree.jl | 4 ++-- test/runtests.jl | 31 +++++++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 8aa3114870..5d2d29563f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6247,13 +6247,13 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; Ptr{source_typ} end else - codegen_typ + source_typ end if isa(inst, LLVM.CallInst) LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_type", string(typetree(typ, ctx, dl, seen)))) else - metadata(inst)["enzyme_type"] = to_md(typetree(arg.typ, ctx, dl, seen), ctx) + metadata(inst)["enzyme_type"] = to_md(typetree(typ, ctx, dl, seen), ctx) end elseif codegen_typ == T_prjlvalue if isa(inst, LLVM.CallInst) diff --git a/src/typetree.jl b/src/typetree.jl index ae4329c172..73b296b95b 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -77,7 +77,7 @@ Construct a Enzyme typetree from a Julia type. When using a memoized lookup by providing `seen` across multiple calls to typtree the user must call `copy` on the returned value before mutating it. """ -function typetree(@nospecialize(T), ctx, dl, seen=TypeTreeTable()) +function typetree(@nospecialize(T::Type), ctx, dl, seen=TypeTreeTable()) if haskey(seen, T) tree = seen[T] if tree === nothing @@ -205,7 +205,7 @@ else ismutabletype(T) = isa(T, DataType) && T.mutable end -function typetree_inner(@nospecialize(T), ctx, dl, seen::TypeTreeTable) +function typetree_inner(@nospecialize(T::Type), ctx, dl, seen::TypeTreeTable) if T isa UnionAll || T isa Union || T == Union{} || Base.isabstracttype(T) return TypeTree() end diff --git a/test/runtests.jl b/test/runtests.jl index 477f646c74..a869443a09 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -814,6 +814,37 @@ end @test hv ≈ [-1.0] end +@testset "Nested Type Error" begin + nested_f(x) = sum(tanh, x) + + function nested_df!(dx, x) + make_zero!(dx) + autodiff_deferred(Reverse, nested_f, Active, Duplicated(x, dx)) + return nothing + end + + function nested_hvp!(hv, v, x) + make_zero!(hv) + autodiff(Forward, nested_df!, Const, Duplicated(make_zero(x), hv), Duplicated(x, v)) + return nothing + end + + x = [0.5] + + # primal: sanity check + @test nested_f(x) ≈ sum(tanh, x) + + # gradient: works + dx = make_zero(x) + nested_df!(dx, x) + + @test dx ≈ (sech.(x).^2) + + v = first(onehot(x)) + hv = make_zero(v) + nested_hvp!(hv, v, x) +end + @testset "Array tests" begin function arsum(f::Array{T}) where T From 4aa1d5b315b608ec88190e9a5555db8fb99a5ede Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Mon, 2 Sep 2024 20:22:11 +0000 Subject: [PATCH 253/495] Fix gc loaded usage --- src/compiler.jl | 7 +++++++ src/rules/llvmrules.jl | 6 ------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 5d2d29563f..8632469655 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3777,6 +3777,13 @@ function annotate!(mod, mode) end end + for fname in ("julia.gc_loaded",) + if haskey(fns, fname) + fn = fns[fname] + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) + end + end + for fname in ("julia.get_pgcstack", "julia.ptls_states", "jl_get_ptls_states", "julia.safepoint", "ijl_throw", "julia.pointer_from_objref", "ijl_array_grow_end", "jl_array_grow_end", "ijl_array_del_end", "jl_array_del_end", "ijl_array_grow_beg", "jl_array_grow_beg", "ijl_array_del_beg", "jl_array_del_beg", diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index df3b0ae181..31b2454a1b 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -1368,12 +1368,6 @@ end @revfunc(jlcall2_rev), @fwdfunc(jlcall2_fwd), ) - register_handler!( - ("julia.gc_loaded",), - @augfunc(gcloaded_augfwd), - @revfunc(gcloaded_rev), - @fwdfunc(gcloaded_fwd), - ) register_handler!( ("jl_apply_generic", "ijl_apply_generic"), @augfunc(generic_augfwd), From c23a8250c4a50307b68f4bdf89e4d8f9e545543d Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 2 Sep 2024 17:04:05 -0500 Subject: [PATCH 254/495] Fix readonly on julia 1.11 (#1777) --- src/compiler.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 8632469655..58ad563224 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6111,8 +6111,12 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; name = string(name) name = T == Float32 ? name*"f" : name - handleCustom(llvmfn, name, [EnumAttribute("readnone", 0), - StringAttribute("enzyme_shouldrecompute")]) + attrs = if LLVM.version().major <= 15 + [LLVM.EnumAttribute("readnone"), StringAttribute("enzyme_shouldrecompute")] + else + [EnumAttribute("memory", NoEffects.data), StringAttribute("enzyme_shouldrecompute")] + end + handleCustom(llvmfn, name, attrs) end @assert actualRetType !== nothing From a295b52ca366ef50372e9c08607c6fcfe1cf7789 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 2 Sep 2024 19:23:55 -0500 Subject: [PATCH 255/495] Fix julia select mixed activity (#1776) --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 58ad563224..041203c37b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2517,7 +2517,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err else shadowres = LLVM.UndefValue(value_type(lhs)) for idx in 1:width - shadowres = insert_value!(prevbb, shadowres, select!(new_from_original(gutils, operands(cur)[1]), extract_value!(prevbb, lhs, idx), extract_value!(prevbb, rhs, idx)), idx) + shadowres = insert_value!(prevbb, shadowres, select!(prevbb, new_from_original(gutils, operands(cur)[1]), extract_value!(prevbb, lhs, idx), extract_value!(prevbb, rhs, idx)), idx) if isa(shadowres, LLVM.Instruction) push!(created, shadowres) end From 8a81e1bbab12b783286db5bf5dd1bbbf9ecfb226 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 3 Sep 2024 02:28:24 +0200 Subject: [PATCH 256/495] Implement `set_err_if_func_written` for `ReverseModeSplit` (#1722) --- lib/EnzymeCore/src/EnzymeCore.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index b5ecd348d6..0e67e4e3c0 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -251,6 +251,10 @@ const ReverseSplitNoPrimal = ReverseModeSplit{false, true, 0, true,DefaultABI, f const ReverseSplitWithPrimal = ReverseModeSplit{true, true, 0, true,DefaultABI, false}() @inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, Width, MBO, ABI, ErrIfFuncWritten}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,Width,MB,MBO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,MB,ABI, ErrIfFuncWritten}() @inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, WidthO, MB, ABI, ErrIfFuncWritten}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,Width,MB,WidthO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,MB,ABI, ErrIfFuncWritten}() + +@inline set_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, true}() +@inline clear_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, false}() + """ struct Forward <: Mode From c4c6e48e75a586925ae9f35b095bc7e6e5ce3278 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Mon, 2 Sep 2024 20:01:57 -0500 Subject: [PATCH 257/495] Don't consider size of memset to be writing --- src/compiler.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 041203c37b..65cd00a10d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6369,6 +6369,13 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if isa(user, LLVM.CallInst) called = LLVM.called_operand(user) if isa(called, LLVM.Function) + intr = LLVM.API.LLVMGetIntrinsicID(called) + if intr == LLVM.Intrinsic("llvm.memset").id + if cur != operands(user)[1] + continue + end + end + nm = LLVM.name(called) if nm == "ijl_alloc_array_1d" || nm == "jl_alloc_array_1d" || nm == "ijl_alloc_array_2d" || nm == "jl_alloc_array_2d" || From 545f3d138c81d50045d13ce862da013859bc56e3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 2 Sep 2024 20:21:25 -0500 Subject: [PATCH 258/495] Fix select index offset (#1778) --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 65cd00a10d..020a2d0b59 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2517,7 +2517,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err else shadowres = LLVM.UndefValue(value_type(lhs)) for idx in 1:width - shadowres = insert_value!(prevbb, shadowres, select!(prevbb, new_from_original(gutils, operands(cur)[1]), extract_value!(prevbb, lhs, idx), extract_value!(prevbb, rhs, idx)), idx) + shadowres = insert_value!(prevbb, shadowres, select!(prevbb, new_from_original(gutils, operands(cur)[1]), extract_value!(prevbb, lhs, idx-1), extract_value!(prevbb, rhs, idx-1)), idx-1) if isa(shadowres, LLVM.Instruction) push!(created, shadowres) end From 9db8a4b4b7f1da5ce1cfb3a33cbe492690f5a6d2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Sep 2024 07:16:27 -0500 Subject: [PATCH 259/495] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 82196219d2..8229dfc170 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.35" +version = "0.12.36" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -35,7 +35,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.8" -Enzyme_jll = "0.0.146" +Enzyme_jll = "0.0.147" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From 54fdd094e4615c1085c1e46ea40df85baae4250c Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Sep 2024 10:52:34 -0500 Subject: [PATCH 260/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8229dfc170..02636694ba 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.8" -Enzyme_jll = "0.0.147" +Enzyme_jll = "0.0.146" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From f520dd4912f3becc2c5286e7ebd9cc3a9fd2f869 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Sep 2024 13:07:49 -0500 Subject: [PATCH 261/495] Add forward mode svec_ref (#1782) --- src/rules/llvmrules.jl | 2 +- src/rules/typeunstablerules.jl | 68 +++++++++++++++++++++++++++------- 2 files changed, 56 insertions(+), 14 deletions(-) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 31b2454a1b..9c4feb126c 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -149,7 +149,7 @@ end if in(name, ("ijl_f__apply_iterate", "jl_f__apply_iterate")) return common_apply_iterate_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end - if in(name, ("ijl_f__svec_rev", "jl_f__svec_ref")) + if in(name, ("ijl_f__svec_ref", "jl_f__svec_ref")) return common_f_svec_ref_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end if in(name, ("ijl_f_finalizer", "jl_f_finalizer")) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 42fcbe18cf..5372a67726 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -1330,27 +1330,69 @@ end common_setfield_rev(1, B, orig, gutils, tape) end - +function error_if_differentiable(::Type{T}) where T + seen = () + areg = active_reg_inner(T, seen, nothing) + if areg != AnyState + throw(AssertionError("Found unhandled differentiable variable in jl_f_svec_ref $T")) + end + nothing +end function common_f_svec_ref_fwd(offset, B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - emit_error(B, orig, "Enzyme: unhandled forward for jl_f__svec_ref") - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - if shadowR != C_NULL && normal !== nothing - unsafe_store!(shadowR, normal.ref) + + width = get_width(gutils) + + origmi, origh, origkey = operands(orig)[offset:end-1] + + shadowh = invert_pointer(gutils, origh, B) + + newvals = API.CValueType[API.VT_Primal, API.VT_Shadow, API.VT_Primal] + + if offset != 1 + pushfirst!(newvals, API.VT_Primal) end - return false -end + + mi = new_from_original(gutils, origmi) -function error_if_differentiable(::Type{T}) where T - seen = () - areg = active_reg_inner(T, seen, nothing, #=justActive=#Val(true)) - if areg != AnyState - throw(AssertionError("Found unhandled differentiable variable in jl_f_svec_ref $T")) + mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) + + shadowres = if width == 1 + newops = LLVM.Value[mi, shadowh, new_from_original(gutils, origkey)] + if offset != 1 + pushfirst!(newops, operands(orig)[1]) + end + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + callconv!(cal, callconv(orig)) + + if is_constant_value(gutils, origh) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_differentiable), emit_jltypeof!(B, cal)]) + end + cal + else + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + shadow = LLVM.UndefValue(ST) + for j in 1:width + newops = LLVM.Value[mi, extract_value!(B, shadowh, j-1), new_from_original(gutils, origkey)] + if offset != 1 + pushfirst!(newops, operands(orig)[1]) + end + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + callconv!(cal, callconv(orig)) + if is_constant_value(gutils, origh) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_differentiable), emit_jltypeof!(B, cal)]) + end + shadow = insert_value!(B, shadow, cal, j-1) + end + shadow end - nothing + + unsafe_store!(shadowR, shadowres.ref) + + return false end function common_f_svec_ref_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) From d883c6ac18204bacf3edd275be51f160b5af18f0 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Sep 2024 17:41:21 -0500 Subject: [PATCH 262/495] Do write barrier after mixed duplicated allocation upgrade (#1785) * Do write barrier after mixed duplicated allocation upgrade * Add dump post wrap option --- src/compiler.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 020a2d0b59..6d3161b25f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4002,6 +4002,7 @@ include("rules/activityrules.jl") @inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: BatchDuplicatedNoNeed = API.DFT_DUP_NONEED const DumpPreEnzyme = Ref(false) +const DumpPostWrap = Ref(false) function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wrap, modifiedBetween, returnPrimal, expectedTapeType, loweredArgs, boxedArgs) if DumpPreEnzyme[] @@ -4189,6 +4190,9 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr else @assert "Unhandled derivative mode", mode end + if DumpPostWrap[] + API.EnzymeDumpModuleRef(mod.ref) + end API.EnzymeLogicErasePreprocessedFunctions(logic) adjointfname = adjointf == nothing ? nothing : LLVM.name(adjointf) augmented_primalfname = augmented_primalf == nothing ? nothing : LLVM.name(augmented_primalf) @@ -4495,9 +4499,10 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, convty = convert(LLVMType, T′; allow_boxed=true) if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType)) - al = emit_allocobj!(builder, Base.RefValue{T′}) + al0 = al = emit_allocobj!(builder, Base.RefValue{T′}) al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) store!(builder, params[i], al) + emit_writebarrier!(builder, get_julia_inner_types(builder, al0, params[i])) al = addrspacecast!(builder, al, LLVM.PointerType(llty, Derived)) push!(realparms, al) else From d7cb43bb53011c768f1b7f68859dce7399cdaef2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Sep 2024 17:41:41 -0500 Subject: [PATCH 263/495] Handle batch closures (#1784) --- src/compiler.jl | 4 +++- test/runtests.jl | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 6d3161b25f..a9bcaddf53 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6872,8 +6872,10 @@ end argexpr = :(fn.dval) if isboxed push!(types, Any) - else + elseif width == 1 push!(types, F) + else + push!(types, NTuple{width, F}) end push!(ccexprs, argexpr) end diff --git a/test/runtests.jl b/test/runtests.jl index a869443a09..250fa20c97 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -313,6 +313,23 @@ end BatchDuplicated(ones(3), (ones(3), ones(3)))) end +struct MyClosure{A} + a::A +end + +function (mc::MyClosure)(x) + # computes x^2 using internal storage + mc.a[1] = x + return mc.a[1]^2 +end + +@testset "Batch Closure" begin + g = MyClosure([0.0]) + g_and_dgs = BatchDuplicated(g, (make_zero(g), make_zero(g))) + x_and_dxs = BatchDuplicated(3.0, (5.0, 7.0)) + autodiff(Forward, g_and_dgs, BatchDuplicated, x_and_dxs) # error +end + # @testset "Split Tape" begin # f(x) = x[1] * x[1] From 75a2f4c9fd072e4e6a11c52058c757aee9620a95 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Sep 2024 22:37:23 -0500 Subject: [PATCH 264/495] Improve zero-set location error (#1788) --- src/rules/llvmrules.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 9c4feb126c..fb93016063 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -327,7 +327,12 @@ end elSize = LLVM.zext!(B, elSize, LLVM.IntType(8*sizeof(Csize_t))) len = get_array_len(B, shadowin) length = LLVM.mul!(B, len, elSize) - GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type" + bt = GPUCompiler.backtrace(orig) + btstr = sprint() do io + print(io,"\nCaused by:") + Base.show_backtrace(io, bt) + end + GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type $btstr" LLVM.memset!(B, get_array_data(B, shadowres), LLVM.ConstantInt(i8, 0, false), length, algn) end if API.runtimeActivity() @@ -345,7 +350,12 @@ end elSize = LLVM.zext!(B, elSize, LLVM.IntType(8*sizeof(Csize_t))) len = get_array_len(B, ev) length = LLVM.mul!(B, len, elSize) - GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type" + bt = GPUCompiler.backtrace(orig) + btstr = sprint() do io + print(io,"\nCaused by:") + Base.show_backtrace(io, bt) + end + GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type $btstr" LLVM.memset!(B, get_array_data(B, callv), LLVM.ConstantInt(i8, 0, false), length, algn) end if API.runtimeActivity() From cdc790d0d25d2938fd520be2645dca6bbf33d711 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 3 Sep 2024 22:54:24 -0500 Subject: [PATCH 265/495] CompatHelper: bump compat for Enzyme_jll to 0.0.148, (keep existing compat) (#1787) Co-authored-by: CompatHelper Julia --- Project.toml | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 02636694ba..b84ba12996 100644 --- a/Project.toml +++ b/Project.toml @@ -16,26 +16,12 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -[weakdeps] -BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[extensions] -EnzymeBFloat16sExt = "BFloat16s" -EnzymeChainRulesCoreExt = "ChainRulesCore" -EnzymeLogExpFunctionsExt = "LogExpFunctions" -EnzymeSpecialFunctionsExt = "SpecialFunctions" -EnzymeStaticArraysExt = "StaticArrays" - [compat] BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.8" -Enzyme_jll = "0.0.146" +Enzyme_jll = "0.0.146, 0.0.148" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" @@ -45,9 +31,23 @@ SpecialFunctions = "1, 2" StaticArrays = "1" julia = "1.6" +[extensions] +EnzymeBFloat16sExt = "BFloat16s" +EnzymeChainRulesCoreExt = "ChainRulesCore" +EnzymeLogExpFunctionsExt = "LogExpFunctions" +EnzymeSpecialFunctionsExt = "SpecialFunctions" +EnzymeStaticArraysExt = "StaticArrays" + [extras] BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[weakdeps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" From 8dde0343d7b4fdc92ecf4fe4fafd7e7df7cf1427 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Sep 2024 23:38:08 -0500 Subject: [PATCH 266/495] Add names to object emission (#1789) --- src/compiler.jl | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index a9bcaddf53..ad2d943c21 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -748,7 +748,7 @@ declare_allocobj!(mod) = get_function!(mod, "julia.gc_alloc_obj") do LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue]) end end -function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround::Bool) +function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround::Bool, name::String="") curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -792,12 +792,12 @@ function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround:: alloc_obj, alty = declare_allocobj!(mod) @static if VERSION < v"1.8.0" - return call!(B, alty, alloc_obj, [ptls, Size, tag]) + return call!(B, alty, alloc_obj, [ptls, Size, tag], name) else - return call!(B, alty, alloc_obj, [ct, Size, tag]) + return call!(B, alty, alloc_obj, [ct, Size, tag], name) end end -function emit_allocobj!(B, T::DataType) +function emit_allocobj!(B, T::DataType, name::String="") curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -811,7 +811,7 @@ function emit_allocobj!(B, T::DataType) T_size_t = convert(LLVM.LLVMType, UInt) Size = LLVM.ConstantInt(T_size_t, sizeof(T)) - emit_allocobj!(B, tag, Size, #=needs_workaround=#false) + emit_allocobj!(B, tag, Size, #=needs_workaround=#false, name) end declare_pointerfromobjref!(mod) = get_function!(mod, "julia.pointer_from_objref") do T_jlvalue = LLVM.StructType(LLVMType[]) @@ -4499,7 +4499,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, convty = convert(LLVMType, T′; allow_boxed=true) if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType)) - al0 = al = emit_allocobj!(builder, Base.RefValue{T′}) + al0 = al = emit_allocobj!(builder, Base.RefValue{T′}, "mixedparameter") al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) store!(builder, params[i], al) emit_writebarrier!(builder, get_julia_inner_types(builder, al0, params[i])) @@ -4649,7 +4649,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue))) for idx in 1:width pv = (width == 1) ? eval : extract_value!(builder, eval, idx-1) - al0 = al = emit_allocobj!(builder, Base.RefValue{eltype(rettype)}) + al0 = al = emit_allocobj!(builder, Base.RefValue{eltype(rettype)}, "batchmixedret") llty = value_type(pv) al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) store!(builder, pv, al) @@ -5236,9 +5236,10 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if arg.arg_i in loweredArgs push!(nops, load!(builder, convert(LLVMType, arg.typ), parm)) elseif arg.arg_i in raisedArgs - obj = emit_allocobj!(builder, arg.typ) + obj = emit_allocobj!(builder, arg.typ, "raisedArg") bc = bitcast!(builder, obj, LLVM.PointerType(value_type(parm), addrspace(value_type(obj)))) store!(builder, parm, bc) + emit_writebarrier!(builder, get_julia_inner_types(builder, obj, parm)) addr = addrspacecast!(builder, bc, LLVM.PointerType(value_type(parm), Derived)) push!(nops, addr) else @@ -5374,7 +5375,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function ret!(builder, fill_val) else nobj = if sretPtr !== nothing - obj = emit_allocobj!(builder, jlrettype) + obj = emit_allocobj!(builder, jlrettype, "boxunion") llty = convert(LLVMType, jlrettype) ld = load!(builder, llty, bitcast!(builder, sretPtr, LLVM.PointerType(llty, addrspace(value_type(sretPtr))))) store!(builder, ld, bitcast!(builder, obj, LLVM.PointerType(llty, addrspace(value_type(obj))))) From a5ec75f2a9000d90a107a4b37f11d50f8f13671e Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 5 Sep 2024 23:01:19 -0500 Subject: [PATCH 267/495] Consider constant fp in runtime activity (#1797) * Consider constant fp in runtime activity * fix --- src/compiler.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index ad2d943c21..94a012bd2c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2459,6 +2459,9 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err return make_batched(ncur, prevbb) end end + if isa(cur, LLVM.ConstantFP) + return make_batched(ConstantFP(value_type(cur), 0), prevbb) + end if isa(cur, LLVM.ConstantDataSequential) cvals = LLVM.Value[] changed = false From 0307b78de83cff587be7f098afc016db1f5a6451 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Fri, 6 Sep 2024 00:02:46 -0400 Subject: [PATCH 268/495] Suggest workaround in error for overwritten active by ref (#1791) --- src/rules/customrules.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 749ec36cfd..0f77d37a9d 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -136,7 +136,12 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) if value_type(val) != eltype(value_type(ptr)) if overwritten[end] - emit_error(B, orig, "Enzyme: active by ref type $Ty is overwritten in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))") + emit_error( + B, + orig, + "Enzyme: active by ref type $Ty is overwritten in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr)). " + * "As a workaround until support for this is added, try passing values as separate arguments rather than as an aggregate of type $Ty.", + ) end if arty == eltype(value_type(val)) val = load!(B, arty, val) From b91fb0798532912bd7d666ebf5666a4769267e08 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 5 Sep 2024 23:08:46 -0500 Subject: [PATCH 269/495] Fix custom active reverse mode check (#1798) --- src/rules/customrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 0f77d37a9d..9628623987 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -1052,7 +1052,7 @@ end idx+=1 end else - Tys = (A <: Active ? eltype(A) : Nothing for A in activity[2+isKWCall:end]) + Tys = (A <: Active ? (width == 1 ? eltype(A) : NTuple{Int(width), eltype(A)}) : Nothing for A in activity[2+isKWCall:end]) ST = Tuple{Tys...} if rev_RT != ST emit_error(B, orig, "Enzyme: Reverse pass custom rule " * string(rev_TT) * " return type mismatch, expected "*string(ST)*" found "* string(rev_RT)) From 754937bacb860d6235c9d3ea86104649b838c5ff Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 6 Sep 2024 10:50:25 -0500 Subject: [PATCH 270/495] Look for more writebarrier opportunities (#1800) * Look for more writebarrier opportunities * Update compiler.jl --- src/compiler.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 94a012bd2c..c5684c2e7e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4505,7 +4505,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, al0 = al = emit_allocobj!(builder, Base.RefValue{T′}, "mixedparameter") al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) store!(builder, params[i], al) - emit_writebarrier!(builder, get_julia_inner_types(builder, al0, params[i])) + emit_writebarrier!(builder, get_julia_inner_types(builder, al0, params[i])) al = addrspacecast!(builder, al, LLVM.PointerType(llty, Derived)) push!(realparms, al) else @@ -5382,6 +5382,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function llty = convert(LLVMType, jlrettype) ld = load!(builder, llty, bitcast!(builder, sretPtr, LLVM.PointerType(llty, addrspace(value_type(sretPtr))))) store!(builder, ld, bitcast!(builder, obj, LLVM.PointerType(llty, addrspace(value_type(obj))))) + emit_writebarrier!(builder, get_julia_inner_types(builder, obj, ld)) # memcpy!(builder, bitcast!(builder, obj, LLVM.PointerType(T_int8, addrspace(value_type(obj)))), 0, bitcast!(builder, sretPtr, LLVM.PointerType(T_int8)), 0, LLVM.ConstantInt(T_int64, sizeof(jlrettype))) obj else From 14851efd29d85a5c0775ff14a409aadb3f4cf4f2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 12 Sep 2024 09:04:02 -0400 Subject: [PATCH 271/495] Restrict version to 1.10+ (#1809) * Restrict version to 1.10+ * fix * fixup * Update CI.yml * Update Project.toml * Update Project.toml --- .github/workflows/CI.yml | 45 +--- Project.toml | 6 +- lib/EnzymeTestUtils/Project.toml | 2 +- src/Enzyme.jl | 56 ++--- src/compiler.jl | 348 +++++++------------------------ src/compiler/interpreter.jl | 100 +-------- src/compiler/orcv1.jl | 181 ---------------- src/compiler/orcv2.jl | 9 +- src/compiler/reflection.jl | 29 +-- src/compiler/utils.jl | 67 ------ src/compiler/validation.jl | 67 ++---- src/internal_rules.jl | 16 +- src/rules/jitrules.jl | 14 -- src/rules/parallelrules.jl | 28 +-- src/typetree.jl | 12 +- src/utils.jl | 106 +--------- test/DiffTests.jl | 18 -- test/applyiter.jl | 2 - test/internal_rules.jl | 6 +- test/mixed.jl | 4 - test/rrules.jl | 3 - test/runtests.jl | 169 ++++++--------- 22 files changed, 184 insertions(+), 1104 deletions(-) delete mode 100644 src/compiler/orcv1.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5093cf3a5a..60d713c529 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,10 +21,6 @@ jobs: fail-fast: false matrix: version: - - '1.6' - - '1.7' - - '1.8' - - '1.9' - '1.10' - ~1.11.0-0 - 'nightly' @@ -42,46 +38,11 @@ jobs: arch: x64 libEnzyme: local include: - - os: ubuntu-20.04 - arch: x86 - libEnzyme: packaged - version: '1.6' - assertions: false - - os: ubuntu-20.04 - arch: x86 - libEnzyme: packaged - version: '1.7' - assertions: false - - os: ubuntu-20.04 - arch: x86 - libEnzyme: packaged - version: '1.8' - assertions: false - - os: ubuntu-20.04 - arch: x86 - libEnzyme: packaged - version: '1.9' - assertions: false - os: ubuntu-20.04 arch: x86 libEnzyme: packaged version: '1.10' assertions: false - - os: ubuntu-20.04 - arch: x64 - libEnzyme: packaged - version: '1.7' - assertions: true - - os: ubuntu-20.04 - arch: x64 - libEnzyme: packaged - version: '1.8' - assertions: true - - os: ubuntu-20.04 - arch: x64 - libEnzyme: packaged - version: '1.9' - assertions: true - os: ubuntu-20.04 arch: x64 libEnzyme: packaged @@ -125,7 +86,8 @@ jobs: shell: julia --color=yes --project=. {0} run: | using Pkg - Pkg.develop(path="lib/EnzymeCore") + Pkg.develop([PackageSpec(; path) for path in ("lib/EnzymeCore", "lib/EnzymeTestUtils")]) + Pkg.instantiate() env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager - name: Build libEnzyme @@ -172,9 +134,6 @@ jobs: fail-fast: false matrix: version: - - '1.7' - - '1.8' - - '1.9' - '1.10' - ~1.11.0-0 - 'nightly' diff --git a/Project.toml b/Project.toml index b84ba12996..15890547e1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.36" +version = "0.13.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -23,13 +23,13 @@ ChainRulesCore = "1" EnzymeCore = "0.7.8" Enzyme_jll = "0.0.146, 0.0.148" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" -LLVM = "6.1, 7, 8, 9" +LLVM = "6.1, 7, 8, =9.0" LogExpFunctions = "0.3" ObjectFile = "0.4" Preferences = "1.4" SpecialFunctions = "1, 2" StaticArrays = "1" -julia = "1.6" +julia = "1.10" [extensions] EnzymeBFloat16sExt = "BFloat16s" diff --git a/lib/EnzymeTestUtils/Project.toml b/lib/EnzymeTestUtils/Project.toml index 38b783facc..80dd2ede75 100644 --- a/lib/EnzymeTestUtils/Project.toml +++ b/lib/EnzymeTestUtils/Project.toml @@ -13,7 +13,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] ConstructionBase = "1.4.1" -Enzyme = "0.11, 0.12" +Enzyme = "0.11, 0.12, 0.13" EnzymeCore = "0.5, 0.6, 0.7" FiniteDifferences = "0.12.12" MetaTesting = "0.1" diff --git a/src/Enzyme.jl b/src/Enzyme.jl index b7d86b4705..450d96ffb0 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -249,11 +249,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) end rt = if A isa UnionAll - @static if VERSION >= v"1.8.0" - Compiler.primal_return_type(rmode, Val(codegen_world_age(FTy, tt)), FTy, tt) - else - Core.Compiler.return_type(f.val, tt) - end + Compiler.primal_return_type(rmode, Val(codegen_world_age(FTy, tt)), FTy, tt) else eltype(A) end @@ -339,7 +335,7 @@ Like [`autodiff`](@ref) but will try to guess the activity of the return value. """ @inline function autodiff(mode::CMode, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, CMode<:Mode, Nargs} tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - rt = if mode isa ReverseMode && VERSION >= v"1.8.0" + rt = if mode isa ReverseMode Compiler.primal_return_type(mode, Val(codegen_world_age(eltype(FA), tt)), eltype(FA), tt) else Core.Compiler.return_type(f.val, tt) @@ -556,7 +552,7 @@ Like [`autodiff_deferred`](@ref) but will try to guess the activity of the retur @inline function autodiff_deferred(mode::M, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, M<:Mode, Nargs} tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - rt = if mode isa ReverseMode && VERSION >= v"1.8.0" + rt = if mode isa ReverseMode Compiler.primal_return_type(mode, Val(codegen_world_age(eltype(FA), tt)), eltype(FA), tt) else Core.Compiler.return_type(f.val, tt) @@ -903,11 +899,7 @@ result, ∂v, ∂A end rt = if RT isa UnionAll - @static if VERSION < v"1.8-" - throw(MethodError(autodiff_deferred_thunk, (mode, tt, fa, a2, args...))) - else - RT{Core.Compiler.return_type(Tuple{eltype(FA), map(eltype, args)...})} - end + RT{Core.Compiler.return_type(Tuple{eltype(FA), map(eltype, args)...})} else @assert RT isa DataType RT @@ -1243,13 +1235,9 @@ of shape `size(input)` of values of the output type. inshape = size(x) outshape = size(cols[1]) # st : outshape x total inputs - st = @static if VERSION >= v"1.9" - Base.stack(cols) - else - reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) - end + st = Base.stack(cols) - st3 = if length(inshape) <= 1 || VERSION < v"1.9" + st3 = if length(inshape) <= 1 st else reshape(st, (outshape..., inshape...)) @@ -1279,13 +1267,9 @@ end inshape = size(x) outshape = size(cols[1]) # st : outshape x total inputs - st = @static if VERSION >= v"1.9" - Base.stack(cols) - else - reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) - end + st = Base.stack(cols) - st3 = if length(inshape) <= 1 || VERSION < v"1.9" + st3 = if length(inshape) <= 1 st else reshape(st, (outshape..., inshape...)) @@ -1311,13 +1295,9 @@ end inshape = size(x) outshape = size(cols[1]) # st : outshape x total inputs - st = @static if VERSION >= v"1.9" - Base.stack(cols) - else - reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) - end + st = Base.stack(cols) - st3 = if length(inshape) <= 1 || VERSION < v"1.9" + st3 = if length(inshape) <= 1 st else reshape(st, (outshape..., inshape...)) @@ -1416,13 +1396,9 @@ of shape `size(output)` of values of the input type. if x isa AbstractArray inshape = size(x) - st = @static if VERSION >= v"1.9" - Base.stack(rows) - else - reshape(cat(rows..., dims=length(inshape)), (inshape..., outshape...)) - end + st = Base.stack(rows) - st2 = if length(outshape) == 1 || VERSION < v"1.9" + st2 = if length(outshape) == 1 st else reshape(st, (inshape..., outshape...)) @@ -1469,13 +1445,9 @@ end outshape = tmp[1][2] if x isa AbstractArray inshape = size(x) - st = @static if VERSION >= v"1.9" - Base.stack(rows) - else - reshape(cat(rows..., dims=length(inshape)), (inshape..., outshape...)) - end + st = Base.stack(rows) - st2 = if length(outshape) == 1 || VERSION < v"1.9" + st2 = if length(outshape) == 1 st else reshape(st, (inshape..., outshape...)) diff --git a/src/compiler.jl b/src/compiler.jl index c5684c2e7e..f82ae6c135 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -31,27 +31,14 @@ function cpu_name() end function cpu_features() - if VERSION >= v"1.10.0-beta1" - return ccall(:jl_get_cpu_features, String, ()) - end - - @static if Sys.ARCH == :x86_64 || - Sys.ARCH == :x86 - return "+mmx,+sse,+sse2,+fxsr,+cx8" # mandated by Julia - else - return "" - end + return ccall(:jl_get_cpu_features, String, ()) end import GPUCompiler: @safe_debug, @safe_info, @safe_warn, @safe_error include("compiler/utils.jl") -if v"8" <= LLVM.version() < v"12" - include("compiler/orcv1.jl") -else - include("compiler/orcv2.jl") -end +include("compiler/orcv2.jl") include("gradientutils.jl") @@ -97,11 +84,9 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( typeof(Base.FastMath.cosh_fast) => (:cosh, 1, nothing), typeof(Base.tanh) => (:tanh, 1, nothing), typeof(Base.ldexp) => (:ldexp, 2, nothing), - typeof(Base.FastMath.tanh_fast) => (:tanh, 1, nothing) + typeof(Base.FastMath.tanh_fast) => (:tanh, 1, nothing), + typeof(Base.fma_emulated) => (:fma, 3, nothing) ) -@static if VERSION >= v"1.8.0" - known_ops[typeof(Base.fma_emulated)] = (:fma, 3, nothing) -end @inline function find_math_method(@nospecialize(func), sparam_vals) if func ∈ keys(known_ops) name, arity, toinject = known_ops[func] @@ -425,39 +410,7 @@ end @inline is_arrayorvararg_ty(::Type{IdDict{K, V} where K}) where {V} = true @inline function datatype_fieldcount(t::Type{T}) where T - @static if VERSION < v"1.10.0" - NT = @static if VERSION < v"1.9.0" - Base.NamedTuple_typename - else - Base._NAMEDTUPLE_NAME - end - if t.name === NT - names, types = t.parameters[1], t.parameters[2] - if names isa Tuple - return length(names) - end - if types isa DataType && types <: Tuple - return datatype_fieldcount(types) - end - return nothing - else - @static if VERSION < v"1.7.0" - if t.abstract || (t.name === Tuple.name && Base.isvatuple(t)) - return nothing - end - else - if isabstracttype(t) || (t.name === Tuple.name && Base.isvatuple(t)) - return nothing - end - end - end - if isdefined(t, :types) - return length(t.types) - end - return length(t.name.names) - else - return Base.datatype_fieldcount(t) - end + return Base.datatype_fieldcount(t) end @inline function staticInTup(::Val{T}, tup::NTuple{N, Val}) where {T, N} @@ -608,24 +561,20 @@ end throw(AssertionError("Type $T is not concrete type or concrete tuple")) end - @static if VERSION < v"1.7.0" - nT = T + nT = if T <: Tuple && T != Tuple && !(T isa UnionAll) + Tuple{(ntuple(length(T.parameters)) do i + Base.@_inline_meta + sT = T.parameters[i] + if sT isa TypeVar + Any + elseif sT isa Core.TypeofVararg + Any + else + sT + end + end)...} else - nT = if T <: Tuple && T != Tuple && !(T isa UnionAll) - Tuple{(ntuple(length(T.parameters)) do i - Base.@_inline_meta - sT = T.parameters[i] - if sT isa TypeVar - Any - elseif sT isa Core.TypeofVararg - Any - else - sT - end - end)...} - else - T - end + T end if staticInTup(Val(nT), seen) @@ -740,13 +689,8 @@ declare_allocobj!(mod) = get_function!(mod, "julia.gc_alloc_obj") do T_ppjlvalue = LLVM.PointerType(LLVM.PointerType(T_jlvalue)) T_size_t = convert(LLVM.LLVMType, Int) - @static if VERSION < v"1.8.0" - T_int8 = LLVM.Int8Type() - T_pint8 = LLVM.PointerType(T_int8) - LLVM.FunctionType(T_prjlvalue, [T_pint8, T_size_t, T_prjlvalue]) - else - LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue]) - end + + LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue]) end function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround::Bool, name::String="") curent_bb = position(B) @@ -760,21 +704,16 @@ function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround:: T_int8 = LLVM.Int8Type() T_pint8 = LLVM.PointerType(T_int8) - @static if VERSION < v"1.7.0" - ptls = reinsert_gcmarker!(fn, B) - ptls = bitcast!(B, ptls, T_pint8) - else - pgcstack = reinsert_gcmarker!(fn, B) - ct = inbounds_gep!(B, - T_pjlvalue, - bitcast!(B, pgcstack, T_ppjlvalue), - [LLVM.ConstantInt(current_task_offset())]) - ptls_field = inbounds_gep!(B, - T_pjlvalue, - ct, [LLVM.ConstantInt(current_ptls_offset())]) - T_ppint8 = LLVM.PointerType(T_pint8) - ptls = load!(B, T_pint8, bitcast!(B, ptls_field, T_ppint8)) - end + pgcstack = reinsert_gcmarker!(fn, B) + ct = inbounds_gep!(B, + T_pjlvalue, + bitcast!(B, pgcstack, T_ppjlvalue), + [LLVM.ConstantInt(current_task_offset())]) + ptls_field = inbounds_gep!(B, + T_pjlvalue, + ct, [LLVM.ConstantInt(current_ptls_offset())]) + T_ppint8 = LLVM.PointerType(T_pint8) + ptls = load!(B, T_pint8, bitcast!(B, ptls_field, T_ppint8)) if needs_workaround T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -791,11 +730,7 @@ function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround:: alloc_obj, alty = declare_allocobj!(mod) - @static if VERSION < v"1.8.0" - return call!(B, alty, alloc_obj, [ptls, Size, tag], name) - else - return call!(B, alty, alloc_obj, [ct, Size, tag], name) - end + return call!(B, alty, alloc_obj, [ct, Size, tag], name) end function emit_allocobj!(B, T::DataType, name::String="") curent_bb = position(B) @@ -832,19 +767,11 @@ declare_writebarrier!(mod) = get_function!(mod, "julia.write_barrier") do T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) LLVM.FunctionType(LLVM.VoidType(), [T_prjlvalue]; vararg=true) end -@static if VERSION < v"1.8.0" -declare_apply_generic!(mod) = get_function!(mod, "jl_apply_generic") do - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, LLVM.PointerType(T_prjlvalue), LLVM.Int32Type()]) -end -else declare_apply_generic!(mod) = get_function!(mod, "ijl_apply_generic") do T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, LLVM.PointerType(T_prjlvalue), LLVM.Int32Type()]) end -end declare_juliacall!(mod) = get_function!(mod, "julia.call") do T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -877,17 +804,10 @@ function emit_getfield!(B::LLVM.IRBuilder, val::LLVM.Value, fld::LLVM.Value)::LL args = [val, fld] - @static if VERSION < v"1.9.0-" - FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue]; vararg=true) - inv = bitcast!(B, inv, LLVM.PointerType(FT)) - res = call!(B, FT, inv, args) - LLVM.callconv!(res, 37) - else - julia_call, FT = get_function!(mod, "julia.call", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(gen_FT), T_prjlvalue]; vararg=true)) - res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) - end + julia_call, FT = get_function!(mod, "julia.call", + LLVM.FunctionType(T_prjlvalue, + [LLVM.PointerType(gen_FT), T_prjlvalue]; vararg=true)) + res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) return res end @@ -930,11 +850,7 @@ function emit_box_int32!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value T_int32 = LLVM.Int32Type() FT = LLVM.FunctionType(T_prjlvalue, [T_int32]) - @static if VERSION < v"1.8-" - box_int32, _ = get_function!(mod, "jl_box_int32", FT) - else - box_int32, _ = get_function!(mod, "ijl_box_int32", FT) - end + box_int32, _ = get_function!(mod, "ijl_box_int32", FT) call!(B, FT, box_int32, [val]) end @@ -948,11 +864,7 @@ function emit_box_int64!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value T_int64 = LLVM.Int64Type() FT = LLVM.FunctionType(T_prjlvalue, [T_int64]) - @static if VERSION < v"1.8-" - box_int64, _ = get_function!(mod, "jl_box_int64", FT) - else - box_int64, _ = get_function!(mod, "ijl_box_int64", FT) - end + box_int64, _ = get_function!(mod, "ijl_box_int64", FT) call!(B, FT, box_int64, [val]) end @@ -967,25 +879,13 @@ function emit_apply_generic!(B::LLVM.IRBuilder, args)::LLVM.Value T_int32 = LLVM.Int32Type() gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32]) - @static if VERSION < v"1.8-" - inv, _ = get_function!(mod, "jl_apply_generic", gen_FT) - else - inv, _ = get_function!(mod, "ijl_apply_generic", gen_FT) - end + inv, _ = get_function!(mod, "ijl_apply_generic", gen_FT) - @static if VERSION < v"1.9.0-" - FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue]; vararg=true) - inv = bitcast!(B, inv, LLVM.PointerType(FT)) - # call cc37 nonnull {}* bitcast ({}* ({}*, {}**, i32)* @jl_f_apply_type to {}* ({}*, {}*, {}*, {}*)*)({}* null, {}* inttoptr (i64 140150176657296 to {}*), {}* %4, {}* inttoptr (i64 140149987564368 to {}*)) - res = call!(B, FT, inv, args) - LLVM.callconv!(res, 37) - else - # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!(mod, "julia.call", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(gen_FT), T_prjlvalue]; vararg=true)) - res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) - end + # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) + julia_call, FT = get_function!(mod, "julia.call", + LLVM.FunctionType(T_prjlvalue, + [LLVM.PointerType(gen_FT), T_prjlvalue]; vararg=true)) + res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) return res end @@ -1001,25 +901,13 @@ function emit_invoke!(B::LLVM.IRBuilder, args)::LLVM.Value # {} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)* @ijl_invoke gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32, T_prjlvalue]) - @static if VERSION < v"1.8-" - inv = get_function!(mod, "jl_invoke", gen_FT) - else - inv = get_function!(mod, "ijl_invoke", gen_FT) - end + inv = get_function!(mod, "ijl_invoke", gen_FT) - @static if VERSION < v"1.9.0-" - FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue]; vararg=true) - inv = bitcast!(B, inv, LLVM.PointerType(FT)) - # call cc37 nonnull {}* bitcast ({}* ({}*, {}**, i32)* @jl_f_apply_type to {}* ({}*, {}*, {}*, {}*)*)({}* null, {}* inttoptr (i64 140150176657296 to {}*), {}* %4, {}* inttoptr (i64 140149987564368 to {}*)) - res = call!(B, FT, inv, args) - LLVM.callconv!(res, 38) - else - # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!(mod, "julia.call2", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) - res = call!(B, FT, julia_call, [inv, args...]) - end + # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) + julia_call, FT = get_function!(mod, "julia.call2", + LLVM.FunctionType(T_prjlvalue, + [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) + res = call!(B, FT, julia_call, [inv, args...]) return res end @@ -1104,7 +992,6 @@ function Base.showerror(io::IO, ece::EnzymeNoDerivativeError) print(io, msg, '\n') end -@static if VERSION >= v"1.8.0" const JuliaEnzymeNameMap = Dict{String, Any}( "enz_val_true" => Val(true), "enz_val_false" => Val(false), @@ -1119,9 +1006,6 @@ const JuliaEnzymeNameMap = Dict{String, Any}( "enz_no_shadow_exc" => EnzymeNoShadowError, "enz_no_derivative_exc" => EnzymeNoDerivativeError, ) -else -const JuliaEnzymeNameMap = Dict{String, Any}() -end const JuliaGlobalNameMap = Dict{String, Any}( "jl_type_type" => Type, @@ -1204,17 +1088,11 @@ const JuliaGlobalNameMap = Dict{String, Any}( "jl_nothing" => nothing, "jl_anytuple_type" => Tuple, + "jl_vararg_type" => Core.TypeofVararg, + "jl_opaque_closure_type" => Core.OpaqueClosure, + "jl_array_uint64_type" => Array{UInt64, 1}, + "jl_binding_type" => Core.Binding ) -@static if VERSION >= v"1.7.0" - JuliaGlobalNameMap["jl_vararg_type"] = Core.TypeofVararg - JuliaGlobalNameMap["jl_opaque_closure_type"] = Core.OpaqueClosure -end -@static if VERSION >= v"1.8.0" - JuliaGlobalNameMap["jl_array_uint64_type"] = Array{UInt64, 1} -end -@static if VERSION >= v"1.10.0" - JuliaGlobalNameMap["jl_binding_type"] = Core.Binding -end include("absint.jl") @@ -1248,19 +1126,11 @@ function emit_apply_type!(B::LLVM.IRBuilder, Ty, args)::LLVM.Value f_apply_type, _ = get_function!(mod, "jl_f_apply_type", generic_FT) Ty = unsafe_to_llvm(B, Ty) - @static if VERSION < v"1.9.0-" - FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue]; vararg=true) - f_apply_type = bitcast!(B, f_apply_type, LLVM.PointerType(FT)) - # call cc37 nonnull {}* bitcast ({}* ({}*, {}**, i32)* @jl_f_apply_type to {}* ({}*, {}*, {}*, {}*)*)({}* null, {}* inttoptr (i64 140150176657296 to {}*), {}* %4, {}* inttoptr (i64 140149987564368 to {}*)) - tag = call!(B, FT, f_apply_type, LLVM.Value[LLVM.PointerNull(T_prjlvalue), Ty, args...]) - LLVM.callconv!(tag, 37) - else - # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!(mod, "julia.call", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) - tag = call!(B, FT, julia_call, LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), Ty, args...]) - end + # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) + julia_call, FT = get_function!(mod, "julia.call", + LLVM.FunctionType(T_prjlvalue, + [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) + tag = call!(B, FT, julia_call, LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), Ty, args...]) return tag end @@ -1293,19 +1163,11 @@ function emit_tuple!(B, args)::LLVM.Value generic_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32]) f_apply_type, _ = get_function!(mod, "jl_f_tuple", generic_FT) - @static if VERSION < v"1.9.0-" - FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue]; vararg=true) - f_apply_type = bitcast!(B, f_apply_type, LLVM.PointerType(FT)) - # call cc37 nonnull {}* bitcast ({}* ({}*, {}**, i32)* @jl_f_apply_type to {}* ({}*, {}*, {}*, {}*)*)({}* null, {}* inttoptr (i64 140150176657296 to {}*), {}* %4, {}* inttoptr (i64 140149987564368 to {}*)) - tag = call!(B, FT, f_apply_type, LLVM.Value[LLVM.PointerNull(T_prjlvalue), args...]) - LLVM.callconv!(tag, 37) - else - # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!(mod, "julia.call", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) - tag = call!(B, FT, julia_call, LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), args...]) - end + # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) + julia_call, FT = get_function!(mod, "julia.call", + LLVM.FunctionType(T_prjlvalue, + [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) + tag = call!(B, FT, julia_call, LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), args...]) return tag end @@ -1361,24 +1223,15 @@ function emit_methodinstance!(B::LLVM.IRBuilder, func, args)::LLVM.Value T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - @static if VERSION < v"1.8.0-" - worlds, FT = get_function!(mod, "jl_gf_invoke_lookup_worlds", - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, sizeT, psizeT, psizeT])) - else worlds, FT = get_function!(mod, "jl_gf_invoke_lookup_worlds", LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue, sizeT, psizeT, psizeT])) - end EB = LLVM.IRBuilder() position!(EB, first(LLVM.instructions(LLVM.entry(fn)))) minworld = alloca!(EB, sizeT) maxworld = alloca!(EB, sizeT) store!(B, LLVM.ConstantInt(sizeT, 0), minworld) store!(B, LLVM.ConstantInt(sizeT, -1), maxworld) - @static if VERSION < v"1.8.0-" - methodmatch = call!(B, FT, worlds, LLVM.Value[tag, LLVM.ConstantInt(sizeT, world), minworld, maxworld]) - else methodmatch = call!(B, FT, worlds, LLVM.Value[tag, unsafe_to_llvm(B, nothing), LLVM.ConstantInt(sizeT, world), minworld, maxworld]) - end # emit_jl!(B, methodmatch) # emit_jl!(B, emit_jltypeof!(B, methodmatch)) offset = 1 @@ -2849,38 +2702,10 @@ function from_tape_type(::Type{B}) where {B<:Tuple} end # See get_current_task_from_pgcstack (used from 1.7+) -if VERSION >= v"1.9.1" - current_task_offset() = -(unsafe_load(cglobal(:jl_task_gcstack_offset, Cint)) ÷ sizeof(Ptr{Cvoid})) -elseif VERSION >= v"1.9.0" - if Sys.WORD_SIZE == 64 - current_task_offset() = -13 - else - current_task_offset() = -18 - end -else - if Sys.WORD_SIZE == 64 - current_task_offset() = -12 #1.8/1.7 - else - current_task_offset() = -17 #1.8/1.7 - end -end +current_task_offset() = -(unsafe_load(cglobal(:jl_task_gcstack_offset, Cint)) ÷ sizeof(Ptr{Cvoid})) # See get_current_ptls_from_task (used from 1.7+) -if VERSION >= v"1.9.1" - current_ptls_offset() = unsafe_load(cglobal(:jl_task_ptls_offset, Cint)) ÷ sizeof(Ptr{Cvoid}) -elseif VERSION >= v"1.9.0" - if Sys.WORD_SIZE == 64 - current_ptls_offset() = 15 - else - current_ptls_offset() = 20 - end -else - if Sys.WORD_SIZE == 64 - current_ptls_offset() = 14 # 1.8/1.7 - else - current_ptls_offset() = 19 - end -end +current_ptls_offset() = unsafe_load(cglobal(:jl_task_ptls_offset, Cint)) ÷ sizeof(Ptr{Cvoid}) function store_nonjl_types!(B, startval, p) T_jlvalue = LLVM.StructType(LLVMType[]) @@ -3309,7 +3134,7 @@ function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) # Check if Julia version has https://github.com/JuliaLang/julia/pull/46914 # and also https://github.com/JuliaLang/julia/pull/47076 # and also https://github.com/JuliaLang/julia/pull/48620 - @static if VERSION >= v"1.10.0-DEV.569" + @static if VERSION >= v"1.10.5" needs_dynamic_size_workaround = false else needs_dynamic_size_workaround = !isa(Size, LLVM.ConstantInt) || convert(Int, Size) != 1 @@ -3555,9 +3380,7 @@ function enzyme_ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) end end -@static if VERSION < v"1.8" GPUCompiler.ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = enzyme_ci_cache(job) -end GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = Interpreter.EnzymeInterpreter(enzyme_ci_cache(job), GPUCompiler.method_table(job), job.world, job.config.params.mode) @@ -5601,14 +5424,12 @@ end using Random # returns arg, return function no_type_setting(@nospecialize(specTypes); world=nothing) - @static if VERSION >= v"1.7.0-" - # Even though the julia type here is ptr{int8}, the actual data can be something else - if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_simd) - return (true, false) - end - if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_nosimd) - return (true, false) - end + # Even though the julia type here is ptr{int8}, the actual data can be something else + if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_simd) + return (true, false) + end + if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_nosimd) + return (true, false) end return (false, false) end @@ -7214,11 +7035,6 @@ function _link(job, (mod, adjoint_name, primal_name, TapeType)) # Now invoke the JIT jitted_mod = JIT.add!(mod) - #if VERSION >= v"1.9.0-DEV.115" - # LLVM.dispose(ctx) - #else - # # we cannot dispose of the global unique context - #end adjoint_addr = JIT.lookup(jitted_mod, adjoint_name) adjoint_ptr = pointer(adjoint_addr) @@ -7382,38 +7198,26 @@ end @inline function thunk(mi::Core.MethodInstance, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, ABI, ErrIfFuncWritten} ts_ctx = JuliaContext() - ctx = @static if VERSION >= v"1.9.0-DEV.115" - context(ts_ctx) - else - ts_ctx - end + ctx = context(ts_ctx) activate(ctx) try return thunkbase(ctx, mi, Val(#=World=#nothing), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten)) finally deactivate(ctx) - @static if VERSION >= v"1.9.0-DEV.115" - dispose(ts_ctx) - end + dispose(ts_ctx) end end @inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI, ErrIfFuncWritten} mi = fspec(eltype(FA), TT, World) ts_ctx = JuliaContext() - ctx = @static if VERSION >= v"1.9.0-DEV.115" - context(ts_ctx) - else - ts_ctx - end + ctx = context(ts_ctx) activate(ctx) res = try thunkbase(ctx, mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten)) finally deactivate(ctx) - @static if VERSION >= v"1.9.0-DEV.115" - dispose(ts_ctx) - end + dispose(ts_ctx) end return quote Base.@_inline_meta diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index e1652c5895..08b42d587b 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -52,8 +52,7 @@ function EnzymeInterpreter(cache_or_token, mt::Union{Nothing,Core.MethodTable}, # parameters for inference and optimization InferenceParams(unoptimize_throw_blocks=false), - VERSION >= v"1.8.0-DEV.486" ? OptimizationParams() : - OptimizationParams(unoptimize_throw_blocks=false), + OptimizationParams(), mode ) end @@ -82,9 +81,7 @@ Core.Compiler.may_compress(interp::EnzymeInterpreter) = true # but as far as I understand Enzyme wants "always inlining, except special cased functions", # so I guess we really don't want to discard sources? Core.Compiler.may_discard_trees(interp::EnzymeInterpreter) = false -if VERSION >= v"1.7.0-DEV.577" Core.Compiler.verbose_stmt_info(interp::EnzymeInterpreter) = false -end if isdefined(Base.Experimental, Symbol("@overlay")) Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = @@ -123,21 +120,7 @@ function is_primitive_func(@nospecialize(TT)) end function isKWCallSignature(@nospecialize(TT)) - if VERSION >= v"1.9.0-DEV.1598" - return TT <: Tuple{typeof(Core.kwcall), Any, Any, Vararg} - else - if hasproperty(TT, :parameters) && length(TT.parameters) >= 3 - kwftype = TT.parameters[1] - ft = TT.parameters[3] - if ccall(:jl_argument_method_table, Any, (Any,), ft) === nothing - return false - end - if Core.kwftype(ft) == kwftype - return true - end - end - return false - end + return TT <: Tuple{typeof(Core.kwcall), Any, Any, Vararg} end function simplify_kw(specTypes) @@ -149,8 +132,6 @@ function simplify_kw(specTypes) end # https://github.com/JuliaLang/julia/pull/46965 -@static if VERSION ≥ v"1.9.0-DEV.1535" - import Core.Compiler: CallInfo function Core.Compiler.inlining_policy(interp::EnzymeInterpreter, @nospecialize(src), @nospecialize(info::CallInfo), stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) @@ -190,81 +171,4 @@ function Core.Compiler.inlining_policy(interp::EnzymeInterpreter, src::Any, info::CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) end -# https://github.com/JuliaLang/julia/pull/41328 -elseif isdefined(Core.Compiler, :is_stmt_inline) - -function Core.Compiler.inlining_policy(interp::EnzymeInterpreter, - @nospecialize(src), stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) - - method_table = Core.Compiler.method_table(interp) - specTypes = simplify_kw(mi.specTypes) - - if is_primitive_func(specTypes) - return nothing - end - - if is_alwaysinline_func(specTypes) - @assert src !== nothing - return src - end - - if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) - return nothing - end - if interp.mode == API.DEM_ForwardMode - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) - return nothing - end - else - if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) - return nothing - end - end - - return Base.@invoke Core.Compiler.inlining_policy(interp::AbstractInterpreter, - src::Any, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) -end - -elseif isdefined(Core.Compiler, :inlining_policy) - -import Core.Compiler: InliningTodo, InliningState -struct EnzymeInliningPolicy - interp::EnzymeInterpreter -end -(::EnzymeInliningPolicy)(@nospecialize(src)) = Core.Compiler.default_inlining_policy(src) -Core.Compiler.inlining_policy(interp::EnzymeInterpreter) = EnzymeInliningPolicy(interp) - -function Core.Compiler.resolve_todo(todo::InliningTodo, state::InliningState{S, T, <:EnzymeInliningPolicy}) where {S<:Union{Nothing, Core.Compiler.EdgeTracker}, T} - mi = todo.mi - specTypes = simplify_kw(mi.specTypes) - - if is_primitive_func(specTypes) - return Core.Compiler.compileable_specialization(state.et, todo.spec.match) - end - - if is_alwaysinline_func(specTypes) - @assert false "Need to mark resolve_todo function as alwaysinline, but don't know how" - end - - interp = state.policy.interp - method_table = Core.Compiler.method_table(interp) - if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) - return Core.Compiler.compileable_specialization(state.et, todo.spec.match) - end - if interp.mode == API.DEM_ForwardMode - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) - return Core.Compiler.compileable_specialization(state.et, todo.spec.match) - end - else - if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) - return Core.Compiler.compileable_specialization(state.et, todo.spec.match) - end - end - - return Base.@invoke Core.Compiler.resolve_todo( - todo::InliningTodo, state::InliningState) -end - -end # @static if isdefined(Core.Compiler, :is_stmt_inline) - end diff --git a/src/compiler/orcv1.jl b/src/compiler/orcv1.jl deleted file mode 100644 index bcac867e73..0000000000 --- a/src/compiler/orcv1.jl +++ /dev/null @@ -1,181 +0,0 @@ -module JIT - -using LLVM -using Libdl -import LLVM: TargetMachine - -import GPUCompiler: CompilerJob, JuliaContext -import ..Compiler -import ..Compiler: API, cpu_name, cpu_features - -export get_trampoline - -# We have one global JIT and TM -const jit = Ref{OrcJIT}() -const tm = Ref{TargetMachine}() - -get_tm() = tm[] - -function __init__() - opt_level = Base.JLOptions().opt_level - if opt_level < 2 - optlevel = LLVM.API.LLVMCodeGenLevelNone - elseif opt_level == 2 - optlevel = LLVM.API.LLVMCodeGenLevelDefault - else - optlevel = LLVM.API.LLVMCodeGenLevelAggressive - end - - tm[] = LLVM.JITTargetMachine(LLVM.triple(), cpu_name(), cpu_features(); optlevel) - LLVM.asm_verbosity!(tm[], true) - - jit[] = OrcJIT(tm[]) # takes ownership of tm - - if haskey(ENV, "ENABLE_GDBLISTENER") - LLVM.register!(jit[], LLVM.GDBRegistrationListener()) - end - atexit() do - dispose(jit[]) - end -end - -mutable struct CallbackContext - job::CompilerJob - stub::Symbol - l_job::ReentrantLock - addr::Ptr{Cvoid} - CallbackContext(job, stub, l_job) = new(job, stub, l_job, C_NULL) -end - -const l_outstanding = Base.ReentrantLock() -const outstanding = Base.IdSet{CallbackContext}() - -# Setup the lazy callback for creating a module -function callback(orc_ref::LLVM.API.LLVMOrcJITStackRef, callback_ctx::Ptr{Cvoid}) - JuliaContext() do ctx - orc = OrcJIT(orc_ref) - cc = Base.unsafe_pointer_to_objref(callback_ctx)::CallbackContext - - # 1. Lock job - lock(cc.l_job) - - # 2. lookup if we are the first - lock(l_outstanding) - if in(cc, outstanding) - delete!(outstanding, cc) - else - unlock(l_outstanding) - unlock(cc.l_job) - - # 3. We are the second callback to run, but we raced the other one - # thus we return the addr from them. - @assert cc.addr != C_NULL - return UInt64(reinterpret(UInt, cc.addr)) - end - unlock(l_outstanding) - - try - thunk = Compiler._link(cc.job, Compiler._thunk(cc.job)) - mode = cc.job.config.params.mode - use_primal = mode == API.DEM_ReverseModePrimal - cc.addr = use_primal ? thunk.primal : thunk.adjoint - - # 4. Update the stub pointer to point to the recently compiled module - set_stub!(orc, string(cc.stub), cc.addr) - finally - unlock(cc.l_job) - end - - # 5. Return the address of the implementation, since we are going to call it now - @assert cc.addr != C_NULL - return UInt64(reinterpret(UInt, cc.addr)) - end -end - -function get_trampoline(job) - l_job = Base.ReentrantLock() - - cc = CallbackContext(job, gensym(:func), l_job) - lock(l_outstanding) - push!(outstanding, cc) - unlock(l_outstanding) - - c_callback = @cfunction(callback, UInt64, (LLVM.API.LLVMOrcJITStackRef, Ptr{Cvoid})) - - orc = jit[] - addr_adjoint = callback!(orc, c_callback, pointer_from_objref(cc)) - create_stub!(orc, string(cc.stub), addr_adjoint) - - return address(orc, string(cc.stub)) -end - - -function resolver(name, ctx) - name = unsafe_string(name) - ptr = try - ## Step 0: Should have already resolved it iff it was in the - ## same module - ## Step 1: See if it's something known to the execution enging - # TODO: Do we need to do this? - # address(jit[], name) - - ## Step 2: Search the program symbols - # - # SearchForAddressOfSymbol expects an unmangled 'C' symbol name. - # Iff we are on Darwin, strip the leading '_' off. - @static if Sys.isapple() - if name[1] == '_' - name = name[2:end] - end - end - - found = false - val = nothing - - hnd = unsafe_load(cglobal(:jl_libjulia_handle, Ptr{Cvoid})) - for (k, v) in Compiler.JuliaGlobalNameMap - if "ejl_"*k == name - val = unsafe_load(Base.reinterpret(Ptr{Ptr{Cvoid}}, Libdl.dlsym(hnd, k))) - found = true - break - end - end - - if !found - for (k, v) in Compiler.JuliaEnzymeNameMap - if "ejl_"*k == name - val = Compiler.unsafe_to_ptr(v) - found = true - break - end - end - end - - if found - val - else - LLVM.API.LLVMSearchForAddressOfSymbol(name) - end - ## Step 4: Lookup in libatomic - # TODO: Do we need to do this? - catch ex - @error "Enzyme: Lookup failed" name exception=(ex, Base.catch_backtrace()) - C_NULL - end - if ptr === C_NULL - @show name - error("Enzyme: Symbol lookup failed. Aborting!") - end - - return UInt64(reinterpret(UInt, ptr)) -end - -function add!(mod) - return compile!(jit[], mod, @cfunction(resolver, UInt64, (Cstring, Ptr{Cvoid}))) -end - -function lookup(jitted_mod, name) - return LLVM.addressin(jit[], jitted_mod, name) -end - -end diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index d36b1ca1c1..40d13eea80 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -1,3 +1,4 @@ + module JIT using LLVM @@ -9,7 +10,7 @@ import ..Compiler import ..Compiler: API, cpu_name, cpu_features @inline function use_ojit() - return (VERSION >= v"1.10.0-DEV.1395") && !Sys.iswindows() + return !Sys.iswindows() end export get_trampoline @@ -132,11 +133,7 @@ function __init__() jit[] = CompilerInstance(lljit, nothing, nothing) end - hnd = @static if VERSION >= v"1.10" - unsafe_load(cglobal(:jl_libjulia_handle, Ptr{Cvoid})) - else - Libdl.dlopen("libjulia") - end + hnd = unsafe_load(cglobal(:jl_libjulia_handle, Ptr{Cvoid})) for (k, v) in Compiler.JuliaGlobalNameMap ptr = unsafe_load(Base.reinterpret(Ptr{Ptr{Cvoid}}, Libdl.dlsym(hnd, k))) LLVM.define(jd_main, absolute_symbol_materialization(mangle(lljit, "ejl_"*k), ptr)) diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl index 5b7100f887..944b0b2498 100644 --- a/src/compiler/reflection.jl +++ b/src/compiler/reflection.jl @@ -35,7 +35,6 @@ function reflect(@nospecialize(func), @nospecialize(A), @nospecialize(types); return llvmf, mod end -# For VERSION >= v"1.9.0-DEV.516" struct jl_llvmf_dump TSM::LLVM.API.LLVMOrcThreadSafeModuleRef F::LLVM.API.LLVMValueRef @@ -46,30 +45,12 @@ function enzyme_code_llvm(io::IO, @nospecialize(func), @nospecialize(A), @nospec raw::Bool=false, debuginfo::Symbol=:default, dump_module::Bool=false, mode=API.DEM_ReverseModeCombined) JuliaContext() do ctx entry_fn, ir = reflect(func, A, types; optimize, run_enzyme, second_stage, mode) - @static if VERSION >= v"1.9.0-DEV.516" - ts_mod = ThreadSafeModule(ir) - if VERSION >= v"1.9.0-DEV.672" - GC.@preserve ts_mod entry_fn begin - value = Ref(jl_llvmf_dump(ts_mod.ref, entry_fn.ref)) - str = ccall(:jl_dump_function_ir, Ref{String}, - (Ptr{jl_llvmf_dump}, Bool, Bool, Ptr{UInt8}), - value, !raw, dump_module, debuginfo) - end - else - GC.@preserve ts_mod entry_fn begin - # N.B. jl_dump_function_ir will `Libc.free` the passed-in pointer - value_ptr = reinterpret(Ptr{jl_llvmf_dump}, - Libc.malloc(sizeof(jl_llvmf_dump))) - unsafe_store!(value_ptr, jl_llvmf_dump(ts_mod.ref, entry_fn.ref)) - str = ccall(:jl_dump_function_ir, Ref{String}, - (Ptr{jl_llvmf_dump}, Bool, Bool, Ptr{UInt8}), - value_ptr, !raw, dump_module, debuginfo) - end - end - else + ts_mod = ThreadSafeModule(ir) + GC.@preserve ts_mod entry_fn begin + value = Ref(jl_llvmf_dump(ts_mod.ref, entry_fn.ref)) str = ccall(:jl_dump_function_ir, Ref{String}, - (LLVM.API.LLVMValueRef, Bool, Bool, Ptr{UInt8}), - entry_fn, !raw, dump_module, debuginfo) + (Ptr{jl_llvmf_dump}, Bool, Bool, Ptr{UInt8}), + value, !raw, dump_module, debuginfo) end print(io, str) end diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index e4825e5226..6615b6bd40 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -261,72 +261,6 @@ T_ppjlvalue() = LLVM.PointerType(LLVM.PointerType(LLVM.StructType(LLVMType[]))) return v end -if VERSION < v"1.7.0-DEV.1205" - -declare_ptls!(mod) = get_function!(mod, "julia.ptls_states", LLVM.FunctionType(LLVM.PointerType(T_ppjlvalue()))) - -function emit_ptls!(B) - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - func, fty = declare_ptls!(mod) - return call!(B, fty, func) -end - -function get_ptls(func) - entry_bb = first(blocks(func)) - ptls_func = declare_ptls!(LLVM.parent(func)) - - for I in instructions(entry_bb) - if I isa LLVM.CallInst && called_operand(I) == ptls_func - return I - end - end - return nothing -end - -function reinsert_gcmarker!(func, PB=nothing) - ptls = get_ptls(func) - if isnothing(ptls) - B = IRBuilder() - entry_bb = first(blocks(func)) - if !isempty(instructions(entry_bb)) - position!(B, first(instructions(entry_bb))) - else - position!(B, entry_bb) - end - emit_ptls!(B) - else - entry_bb = first(blocks(func)) - fst = first(instructions(entry_bb)) - if fst != ptls - API.moveBefore(ptls, fst, PB === nothing ? C_NULL : PB.ref) - end - ptls - end -end - -function unique_gcmarker!(func) - entry_bb = first(blocks(func)) - ptls_func = declare_ptls!(LLVM.parent(func)) - - found = LLVM.CallInst[] - for I in instructions(entry_bb) - if I isa LLVM.CallInst && called_operand(I) == ptls_func - push!(found, I) - end - end - if length(found) > 1 - for i in 2:length(found) - LLVM.replace_uses!(found[i], found[1]) - Base.unsafe_delete!(entry_bb, found[i]) - end - end - return nothing -end - -else - function declare_pgcstack!(mod) get_function!(mod, "julia.get_pgcstack", LLVM.FunctionType(LLVM.PointerType(T_ppjlvalue()))) end @@ -398,7 +332,6 @@ function unique_gcmarker!(func) end return nothing end -end @inline AnonymousStruct(::Type{U}) where U<:Tuple = NamedTuple{ntuple(i->Symbol(i), Val(length(U.parameters))), U} diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index b95d343bfb..51aeacf675 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -9,56 +9,19 @@ module FFI using LinearAlgebra using ObjectFile using Libdl - @static if VERSION >= v"1.7" - function __init__() - @static if VERSION > v"1.8" - global blas_handle = Libdl.dlopen(BLAS.libblastrampoline) - else - global blas_handle = Libdl.dlopen(BLAS.libblas) - end - end - function get_blas_symbols() - symbols = BLAS.get_config().exported_symbols - if BLAS.USE_BLAS64 - return map(n->n*"64_", symbols) - end - return symbols - end - - function lookup_blas_symbol(name) - Libdl.dlsym(blas_handle::Ptr{Cvoid}, name; throw_error=false) - end - else - function __init__() - global blas_handle = Libdl.dlopen(BLAS.libblas) - end - function get_blas_symbols() - symbols = Set{String}() - path = Libdl.dlpath(BLAS.libblas) - ignoreSymbols = Set(String["", "edata", "_edata", "end", "_end", "_bss_start", "__bss_start", ".text", ".data"]) - for meta in readmeta(open(path, "r")) - for s in Symbols(meta) - name = symbol_name(s) - if !Sys.iswindows() && BLAS.vendor() == :openblas64 - endswith(name, "64_") || continue - else - endswith(name, "_") || continue - end - if !in(name, ignoreSymbols) - push!(symbols, name) - end - end - end - symbols = collect(symbols) - if Sys.iswindows() && BLAS.vendor() == :openblas64 - return map(n->n*"64_", symbols) - end - return symbols + function __init__() + global blas_handle = Libdl.dlopen(BLAS.libblastrampoline) + end + function get_blas_symbols() + symbols = BLAS.get_config().exported_symbols + if BLAS.USE_BLAS64 + return map(n->n*"64_", symbols) end + return symbols + end - function lookup_blas_symbol(name) - Libdl.dlsym(blas_handle::Ptr{Cvoid}, name; throw_error=false) - end + function lookup_blas_symbol(name) + Libdl.dlsym(blas_handle::Ptr{Cvoid}, name; throw_error=false) end end @@ -361,11 +324,9 @@ end return has_method(sig, mt.world, nothing) end -@static if VERSION >= v"1.7" @inline function has_method(sig, world::UInt, mt::Core.Compiler.OverlayMethodTable) return has_method(sig, mt.mt, mt.world) || has_method(sig, nothing, mt.world) end -end @inline function is_inactive(tys, world::UInt, mt) specTypes = Interpreter.simplify_kw(Tuple{tys...}) @@ -739,11 +700,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) frames = ccall(:jl_lookup_code_address, Any, (Ptr{Cvoid}, Cint,), ptr, 0) if length(frames) >= 1 - @static if VERSION >= v"1.4.0-DEV.123" - fn, file, line, linfo, fromC, inlined = last(frames) - else - fn, file, line, linfo, fromC, inlined, ip = last(frames) - end + fn, file, line, linfo, fromC, inlined = last(frames) # Remember pointer in our global map fn = FFI.memoize!(ptr, string(fn)) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index bcb3b1c413..8b772e8e24 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -107,9 +107,7 @@ function EnzymeRules.inactive(::typeof(Base.startswith), ::AbstractString, args. return nothing end -if VERSION >= v"1.9" - Enzyme.EnzymeRules.inactive_noinl(::typeof(Core._compute_sparams), args...) = nothing -end +Enzyme.EnzymeRules.inactive_noinl(::typeof(Core._compute_sparams), args...) = nothing @inline EnzymeRules.inactive_type(v::Type{Nothing}) = true @inline EnzymeRules.inactive_type(v::Type{Union{}}) = true @@ -379,15 +377,6 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT} nothing end -@static if VERSION < v"1.8.0" - UT = Union{ - LinearAlgebra.Diagonal{eltype(AT), onedimensionalize(BT)}, - LinearAlgebra.LowerTriangular{eltype(AT), AT}, - LinearAlgebra.UpperTriangular{eltype(AT), AT}, - LinearAlgebra.LU{eltype(AT), AT}, - LinearAlgebra.QRCompactWY{eltype(AT), AT} - } -else UT = Union{ LinearAlgebra.Diagonal{eltype(AT), onedimensionalize(BT)}, LinearAlgebra.LowerTriangular{eltype(AT), AT}, @@ -395,7 +384,6 @@ else LinearAlgebra.LU{eltype(AT), AT, Vector{Int}}, LinearAlgebra.QRPivoted{eltype(AT), AT, onedimensionalize(BT), Vector{Int}} } -end cache = NamedTuple{(Symbol("1"),Symbol("2"), Symbol("3"), Symbol("4")), Tuple{ eltype(RT), @@ -532,7 +520,6 @@ _zero_unused_elements!(X, ::LowerTriangular) = tril!(X) _zero_unused_elements!(X, ::UnitUpperTriangular) = triu!(X, 1) _zero_unused_elements!(X, ::UnitLowerTriangular) = tril!(X, -1) -@static if VERSION >= v"1.7-" # Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} primal = if EnzymeRules.needs_primal(config) @@ -581,7 +568,6 @@ function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Ty end return (nothing, nothing) end -end function EnzymeRules.forward( ::Const{typeof(sort!)}, diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index bdcdd79b25..49622ada90 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1302,23 +1302,9 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, end pushfirst!(vals, etup) - @static if VERSION < v"1.7.0-" || true - else - mi = emit_methodinstance!(B, func, vals) - end - pushfirst!(vals, unsafe_to_llvm(B, func)) - @static if VERSION < v"1.7.0-" || true - else - pushfirst!(vals, mi) - end - - @static if VERSION < v"1.7.0-" || true cal = emit_apply_generic!(B, vals) - else - cal = emit_invoke!(B, vals) - end debug_from_orig!(gutils, cal, orig) diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index 1db4cd8d0b..54208fe21c 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -167,13 +167,8 @@ end # TODO actually do modifiedBetween -@static if VERSION < v"1.8-" - e_tt = Tuple{} - modifiedBetween = (mode != API.DEM_ForwardMode, ) -else e_tt = Tuple{Const{Int}} modifiedBetween = (mode != API.DEM_ForwardMode, false) -end world = enzyme_extract_world(LLVM.parent(position(B))) @@ -374,10 +369,7 @@ end push!(vals, tape) end - @static if VERSION < v"1.8-" - else - push!(vals, new_from_original(gutils, operands(orig)[end-1])) - end + push!(vals, new_from_original(gutils, operands(orig)[end-1])) return refed, LLVM.name(subfunc), dfuncT, vals, thunkTy, TapeType, copies end @@ -392,11 +384,7 @@ end _, sname, dfuncT, vals, thunkTy, _, _ = threadsfor_common(orig, gutils, B, API.DEM_ForwardMode) -@static if VERSION < v"1.8-" - tt = Tuple{thunkTy, dfuncT} -else tt = Tuple{thunkTy, dfuncT, Bool} -end mode = get_mode(gutils) world = enzyme_extract_world(LLVM.parent(position(B))) entry = nested_codegen!(mode, mod, runtime_pfor_fwd, tt, world) @@ -431,17 +419,7 @@ end byRef, sname, dfuncT, vals, thunkTy, _, copies = threadsfor_common(orig, gutils, B, API.DEM_ReverseModePrimal) -@static if VERSION < v"1.8-" - if byRef - emit_error(B, orig, "Enzyme: active variable in Threads.@threads closure "*(string(eltype(eltype(dfuncT))))*" not supported") - end -end - -@static if VERSION < v"1.8-" - tt = Tuple{thunkTy, dfuncT, Val{any_jltypes(EnzymeRules.tape_type(thunkTy))}, Val{byRef}} -else tt = Tuple{thunkTy, dfuncT, Val{any_jltypes(EnzymeRules.tape_type(thunkTy))}, Val{byRef}, Bool} -end mode = get_mode(gutils) world = enzyme_extract_world(LLVM.parent(position(B))) entry = nested_codegen!(mode, mod, runtime_pfor_augfwd, tt, world) @@ -489,11 +467,7 @@ end Vector{TapeType} end -@static if VERSION < v"1.8-" - tt = Tuple{thunkTy, dfuncT, Val{any_jltypes(EnzymeRules.tape_type(thunkTy))}, Val{byRef}, STT } -else tt = Tuple{thunkTy, dfuncT, Val{any_jltypes(EnzymeRules.tape_type(thunkTy))}, Val{byRef}, STT, Bool} -end mode = get_mode(gutils) entry = nested_codegen!(mode, mod, runtime_pfor_rev, tt, world) push!(function_attributes(entry), EnumAttribute("alwaysinline")) diff --git a/src/typetree.jl b/src/typetree.jl index 73b296b95b..40b01edcce 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -199,11 +199,7 @@ else end end -if VERSION >= v"1.7.0-DEV.204" - import Base: ismutabletype -else - ismutabletype(T) = isa(T, DataType) && T.mutable -end +import Base: ismutabletype function typetree_inner(@nospecialize(T::Type), ctx, dl, seen::TypeTreeTable) if T isa UnionAll || T isa Union || T == Union{} || Base.isabstracttype(T) @@ -214,10 +210,8 @@ function typetree_inner(@nospecialize(T::Type), ctx, dl, seen::TypeTreeTable) return TypeTree() end - @static if VERSION >= v"1.7.0" - if is_concrete_tuple(T) && any(T2 isa Core.TypeofVararg for T2 in T.parameters) - return TypeTree() - end + if is_concrete_tuple(T) && any(T2 isa Core.TypeofVararg for T2 in T.parameters) + return TypeTree() end if T <: AbstractFloat diff --git a/src/utils.jl b/src/utils.jl index cc2af40c74..ac312e8295 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -124,11 +124,7 @@ function hasfieldcount(@nospecialize(dt)) return true end -if VERSION <= v"1.6" - allocatedinline(@nospecialize(T)) = T.isinlinealloc -else - import Base: allocatedinline -end +import Base: allocatedinline #Excerpt from https://github.com/JuliaGPU/GPUCompiler.jl/blob/v0.19.4/src/jlgen.jl # !!! warning "codegen_world_age below is fundamentally unsound." @@ -154,8 +150,6 @@ using Base: _methods_by_ftype # directly, instead use `cached_compilation` which handles invalidation for you. -if VERSION >= v"1.10.0-DEV.873" - # on 1.10 (JuliaLang/julia#48611) the generated function knows which world it was invoked in function _generated_ex(world, source, ex) @@ -178,16 +172,9 @@ function codegen_world_age_generator(world::UInt, source, self, ft::Type, tt::Ty min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results - mthds = if VERSION >= v"1.7.0-DEV.1297" - Base._methods_by_ftype(sig, #=mt=# nothing, #=lim=# -1, - world, #=ambig=# false, - min_world, max_world, has_ambig) - # XXX: use the correct method table to support overlaying kernels - else - Base._methods_by_ftype(sig, #=lim=# -1, + mthds = Base._methods_by_ftype(sig, #=mt=# nothing, #=lim=# -1, world, #=ambig=# false, min_world, max_world, has_ambig) - end mthds === nothing && return _generated_ex(world, source, method_error) length(mthds) == 1 || return _generated_ex(world, source, method_error) @@ -234,95 +221,6 @@ end $(Expr(:meta, :generated, codegen_world_age_generator)) end -else - -# on older versions of Julia we fall back to looking up the current world. this may be wrong -# when the generator is invoked in a different world (TODO: when does this happen?) - -function codegen_world_age_generator(self, ft::Type, tt::Type) - @nospecialize - @assert Core.Compiler.isType(ft) && Core.Compiler.isType(tt) - ft = ft.parameters[1] - tt = tt.parameters[1] - - # validation - ft <: Core.Builtin && error("$(GPUCompiler.unsafe_function_from_type(ft)) is not a generic function") - - # look up the method - method_error = :(throw(MethodError(ft, tt))) - sig = Tuple{ft, tt.parameters...} - min_world = Ref{UInt}(typemin(UInt)) - max_world = Ref{UInt}(typemax(UInt)) - has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results - mthds = if VERSION >= v"1.7.0-DEV.1297" - Base._methods_by_ftype(sig, #=mt=# nothing, #=lim=# -1, - #=world=# typemax(UInt), #=ambig=# false, - min_world, max_world, has_ambig) - # XXX: use the correct method table to support overlaying kernels - else - Base._methods_by_ftype(sig, #=lim=# -1, - #=world=# typemax(UInt), #=ambig=# false, - min_world, max_world, has_ambig) - end - # XXX: using world=-1 is wrong, but the current world isn't exposed to this generator - mthds === nothing && return method_error - length(mthds) == 1 || return method_error - - # look up the method and code instance - mtypes, msp, m = mthds[1] - mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp) - ci = retrieve_code_info(mi)::CodeInfo - - # prepare a new code info - new_ci = copy(ci) - empty!(new_ci.code) - empty!(new_ci.codelocs) - resize!(new_ci.linetable, 1) # see note below - empty!(new_ci.ssaflags) - new_ci.ssavaluetypes = 0 - new_ci.min_world = min_world[] - new_ci.max_world = max_world[] - new_ci.edges = MethodInstance[mi] - # XXX: setting this edge does not give us proper method invalidation, see - # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. - # invoking `code_llvm` also does the necessary codegen, as does calling the - # underlying C methods -- which GPUCompiler does, so everything Just Works. - - # prepare the slots - new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt] - new_ci.slotflags = UInt8[0x00 for i = 1:3] - - # return the current world age (which is not technically the codegen world age, - # but works well enough for invalidation purposes) - push!(new_ci.code, ReturnNode(Base.get_world_counter())) - push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` - push!(new_ci.codelocs, 1) # see note below - new_ci.ssavaluetypes += 1 - - # NOTE: we keep the first entry of the original linetable, and use it for location info - # on the call to check_cache. we can't not have a codeloc (using 0 causes - # corruption of the back trace), and reusing the target function's info - # has as advantage that we see the name of the kernel in the backtraces. - - return new_ci -end - -@eval function codegen_world_age(ft, tt) - $(Expr(:meta, :generated_only)) - $(Expr(:meta, - :generated, - Expr(:new, - Core.GeneratedFunctionStub, - :codegen_world_age_generator, - Any[:methodinstance, :ft, :tt], - Any[], - @__LINE__, - QuoteNode(Symbol(@__FILE__)), - true))) -end - -end - export codegen_world_age diff --git a/test/DiffTests.jl b/test/DiffTests.jl index 98851f1559..9dd5abdb4d 100644 --- a/test/DiffTests.jl +++ b/test/DiffTests.jl @@ -29,13 +29,8 @@ num2num_3(x) = 10.31^(x + x) - x num2num_4(x) = 1.0 num2num_5(x) = 1. / (1. + exp(-x)) -@static if sizeof(Int) == Int64 || VERSION ≥ v"1.7-" const NUMBER_TO_NUMBER_FUNCS = (num2num_1, num2num_2, num2num_3, num2num_4, num2num_5, identity) -else -const NUMBER_TO_NUMBER_FUNCS = (num2num_1, num2num_2, num2num_3, - num2num_4, identity) -end ####################### # f(x::Number)::Array # @@ -120,25 +115,12 @@ end self_weighted_logit(x) = inv(1.0 + exp(-dot(x, x))) -@static if VERSION ≥ v"1.10-" # vec2num_6 fails due to #708 # rosenbrock_4 fails on nightly for unknown reasons const VECTOR_TO_NUMBER_FUNCS = (vec2num_1, vec2num_2, vec2num_3, vec2num_4, vec2num_5, #=vec2num_6,=# vec2num_7, rosenbrock_1, rosenbrock_2, rosenbrock_3, #=rosenbrock_4,=# ackley, self_weighted_logit, first) -elseif sizeof(Int) == Int64 || VERSION ≥ v"1.7-" -# vec2num_6 fails due to #708 -const VECTOR_TO_NUMBER_FUNCS = (vec2num_1, vec2num_2, vec2num_3, vec2num_4, vec2num_5, - #=vec2num_6,=# vec2num_7, rosenbrock_1, rosenbrock_2, - rosenbrock_3, rosenbrock_4, ackley, self_weighted_logit, - first) -else -const VECTOR_TO_NUMBER_FUNCS = (#=vec2num_1,=# vec2num_2, vec2num_3, vec2num_4, vec2num_5, - #=vec2num_6,=# vec2num_7, rosenbrock_1, rosenbrock_2, - rosenbrock_3, rosenbrock_4, #=ackley,=# self_weighted_logit, - first) -end ######################## # f(x::Matrix)::Number # ######################## diff --git a/test/applyiter.jl b/test/applyiter.jl index 11e9ebf37c..5b55617e55 100644 --- a/test/applyiter.jl +++ b/test/applyiter.jl @@ -305,7 +305,6 @@ end Enzyme.autodiff(Forward, metaconcat, Const(a)) -@static if VERSION ≥ v"1.7-" dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Duplicated(a, da)) @test length(dres) == 5 @test dres[1] ≈ 7.0 @@ -351,7 +350,6 @@ end @test dres[3] == "b" @test dres[4] == "c" @test dres[5] == "d" -end y = [(-92.0, -93.0), (-97.9, -911.2)] dy = [(-913.7, -915.2), (-9100.02, -9304.1)] diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 2b1e9bc621..b9a705941c 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -62,9 +62,7 @@ end end @test autodiff(Forward, f4, Duplicated(1.5, 1.0))[1] == 1.5 - @static if VERSION < v"1.7-" || VERSION >= v"1.8-" - @test autodiff(Forward, f4, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.5, var"2"=3.0) - end + @test autodiff(Forward, f4, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.5, var"2"=3.0) @test autodiff(Reverse, f4, Active(1.5))[1][1] == 1.5 @test autodiff(Reverse, f4, Active(4.0))[1][1] == 0.5 @test autodiff(Reverse, f4, Active(6.0))[1][1] == 0.0 @@ -285,7 +283,6 @@ end end end -@static if VERSION > v"1.8" @testset "Cholesky" begin function symmetric_definite(n :: Int=10) α = one(Float64) @@ -619,7 +616,6 @@ end end end end -end @testset "rand and randn rules" begin # Distributed as x + unit normal + uniform diff --git a/test/mixed.jl b/test/mixed.jl index dae0623073..4de521414b 100644 --- a/test/mixed.jl +++ b/test/mixed.jl @@ -26,7 +26,6 @@ end @test 6.2 ≈ Enzyme.autodiff(Reverse, outmixedmul2, Const, Duplicated(res, dres), Active(3.1))[1][2] end -@static if VERSION >= v"1.8-" @testset "Batched Byref Mixed Activity" begin res = Ref(4.7) dres = Ref(1.0) @@ -35,7 +34,6 @@ end @test 6.2 ≈ sig[1][2][1] @test 3*6.2 ≈ sig[1][2][2] end -end function tupmixedmul(x::Float64) vec = [x] @@ -59,7 +57,6 @@ end @test 6.2 ≈ Enzyme.autodiff(Reverse, outtupmixedmul, Const, Duplicated(res, dres), Active(3.1))[1][2] end -@static if VERSION >= v"1.8-" @testset "Batched Byref Tuple Mixed Activity" begin res = Ref(4.7) dres = Ref(1.0) @@ -68,4 +65,3 @@ end @test 6.2 ≈ sig[1][2][1] @test 3*6.2 ≈ sig[1][2][2] end -end diff --git a/test/rrules.jl b/test/rrules.jl index be4c4f1424..6c2a965b0e 100644 --- a/test/rrules.jl +++ b/test/rrules.jl @@ -295,7 +295,6 @@ function plaquette_sum(U) end -@static if VERSION >= v"1.9" @testset "No caching byref julia" begin U = Complex{Float64}[3.0 + 4.0im] dU = Complex{Float64}[0.0] @@ -304,8 +303,6 @@ end @test dU[1] ≈ 7 * ( 3.0 + 4.0im ) end -end - struct Closure v::Vector{Float64} diff --git a/test/runtests.jl b/test/runtests.jl index 250fa20c97..dc826cd5b5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,13 +1,3 @@ -# HACK: work around Pkg.jl#2500 -if VERSION < v"1.8-" -test_project = Base.active_project() -preferences_file = joinpath(dirname(@__DIR__), "LocalPreferences.toml") -test_preferences_file = joinpath(dirname(test_project), "LocalPreferences.toml") -if isfile(preferences_file) && !isfile(test_preferences_file) - cp(preferences_file, test_preferences_file) -end -end - # # work around https://github.com/JuliaLang/Pkg.jl/issues/1585 # using Pkg # Pkg.develop(PackageSpec(; path=joinpath(dirname(@__DIR__), "lib", "EnzymeTestUtils"))) @@ -80,7 +70,8 @@ function test_matrix_to_number(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1) @test isapproxfn((Enzyme.Forward, f), dx_fwd, dx_fd; rtol=rtol, atol=atol, kwargs...) end -Aqua.test_all(Enzyme, unbound_args=false, piracies=false, deps_compat=false) +# Aqua.test_all(Enzyme, unbound_args=false, piracies=false, deps_compat=false, stale_deps=(;:ignore=>[:EnzymeTestUtils])) +# Aqua.test_all(Enzyme, unbound_args=false, piracies=false, deps_compat=false, stale_deps=(;:ignore=>[:EnzymeTestUtils])) include("abi.jl") include("typetree.jl") @@ -91,12 +82,9 @@ include("typetree.jl") include("kwrules.jl") include("kwrrules.jl") include("internal_rules.jl") - @static if VERSION ≥ v"1.9-" - # XXX invalidation does not work on Julia 1.8 - include("ruleinvalidation.jl") - end + include("ruleinvalidation.jl") end -@static if VERSION ≥ v"1.7-" || !Sys.iswindows() +@static if !Sys.iswindows() include("blas.jl") end @@ -394,17 +382,11 @@ make3() = (1.0, 2.0, 3.0) test_scalar(cbrt, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5) test_scalar(Base.sinh, 1.0) test_scalar(Base.cosh, 1.0) - if sizeof(Int) == Int64 || VERSION ≥ v"1.7-" test_scalar(Base.sinc, 2.2) - end test_scalar(Base.FastMath.sinh_fast, 1.0) test_scalar(Base.FastMath.cosh_fast, 1.0) - if sizeof(Int) == Int64 || VERSION ≥ v"1.7-" test_scalar(Base.FastMath.exp_fast, 1.0) - end - if sizeof(Int) == Int64 || VERSION ≥ v"1.7-" test_scalar(Base.exp10, 1.0) - end test_scalar(Base.exp2, 1.0) test_scalar(Base.expm1, 1.0) test_scalar(x->rem(x, 1), 0.7) @@ -454,11 +436,7 @@ end Const{typeof(dot)}, Active, Duplicated{typeof(thunk_A)} ) @test Tuple{Float64,Float64} === TapeType - Ret = if VERSION < v"1.8-" - Active{Float64} - else - Active - end + Ret = Active fwd, rev = Enzyme.autodiff_deferred_thunk( ReverseSplitWithPrimal, TapeType, @@ -474,31 +452,28 @@ end @test all(dA .== def_dA) @test all(dA .== thunk_dA) - @static if VERSION < v"1.8-" - else - function kernel(len, A) - for i in 1:len - A[i] *= A[i] - end + function kernel(len, A) + for i in 1:len + A[i] *= A[i] end + end - A = Array{Float64}(undef, 64) - dA = Array{Float64}(undef, 64) + A = Array{Float64}(undef, 64) + dA = Array{Float64}(undef, 64) - A .= (1:1:64) - dA .= 1 + A .= (1:1:64) + dA .= 1 - function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, args...) where {ModifiedBetween, FT} - TapeType = Enzyme.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...) - forward, reverse = Enzyme.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...) - forward(Const(f), Const(ctx), args...)[1] - return nothing - end + function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, args...) where {ModifiedBetween, FT} + TapeType = Enzyme.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...) + forward, reverse = Enzyme.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...) + forward(Const(f), Const(ctx), args...)[1] + return nothing + end - ModifiedBetween = Val((false, false, true)) + ModifiedBetween = Val((false, false, true)) - aug_fwd(64, kernel, ModifiedBetween, Duplicated(A, dA)) - end + aug_fwd(64, kernel, ModifiedBetween, Duplicated(A, dA)) end @@ -880,34 +855,31 @@ end @test autodiff(Forward, arsum, Duplicated(inp, dinp))[1] ≈ 2.0 - # On Julia 1.6 the gradients are wrong (1.0 too large) and on 1.7 it errors - @static if VERSION ≥ v"1.8-" - function f1(m) - s = 0.0 - for (i, col) in enumerate(eachcol(m)) - s += i * sum(col) - end - return s + function f1(m) + s = 0.0 + for (i, col) in enumerate(eachcol(m)) + s += i * sum(col) end + return s + end - m = Float64[1 2 3; 4 5 6; 7 8 9] - dm = zero(m) - autodiff(Reverse, f1, Active, Duplicated(m, dm)) - @test dm == Float64[1 2 3; 1 2 3; 1 2 3] + m = Float64[1 2 3; 4 5 6; 7 8 9] + dm = zero(m) + autodiff(Reverse, f1, Active, Duplicated(m, dm)) + @test dm == Float64[1 2 3; 1 2 3; 1 2 3] - function f2(m) - s = 0.0 - for (i, col) in enumerate(eachrow(m)) - s += i * sum(col) - end - return s + function f2(m) + s = 0.0 + for (i, col) in enumerate(eachrow(m)) + s += i * sum(col) end - - dm = zero(m) - autodiff(Reverse, f2, Active, Duplicated(m, dm)) - @test dm == Float64[1 1 1; 2 2 2; 3 3 3] + return s end + dm = zero(m) + autodiff(Reverse, f2, Active, Duplicated(m, dm)) + @test dm == Float64[1 1 1; 2 2 2; 3 3 3] + function my_conv_3(x, w) y = zeros(Float64, 2, 3, 4, 5) for hi in axes(y, 3) @@ -2300,7 +2272,6 @@ function bc2_loss_function(x, scale, bias) return sum(abs2, bc2_affine_normalize(identity, x_, xmean, xvar, scale_, bias_, 1e-5)) end -@static if VERSION ≥ v"1.8-" @testset "Broadcast noalias" begin x = ones(30) @@ -2315,7 +2286,6 @@ end Enzyme.autodiff(Reverse, bc2_loss_function, Active, Duplicated(x, Enzyme.make_zero(x)), Duplicated(sc, Enzyme.make_zero(sc)), Duplicated(bi, Enzyme.make_zero(bi))) end -end function solve_cubic_eq(poly::AbstractVector{Complex{T}}) where T a1 = 1 / @inbounds poly[1] @@ -2870,14 +2840,12 @@ end @test y[1] == [0.0, 1.0, 0.0] @test y[2] == [0.0, 0.0, 1.0] -@static if VERSION ≥ v"1.9-" x = @SArray [5.0 0.0 6.0] dx = Enzyme.gradient(Forward, prod, x) @test dx[1] ≈ 0 @test dx[2] ≈ 30 @test dx[3] ≈ 0 end -end function sparse_eval(x::Vector{Float64}) @@ -2887,7 +2855,6 @@ function sparse_eval(x::Vector{Float64}) return A[1] end -@static if VERSION ≥ v"1.7-" @testset "Type Unstable SparseArrays" begin x = [3.1, 2.7, 8.2] dx = [0.0, 0.0, 0.0] @@ -2897,7 +2864,6 @@ end @test x ≈ [3.1, 2.7, 8.2] @test dx ≈ [-1.0, 43.74, 0] end -end @testset "Simple Jacobian" begin @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0) ≈ 2.0 @@ -3357,11 +3323,7 @@ end @test res[1] ≈ 0.2 # broken as the return of an apply generic is {primal, primal} # but since the return is abstractfloat doing the - @static if VERSION ≥ v"1.9-" && !(VERSION ≥ v"1.10-" ) - @test_broken res[2] ≈ 1.0 - else - @test res[2] ≈ 1.0 - end + @test res[2] ≈ 1.0 end @inline function uns_mymean(f, A, ::Type{T}, c) where T @@ -3412,7 +3374,6 @@ end @test dx ≈ Float64[1.0] end -@static if VERSION < v"1.8-" || VERSION >= v"1.9-" @inline extract_bc(bc, ::Val{:north}) = (bc.north) @inline extract_bc(bc, ::Val{:top}) = (bc.top) @@ -3437,7 +3398,6 @@ end Enzyme.API.looseTypeAnalysis!(false) end -end @testset "Static activity" begin @@ -3539,11 +3499,9 @@ end @test res.x == 5.0 - if VERSION > v"1.10-" - res = autodiff(Reverse, g, Active, Active(Moo(3.0, "a")))[1][1] + res = autodiff(Reverse, g, Active, Active(Moo(3.0, "a")))[1][1] - @test res.x == 5.0 - end + @test res.x == 5.0 end @testset "Type preservation" begin @@ -3800,15 +3758,11 @@ end @test autodiff(Reverse, f8, Active, Active(1.5))[1][1] == 0 @test autodiff(Forward, f8, Duplicated(1.5, 1.0))[1] == 0 - # On Julia 1.6 the gradients are wrong (0.7 not 1.2) and on 1.7 it errors - @static if VERSION ≥ v"1.8-" - f9(x) = sum(quantile([1.0, x], [0.5, 0.7])) - @test autodiff(Reverse, f9, Active, Active(2.0))[1][1] == 1.2 - @test autodiff(Forward, f9, Duplicated(2.0, 1.0))[1] == 1.2 - end + f9(x) = sum(quantile([1.0, x], [0.5, 0.7])) + @test autodiff(Reverse, f9, Active, Active(2.0))[1][1] == 1.2 + @test autodiff(Forward, f9, Duplicated(2.0, 1.0))[1] == 1.2 end -@static if VERSION >= v"1.7-" @testset "hvcat_fill" begin ar = Matrix{Float64}(undef, 2, 3) dar = [1.0 2.0 3.0; 4.0 5.0 6.0] @@ -3824,26 +3778,19 @@ end end # TEST EXTENSIONS -@static if VERSION ≥ v"1.9-" - using SpecialFunctions - @testset "SpecialFunctions ext" begin - lgabsg(x) = SpecialFunctions.logabsgamma(x)[1] - test_scalar(lgabsg, 1.0; rtol = 1.0e-5, atol = 1.0e-5) - test_scalar(lgabsg, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5) - end - - using ChainRulesCore - @testset "ChainRulesCore ext" begin - include("ext/chainrulescore.jl") - end - include("ext/logexpfunctions.jl") - - @testset "BFloat16s ext" begin - include("ext/bfloat16s.jl") - end +using SpecialFunctions +@testset "SpecialFunctions ext" begin + lgabsg(x) = SpecialFunctions.logabsgamma(x)[1] + test_scalar(lgabsg, 1.0; rtol = 1.0e-5, atol = 1.0e-5) + test_scalar(lgabsg, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5) end - - +using ChainRulesCore +@testset "ChainRulesCore ext" begin + include("ext/chainrulescore.jl") end +include("ext/logexpfunctions.jl") +@testset "BFloat16s ext" begin + include("ext/bfloat16s.jl") +end From dffb431ddae5bae755fd646485b83b96c02fd07b Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 12 Sep 2024 09:15:26 -0500 Subject: [PATCH 272/495] Update Project.toml --- lib/EnzymeTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/EnzymeTestUtils/Project.toml b/lib/EnzymeTestUtils/Project.toml index 80dd2ede75..05e5e6b94e 100644 --- a/lib/EnzymeTestUtils/Project.toml +++ b/lib/EnzymeTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeTestUtils" uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a" authors = ["Seth Axen ", "William Moses ", "Valentin Churavy "] -version = "0.1.7" +version = "0.1.8" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" From b140e0e1f3cc70d8135b3f78f371285a675fd188 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 12 Sep 2024 17:04:56 -0500 Subject: [PATCH 273/495] Fix MixedDuplicated ABI error on primalerror (#1815) --- src/compiler.jl | 10 +++++----- test/mixed.jl | 17 +++++++++++++++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index f82ae6c135..ba53836409 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4928,7 +4928,7 @@ function get_return_info(jlrettype)::Tuple{Union{Nothing, Type}, Union{Nothing, end # Modified from GPUCompiler/src/irgen.jl:365 lower_byval -function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function, actualRetType::Type, RetActivity, TT) +function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function, actualRetType::Type, RetActivity, TT, run_enzyme) entry_ft = LLVM.function_type(entry_f) RT = LLVM.return_type(entry_ft) @@ -4985,7 +4985,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function push!(wrapper_types, typ) push!(wrapper_attrs, LLVM.Attribute[]) elseif arg.cc != GPUCompiler.BITS_REF - if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated) + if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated) && run_enzyme push!(boxedArgs, arg.arg_i) push!(raisedArgs, arg.arg_i) push!(wrapper_types, LLVM.PointerType(typ, Derived)) @@ -4996,7 +4996,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function end else # bits ref, and not boxed - if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated) + if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated) && run_enzyme push!(boxedArgs, arg.arg_i) push!(wrapper_types, typ) push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")]) @@ -5931,7 +5931,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; sret = get_return_info(k.ci.rettype)[2] !== nothing if sret cur = llvmfn == primalf - llvmfn, _, boxedArgs, loweredArgs = lower_convention(mi.specTypes, mod, llvmfn, k.ci.rettype, Duplicated, nothing) + llvmfn, _, boxedArgs, loweredArgs = lower_convention(mi.specTypes, mod, llvmfn, k.ci.rettype, Duplicated, nothing, params.run_enzyme) if cur primalf = llvmfn lowerConvention = false @@ -6002,7 +6002,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; primalf, returnRoots = primalf, false if lowerConvention - primalf, returnRoots, boxedArgs, loweredArgs = lower_convention(source_sig, mod, primalf, actualRetType, job.config.params.rt, TT) + primalf, returnRoots, boxedArgs, loweredArgs = lower_convention(source_sig, mod, primalf, actualRetType, job.config.params.rt, TT, params.run_enzyme) end if primal_job.config.target isa GPUCompiler.NativeCompilerTarget diff --git a/test/mixed.jl b/test/mixed.jl index 4de521414b..dc4c510b23 100644 --- a/test/mixed.jl +++ b/test/mixed.jl @@ -65,3 +65,20 @@ end @test 6.2 ≈ sig[1][2][1] @test 3*6.2 ≈ sig[1][2][2] end + +struct Foobar + x::Int + y::Int + z::Int + q::Int + r::Float64 +end + +function bad_abi(fb) + v = fb.x + throw(AssertionError("saw bad val $v")) +end + +@testset "Mixed PrimalError" begin + @test_throws AssertionError autodiff(Reverse, bad_abi, MixedDuplicated(Foobar(2, 3, 4, 5, 6.0), Ref(Foobar(2, 3, 4, 5, 6.0)))) +end \ No newline at end of file From e63c1b75f1e5d1158722e02c6d048dfe9fbe30ae Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Sat, 14 Sep 2024 02:43:38 +0900 Subject: [PATCH 274/495] adjustments to the latest inlining interface changes (#1350) * adjustments to the latest inlining interface changes * Update src/compiler/interpreter.jl * rebase * Update interpreter.jl * fix `inlining_policy` overload --------- Co-authored-by: William S. Moses --- .gitignore | 1 + src/compiler/interpreter.jl | 113 +++++++++++++++++++++++------------- 2 files changed, 75 insertions(+), 39 deletions(-) diff --git a/.gitignore b/.gitignore index 594a8584c4..e7ee8ed2f5 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.jl.cov *.jl.mem /Manifest.toml +/Manifest-v*.toml /test/Manifest.toml /docs/Manifest.toml /docs/build/ diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 08b42d587b..46ca95ab32 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -68,20 +68,17 @@ else end # No need to do any locking since we're not putting our results into the runtime cache -Core.Compiler.lock_mi_inference(interp::EnzymeInterpreter, mi::MethodInstance) = nothing -Core.Compiler.unlock_mi_inference(interp::EnzymeInterpreter, mi::MethodInstance) = nothing +Core.Compiler.lock_mi_inference(::EnzymeInterpreter, ::MethodInstance) = nothing +Core.Compiler.unlock_mi_inference(::EnzymeInterpreter, ::MethodInstance) = nothing -function Core.Compiler.add_remark!(interp::EnzymeInterpreter, sv::InferenceState, msg) -end - -Core.Compiler.may_optimize(interp::EnzymeInterpreter) = true -Core.Compiler.may_compress(interp::EnzymeInterpreter) = true +Core.Compiler.may_optimize(::EnzymeInterpreter) = true +Core.Compiler.may_compress(::EnzymeInterpreter) = true # From @aviatesk: # `may_discard_trees = true`` means a complicated (in terms of inlineability) source will be discarded, # but as far as I understand Enzyme wants "always inlining, except special cased functions", # so I guess we really don't want to discard sources? -Core.Compiler.may_discard_trees(interp::EnzymeInterpreter) = false -Core.Compiler.verbose_stmt_info(interp::EnzymeInterpreter) = false +Core.Compiler.may_discard_trees(::EnzymeInterpreter) = false +Core.Compiler.verbose_stmt_info(::EnzymeInterpreter) = false if isdefined(Base.Experimental, Symbol("@overlay")) Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = @@ -123,7 +120,7 @@ function isKWCallSignature(@nospecialize(TT)) return TT <: Tuple{typeof(Core.kwcall), Any, Any, Vararg} end -function simplify_kw(specTypes) +function simplify_kw(@nospecialize specTypes) if isKWCallSignature(specTypes) return Base.tuple_type_tail(Base.tuple_type_tail(specTypes)) else @@ -131,44 +128,82 @@ function simplify_kw(specTypes) end end -# https://github.com/JuliaLang/julia/pull/46965 import Core.Compiler: CallInfo -function Core.Compiler.inlining_policy(interp::EnzymeInterpreter, - @nospecialize(src), @nospecialize(info::CallInfo), stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) - +struct NoInlineCallInfo <: CallInfo + info::CallInfo # wrapped call + tt # ::Type + kind::Symbol + NoInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt), kind::Symbol) = new(info, tt, kind) +end +Core.Compiler.nsplit_impl(info::NoInlineCallInfo) = Core.Compiler.nsplit(info.info) +Core.Compiler.getsplit_impl(info::NoInlineCallInfo, idx::Int) = Core.Compiler.getsplit(info.info, idx) +Core.Compiler.getresult_impl(info::NoInlineCallInfo, idx::Int) = Core.Compiler.getresult(info.info, idx) +struct AlwaysInlineCallInfo <: CallInfo + info::CallInfo # wrapped call + tt # ::Type + AlwaysInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt)) = new(info, tt) +end +Core.Compiler.nsplit_impl(info::AlwaysInlineCallInfo) = Core.Compiler.nsplit(info.info) +Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getsplit(info.info, idx) +Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult(info.info, idx) + +using Core.Compiler: ArgInfo, StmtInfo, AbsIntState +function Core.Compiler.abstract_call_gf_by_type(interp::EnzymeInterpreter, @nospecialize(f), + arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype), sv::AbsIntState, max_methods::Int) + ret = @invoke Core.Compiler.abstract_call_gf_by_type(interp::AbstractInterpreter, f::Any, + arginfo::ArgInfo, si::StmtInfo, atype::Any, sv::AbsIntState, max_methods::Int) + callinfo = ret.info method_table = Core.Compiler.method_table(interp) - specTypes = simplify_kw(mi.specTypes) - + specTypes = simplify_kw(atype) if is_primitive_func(specTypes) - @safe_debug "Blocking inlining for primitive func" mi.specTypes - return nothing - end - - if is_alwaysinline_func(specTypes) - @safe_debug "Forcing inlining for primitive func" mi.specTypes - @assert src !== nothing - return src + callinfo = NoInlineCallInfo(callinfo, atype, :primitive) + elseif is_alwaysinline_func(specTypes) + callinfo = AlwaysInlineCallInfo(callinfo, atype) + elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) + callinfo = NoInlineCallInfo(callinfo, atype, :inactive) + elseif interp.mode == API.DEM_ForwardMode + if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) + callinfo = NoInlineCallInfo(callinfo, atype, :frule) + end + elseif EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) + callinfo = NoInlineCallInfo(callinfo, atype, :rrule) end - - if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) - @safe_debug "Blocking inlining due to inactive rule" mi.specTypes - return nothing + @static if VERSION ≥ v"1.11-" + return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) + else + return Core.Compiler.CallMeta(ret.rt, ret.effects, callinfo) end +end - if interp.mode == API.DEM_ForwardMode - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) - @safe_debug "Blocking inlining due to frule" mi.specTypes - return nothing - end +let # overload `inlining_policy` + @static if VERSION ≥ v"1.11.0-DEV.879" + sigs_ex = :(interp::EnzymeInterpreter, @nospecialize(src), @nospecialize(info::Core.Compiler.CallInfo), stmt_flag::UInt32) + args_ex = :(interp::AbstractInterpreter, src::Any, info::Core.Compiler.CallInfo, stmt_flag::UInt32) else - if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) - @safe_debug "Blocking inling due to rrule" mi.specTypes + sigs_ex = :(interp::EnzymeInterpreter, + @nospecialize(src), @nospecialize(info::Core.Compiler.CallInfo), stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) + args_ex = :(interp::AbstractInterpreter, + src::Any, info::Core.Compiler.CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) + end + @eval function Core.Compiler.inlining_policy($(sigs_ex.args...)) + if info isa NoInlineCallInfo + if info.kind === :primitive + @safe_debug "Blocking inlining for primitive func" info.tt + elseif info.kind === :inactive + @safe_debug "Blocking inlining due to inactive rule" info.tt + elseif info.kind === :frule + @safe_debug "Blocking inlining due to frule" info.tt + else + @assert info.kind === :rrule + @safe_debug "Blocking inlining due to rrule" info.tt + end return nothing + elseif info isa AlwaysInlineCallInfo + @safe_debug "Forcing inlining for primitive func" info.tt + return src end + return @invoke Core.Compiler.inlining_policy($(args_ex.args...)) end - - return Base.@invoke Core.Compiler.inlining_policy(interp::AbstractInterpreter, - src::Any, info::CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) end -end +end # module Interpreter From 3528723712fd6af6f25822ad9f3f4a214be0b4f4 Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Fri, 13 Sep 2024 18:10:49 -0400 Subject: [PATCH 275/495] more comprehensive unit tests for gradient and jacobian (#1773) * more comprehensive unit tests for gradient and jacobian * More extensive sugar tests * more comprehensive unit tests for gradient and jacobian * More extensive sugar tests * hopefully working now? * Update Project.toml * try fixing tests on 1.6 * Update Project.toml --------- Co-authored-by: Billy Moses Co-authored-by: William Moses --- src/Enzyme.jl | 10 +- test/runtests.jl | 275 ++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 267 insertions(+), 18 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 450d96ffb0..a5b949c60a 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1341,7 +1341,7 @@ For functions who return other types, this function will retun an array or tuple of shape `size(output)` of values of the input type. ``` """ -@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, RABI<:ABI, ErrIfFuncWritten} +@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, #=Holomorphic=#false, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, RABI<:ABI, ErrIfFuncWritten} num = ((n_out_val + chunk - 1) ÷ chunk) if chunk == 0 @@ -1417,7 +1417,7 @@ of shape `size(output)` of values of the input type. end end -@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,RABI<:ABI, ErrIfFuncWritten} +@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, #=Holomorphic=#false, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,RABI<:ABI, ErrIfFuncWritten} XT = Core.Typeof(x) MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} @@ -1466,12 +1466,12 @@ end end end -@inline function jacobian(::ReverseMode{ReturnPrimal,RABI, ErrIfFuncWritten}, f::F, x::X) where {ReturnPrimal, F, X, RABI<:ABI, ErrIfFuncWritten} +@inline function jacobian(::ReverseMode{ReturnPrimal,RABI,Holomorphic,ErrIfFuncWritten}, f::F, x::X) where {ReturnPrimal, F, X, RABI<:ABI, Holomorphic, ErrIfFuncWritten} res = f(x) jac = if res isa AbstractArray - jacobian(ReverseMode{false,RABI, ErrIfFuncWritten}(), f, x, Val(length(jac))) + jacobian(ReverseMode{false,RABI, Holomorphic, ErrIfFuncWritten}(), f, x, Val(length(res))) elseif res isa AbstractFloat - gradient(ReverseMode{false,RABI, ErrIfFuncWritten}(), f, x) + gradient(ReverseMode{false,RABI, Holomorphic, ErrIfFuncWritten}(), f, x) else throw(AssertionError("Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))")) end diff --git a/test/runtests.jl b/test/runtests.jl index dc826cd5b5..6ffd3dd09c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,6 +16,15 @@ using InlineStrings using Enzyme_jll @info "Testing against" Enzyme_jll.libEnzyme +# symbol is \simeq +# this is basically a more flexible version of ≈ +(≃)(a, b) = (≈)(a, b) +(≃)(a::Tuple, b::Tuple) = all(xy -> xy[1] ≃ xy[2], zip(a,b)) +function (≃)(a::AbstractArray{<:Tuple}, b::AbstractArray{<:Tuple}) + size(a) == size(b) || return false + all(xy -> xy[1] ≃ xy[2], zip(a,b)) +end + function isapproxfn(fn, args...; kwargs...) isapprox(args...; kwargs...) end @@ -2865,6 +2874,259 @@ end @test dx ≈ [-1.0, 43.74, 0] end + +# these are used in gradient and jacobian tests +struct InpStruct + i1::Float64 + i2::Float64 + i3::Float64 +end +struct OutStruct + i1::Float64 + i2::Float64 + i3::Float64 +end + +for A ∈ (:InpStruct, :OutStruct) + @eval (≃)(a::$A, b::$A) = (a.i1 ≃ b.i1) && (a.i2 ≃ b.i2) && (a.i3 ≃ b.i3) + @eval function (≃)(a::AbstractArray{<:$A}, b::AbstractArray{<:$A}) + size(a) == size(b) || return false + all(xy -> xy[1] ≃ xy[2], zip(a, b)) + end +end + + +#NOTE: this is needed because of problems with hvcat on 1.10 and something inexplicable on 1.6 +# suffice it to say it's not good that this is required, please remove when possible +mkarray(sz, args...) = reshape(vcat(args...), sz) + +@testset "Gradient and Jacobian Outputs" begin + + scalar = 3.0 + + # ∂ scalar / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, x -> x^2, scalar) ≈ 6.0 + @test Enzyme.gradient(Enzyme.Reverse, x -> x^2, scalar) ≈ 6.0 + @test Enzyme.jacobian(Enzyme.Forward, x -> x^2, scalar) ≈ 6.0 + @test Enzyme.jacobian(Enzyme.Reverse, x -> x^2, scalar) ≈ 6.0 + @test Enzyme.gradient(Enzyme.Forward, x -> 2*x, scalar) ≈ 2.0 + @test Enzyme.gradient(Enzyme.Reverse, x -> 2*x, scalar) ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x -> 2*x, scalar) ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Reverse, x -> 2*x, scalar) ≈ 2.0 + + # ∂ vector / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, x -> [2*x, x^2], scalar) ≈ [2.0, 6.0] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [2*x, x^2], scalar) ≈ [2.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Forward, x -> [2*x, x^2], scalar) ≈ [2.0, 6.0] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [2*x, x^2], scalar) ≈ [2.0, 6.0] + + + # ∂ tuple / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, x -> (2*x, x^2), scalar) ≃ (2.0, 6.0) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (2*x, x^2), scalar) ≈ [2.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Forward, x -> (2*x, x^2), scalar) ≃ (2.0, 6.0) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (2*x, x^2), scalar) ≃ (2.0, 6.0) + + mkarray1 = x -> mkarray((2,2),2*x,sin(x),x^2,exp(x)) + + # ∂ matrix / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, mkarray1, scalar) ≈ [2.0 6.0; cos(scalar) exp(scalar)] + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray1, scalar) ≈ [2.0 6.0; cos(scalar) exp(scalar)] + + @test Enzyme.jacobian(Enzyme.Forward, mkarray1, scalar) ≈ [2.0 6.0; cos(scalar) exp(scalar)] + @test Enzyme.jacobian(Enzyme.Reverse, mkarray1, scalar) ≈ [2.0 6.0; cos(scalar) exp(scalar)] + + # ∂ struct / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar) == OutStruct(1.0,2*scalar,3*scalar^2) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar) == (OutStruct(1.0,2.0,3.0),) + @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar) == OutStruct(1.0,2*scalar,3*scalar^2) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar) == (OutStruct(1.0,2.0,3.0),) + + + + vector = [2.7, 3.1] + + # ∂ scalar / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], vector) ≃ (vector[2],vector[1]) + @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], vector) ≈ [vector[2], vector[1]] + @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], vector) ≈ [vector[2], vector[1]] + @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], vector) ≈ [vector[2], vector[1]] + + + # ∂ vector / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector) ≃ + ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector) ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + @test Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector) ≈ + [vector[2] vector[1]; -sin(vector[1]) 1.0] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector) ≈ + [vector[2] vector[1]; -sin(vector[1]) 1.0] + + # ∂ tuple / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≃ + ((vector[2], -sin(vector[1])), (vector[1], 1.0)) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≈ + ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≃ + [(vector[2], -sin(vector[1])), (vector[1], 1.0)] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) + + mkarray2 = x -> mkarray((2,2), x[1]*x[2], exp(x[2]), cos(x[1])+x[2], x[1]) + + # ∂ matrix / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, mkarray2, vector) ≃ + ([vector[2] -sin(vector[1]); 0.0 1.0], [vector[1] 1.0; exp(vector[2]) 0.0]) + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, vector) + @test Enzyme.jacobian(Enzyme.Forward, mkarray2, vector) ≈ + mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) + @test Enzyme.jacobian(Enzyme.Reverse, mkarray2, vector) ≈ + mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) + + # ∂ struct / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector) ≃ + (OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + + @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector) ≃ + [OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + + + + tuplev = (2.7, 3.1) + + # ∂ scalar / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], tuplev) ≃ (tuplev[2],tuplev[1]) + @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], tuplev) ≃ (tuplev[2],tuplev[1]) + @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], tuplev) ≃ (tuplev[2],tuplev[1]) + @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], tuplev) ≃ (tuplev[2],tuplev[1]) + + # ∂ vector / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev) ≃ + ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev) ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev) ≈ + [tuplev[2] tuplev[1]; -sin(tuplev[1]) 1.0] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev) ≃ + [(tuplev[2], tuplev[1]), (-sin(tuplev[1]), 1.0)] + + # ∂ tuple / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≃ + ((vector[2], -sin(vector[1])), (vector[1], 1.0)) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≃ + ((tuplev[2], -sin(tuplev[1])), (tuplev[1], 1.0)) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≈ + [tuplev[2] tuplev[1]; -sin(tuplev[1]) 1.0] + + # ∂ matrix / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, mkarray2, tuplev) ≃ + ([tuplev[2] -sin(tuplev[1]); 0.0 1.0], [tuplev[1] 1.0; exp(tuplev[2]) 0.0]) + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, tuplev) + @test_broken Enzyme.jacobian(Enzyme.Forward, mkarray2, tuplev) ≈ + [tuplev[2] -sin(tuplev[1]); 0.0 1.0;;; tuplev[1] 1.0; exp(tuplev[2]) 0.0] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> mkarray2, tuplev) ≈ + [tuplev[2] -sin(tuplev[1]); 0.0 1.0;;; tuplev[1] 1.0; exp(tuplev[2]) 0.0] + + # ∂ struct / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev) ≃ + (OutStruct(tuplev[2], -sin(tuplev[1]), 0.0), OutStruct(tuplev[1], 1.0, exp(tuplev[2]))) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev) ≃ + [OutStruct(tuplev[2], -sin(tuplev[1]), 0.0), OutStruct(tuplev[1], 1.0, exp(tuplev[2]))] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + + + + matrix = [2.7 3.1; 4.7 5.6] + + # ∂ scalar / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix) ≃ + (matrix[1,2], matrix[2,2], matrix[1,1], matrix[2,1]) + @test Enzyme.gradient(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix) ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.jacobian(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix) ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.jacobian(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix) ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + + # ∂ vector / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix) ≃ + ([matrix[1,2], 0.0], [0.0, matrix[2,2]], [matrix[1,1], 0.0], [0.0, matrix[2,1]]) + @test_broken Enzyme.gradient(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix) + # again we can't use array construction syntax because of 1.6 + @test Enzyme.jacobian(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix) ≈ + mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) + @test Enzyme.jacobian(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix) ≈ + mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) + + # ∂ tuple / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) ≃ ((matrix[1,2], 0.0), (0.0, matrix[2,2]), (matrix[1,1], 0.0), (0.0, matrix[2,1])) + @test_broken Enzyme.gradient(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) + @test Enzyme.jacobian(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) ≃ + [(matrix[1,2],0.0) (matrix[1,1],0.0); (0.0,matrix[2,2]) (0.0,matrix[2,1])] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) + + mkarray3 = x -> mkarray((2,2), x[1,1]*x[1,2], exp(x[1,1])+x[2,2], x[2,1]*x[2,2], sin(x[1,2])+x[2,1]) + + # ∂ matrix / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, mkarray3, matrix) ≃ + ([matrix[1,2] 0.0; exp(matrix[1,1]) 0.0], [0.0 matrix[2,2]; 0.0 1.0], [matrix[1,1] 0.0; 0.0 cos(matrix[1,2])], [0.0 matrix[2,1]; 1.0 0.0]) + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray3, matrix) + # array construction syntax broken on 1.6 + @test Enzyme.jacobian(Enzyme.Forward, mkarray3, matrix) ≈ + mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, + matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) + @test Enzyme.jacobian(Enzyme.Reverse, mkarray3, matrix) ≈ + mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, + matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) + + # ∂ tuple / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix) ≃ + (OutStruct(matrix[1,2], 0.0, exp(matrix[1,1])), OutStruct(0.0, matrix[2,2], 0.0), OutStruct(matrix[1,1], 0.0, 0.0), OutStruct(0.0, matrix[2,1], 1.0)) + @test_broken Enzyme.gradient(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix) + @test Enzyme.jacobian(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix) ≃ + [OutStruct(matrix[1,2],0.0, exp(matrix[1,1])) OutStruct(matrix[1,1],0.0,0.0); OutStruct(0.0,matrix[2,2],0.0) OutStruct(0.0,matrix[2,1], 1.0)] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix) + + + istruct = InpStruct(2.7, 3.1, 4.7) + + # ∂ scalar / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct) + @test Enzyme.gradient(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct) ≃ InpStruct(istruct.i2, istruct.i1, 1.0) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct) + @test Enzyme.jacobian(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct) ≃ InpStruct(istruct.i2, istruct.i1, 1.0) + + # ∂ vector / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct) + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct) ≃ [InpStruct(istruct.i2, istruct.i1, 0.0), InpStruct(1.0, 0.0, -sin(istruct.i3))] + + # ∂ tuple / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct) + + mkarray4 = x -> mkarray((2,2), x.i1*x.i2, exp(x.i2), cos(x.i3)+x.i1, x.i1) + + # ∂ matrix / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct) + @test Enzyme.jacobian(Enzyme.Reverse, mkarray4, istruct) ≃ + [InpStruct(istruct.i2, istruct.i1, 0.0) InpStruct(1.0, 0.0, -sin(istruct.i3)); + InpStruct(0.0, exp(istruct.i2), 0.0) InpStruct(1.0, 0.0, 0.0)] + + # ∂ struct / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct) +end + @testset "Simple Jacobian" begin @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0) ≈ 2.0 @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0) ≈ [1.0, 2.0] @@ -2922,12 +3184,6 @@ end @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - struct InpStruct - i1::Float64 - i2::Float64 - i3::Float64 - end - fillinpabs2(x) = [(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 10*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 100*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 1000*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3)] x2 = InpStruct(1.0, 2.0, 3.0) @@ -2946,12 +3202,6 @@ end @test jac[3] == InpStruct(200.0, 400.0, 600.0) @test jac[4] == InpStruct(2000.0, 4000.0, 6000.0) - struct OutStruct - i1::Float64 - i2::Float64 - i3::Float64 - end - filloutabs2(x) = OutStruct(sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x)) jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x) @@ -2986,7 +3236,6 @@ end @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) - end From 24c58efe5bac678aad9ba66fe97a18f9044b3e3d Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 14 Sep 2024 23:24:42 -0500 Subject: [PATCH 276/495] Ensure typeof doesn't get cached (#1826) --- src/compiler.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index ba53836409..0d9e147156 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3541,6 +3541,12 @@ function annotate!(mod, mode) push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) end end + for fname in ("julia.typeof",) + if haskey(fns, fname) + fn = fns[fname] + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_nocache")) + end + end for fname in ("jl_excstack_state","ijl_excstack_state", "ijl_field_index", "jl_field_index") if haskey(fns, fname) From afedaac9dac1fc039aa585307398247cbbb54c68 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 14 Sep 2024 23:24:54 -0500 Subject: [PATCH 277/495] Improve deferred error message (#1827) * Improve deferred error message * fix --- src/rules/llvmrules.jl | 46 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index fb93016063..962a4f46af 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -1329,6 +1329,46 @@ end end +@register_fwd function deferred_fwd(B, orig, gutils, normalR, shadowR) + if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + return true + end + err = emit_error(B, orig, "There is a known issue in GPUCompiler.jl which is preventing higher-order AD of this code.\nPlease see https://github.com/JuliaGPU/GPUCompiler.jl/issues/629 for more information and to alert the GPUCompiler authors of your use case and need.") + newo = new_from_original(gutils, orig) + API.moveBefore(newo, err, B) + normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + if shadowR != C_NULL && normal !== nothing + unsafe_store!(shadowR, normal.ref) + end + return false +end + +@register_aug function deferred_augfwd(B, orig, gutils, normalR, shadowR, tapeR) + if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + return true + end + err = emit_error(B, orig, "There is a known issue in GPUCompiler.jl which is preventing higher-order AD of this code.\nPlease see https://github.com/JuliaGPU/GPUCompiler.jl/issues/629 for more information and to alert the GPUCompiler authors of your use case and need.") + newo = new_from_original(gutils, orig) + API.moveBefore(newo, err, B) + normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + if shadowR != C_NULL && normal !== nothing + unsafe_store!(shadowR, normal.ref) + end + # Delete the primal code + if normal !== nothing + unsafe_store!(normalR, C_NULL) + else + ni = new_from_original(gutils, orig) + API.EnzymeGradientUtilsErase(gutils, ni) + end + return false +end + +@register_rev function deferred_rev(B, orig, gutils, tape) + return nothing +end + + function register_handler!(variants, augfwd_handler, rev_handler, fwd_handler=nothing) for variant in variants if augfwd_handler !== nothing && rev_handler !== nothing @@ -1522,6 +1562,12 @@ end @revfunc(finalizer_rev), @fwdfunc(finalizer_fwd), ) + register_handler!( + ("deferred_codegen",), + @augfunc(deferred_augfwd), + @revfunc(deferred_rev), + @fwdfunc(deferred_fwd), + ) register_handler!( ("jl_array_grow_end","ijl_array_grow_end"), @augfunc(jl_array_grow_end_augfwd), From 9ff45682cfed861054950460a8f3c4f5712e9300 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 15 Sep 2024 01:37:50 -0500 Subject: [PATCH 278/495] Fix diffuse rooting (#1829) * add comment * fix --- src/rules/customrules.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 9628623987..005557c65a 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -1141,6 +1141,22 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) return (false, true) end + non_rooting_use = false + fop = called_operand(orig)::LLVM.Function + for (i, v) in enumerate(operands(orig)[1:end-1]) + if v == val + if !any(a->kind(a) == kind(StringAttribute("enzymejl_returnRoots")), collect(parameter_attributes(fop, i))) + non_rooting_use = true + break + end + end + end + + # If the operand is just rooting, we don't need it and should override defaults + if !non_rooting_use + return (false, false) + end + # don't use default and always require the arg return (true, false) end From 0b6effaa00ef511c9e8b6b3474d4b70e07f69ed2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 15 Sep 2024 02:12:00 -0500 Subject: [PATCH 279/495] Fix nightly precompile (#1830) --- src/compiler/interpreter.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 46ca95ab32..61a433af4c 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -40,6 +40,12 @@ end function EnzymeInterpreter(cache_or_token, mt::Union{Nothing,Core.MethodTable}, world::UInt, mode::API.CDerivativeMode) @assert world <= Base.get_world_counter() + parms = @static if VERSION < v"1.12" + InferenceParams(unoptimize_throw_blocks=false), + else + InferenceParams() + end + return EnzymeInterpreter( cache_or_token, mt, @@ -51,7 +57,7 @@ function EnzymeInterpreter(cache_or_token, mt::Union{Nothing,Core.MethodTable}, world, # parameters for inference and optimization - InferenceParams(unoptimize_throw_blocks=false), + parms, OptimizationParams(), mode ) From ffcb7dd977fada76efa88c10d80f3b47d7bdc9a2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 15 Sep 2024 09:15:54 -0500 Subject: [PATCH 280/495] Runtime activity in mode (#1816) * Runtime activity in mode * fixup * cr * rules * fix fwd * fr * fr * fr * fr * fr * Update test_forward.jl * Update test_forward.jl * fix * fix * inv * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * improve gradient --- Project.toml | 31 ++--- docs/src/faq.md | 18 ++- examples/custom_rule.jl | 7 +- ext/EnzymeChainRulesCoreExt.jl | 2 +- lib/EnzymeCore/Project.toml | 4 +- lib/EnzymeCore/src/EnzymeCore.jl | 62 ++++++---- lib/EnzymeCore/src/rules.jl | 83 +++++++++---- lib/EnzymeCore/test/runtests.jl | 4 +- lib/EnzymeTestUtils/Project.toml | 8 +- lib/EnzymeTestUtils/src/test_forward.jl | 3 +- lib/EnzymeTestUtils/src/test_reverse.jl | 4 +- lib/EnzymeTestUtils/test/test_forward.jl | 4 +- src/Enzyme.jl | 141 +++++++++++++---------- src/api.jl | 64 ++-------- src/compiler.jl | 38 +++--- src/compiler/reflection.jl | 4 +- src/gradientutils.jl | 1 + src/internal_rules.jl | 99 ++++++++-------- src/rules/customrules.jl | 11 +- src/rules/jitrules.jl | 55 ++++----- src/rules/llvmrules.jl | 6 +- src/rules/parallelrules.jl | 15 +-- src/rules/typeunstablerules.jl | 8 +- test/Project.toml | 2 +- test/ext/.chainrulescore.jl.swp | Bin 0 -> 12288 bytes test/kwrrules.jl | 20 ++-- test/kwrules.jl | 8 +- test/mixedrrule.jl | 8 +- test/rrules.jl | 36 +++--- test/ruleinvalidation.jl | 10 +- test/rules.jl | 24 ++-- test/runtests.jl | 43 +++---- test/sc.jl | 64 ++++++++++ 33 files changed, 495 insertions(+), 392 deletions(-) create mode 100644 test/ext/.chainrulescore.jl.swp create mode 100644 test/sc.jl diff --git a/Project.toml b/Project.toml index 15890547e1..1ea7b5c05b 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.13.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" Enzyme_jll = "7cc45869-7501-5eee-bdea-0790c847d4ef" GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" @@ -16,11 +17,25 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[weakdeps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[extensions] +EnzymeBFloat16sExt = "BFloat16s" +EnzymeChainRulesCoreExt = "ChainRulesCore" +EnzymeLogExpFunctionsExt = "LogExpFunctions" +EnzymeSpecialFunctionsExt = "SpecialFunctions" +EnzymeStaticArraysExt = "StaticArrays" + [compat] BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.7.8" +EnzymeCore = "0.8" Enzyme_jll = "0.0.146, 0.0.148" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, =9.0" @@ -31,23 +46,9 @@ SpecialFunctions = "1, 2" StaticArrays = "1" julia = "1.10" -[extensions] -EnzymeBFloat16sExt = "BFloat16s" -EnzymeChainRulesCoreExt = "ChainRulesCore" -EnzymeLogExpFunctionsExt = "LogExpFunctions" -EnzymeSpecialFunctionsExt = "SpecialFunctions" -EnzymeStaticArraysExt = "StaticArrays" - [extras] BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[weakdeps] -BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/docs/src/faq.md b/docs/src/faq.md index 5e57a8ada8..88c0cce3b9 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -268,7 +268,7 @@ Enzyme.autodiff(Reverse, f, Active(1.2), Const(Vector{Float64}(undef, 1)), Const ((0.0, nothing, nothing, nothing),) ``` -Passing in a dupliacted (e.g. differentiable) variable for `tmp` now leads to the correct answer. +Passing in a duplicated (e.g. differentiable) variable for `tmp` now leads to the correct answer. ```jldoctest storage Enzyme.autodiff(Reverse, f, Active(1.2), Duplicated(Vector{Float64}(undef, 1), zeros(1)), Const(1), Const(5)) # Correct (returns 10.367999999999999 == 1.2^4 * 5) @@ -278,9 +278,11 @@ Enzyme.autodiff(Reverse, f, Active(1.2), Duplicated(Vector{Float64}(undef, 1), z ((10.367999999999999, nothing, nothing, nothing),) ``` -However, even if we ignore the semantic guarantee provided by marking `tmp` as constant, another issue arises. When computing the original function, intermediate computations (like in `f` above) can use `tmp` for temporary storage. When computing the derivative, Enzyme also needs additional temporary storage space for the corresponding derivative variables as well. If `tmp` is marked as Const, Enzyme does not have any temporary storage space for the derivatives! +## Runtime Activity -Recent versions of Enzyme will attempt to error when they detect these latter types of situations, which we will refer to as `activity unstable`. This term is chosen to mirror the Julia notion of type-unstable code (e.g. where a type is not known at compile time). If an expression is activity unstable, it could either be constant, or active, depending on data not known at compile time. For example, consider the following: +When computing the derivative of mutable variables, Enzyme also needs additional temporary storage space for the corresponding derivative variables. If an argument `tmp` is marked as Const, Enzyme does not have any temporary storage space for the derivatives! + +Enzyme will error when they detect these latter types of situations, which we will refer to as `activity unstable`. This term is chosen to mirror the Julia notion of type-unstable code (e.g. where a type is not known at compile time). If an expression is activity unstable, it could either be constant, or active, depending on data not known at compile time. For example, consider the following: ```julia function g(cond, active_var, constant_var) @@ -293,7 +295,7 @@ end Enzyme.autodiff(Forward, g, Const(condition), Duplicated(x, dx), Const(y)) ``` -The returned value here could either by constant or duplicated, depending on the runtime-defined value of `cond`. If `cond` is true, Enzyme simply returns the shadow of `active_var` as the derivative. However, if `cond` is false, there is no derivative shadow for `constant_var` and Enzyme will throw a "Mismatched activity" error. For some simple types, e.g. a float Enzyme can circumvent this issue, for example by returning the float 0. Similarly, for some types like the Symbol type, which are never differentiable, such a shadow value will never be used, and Enzyme can return the original "primal" value as its derivative. However, for arbitrary data structures, Enzyme presently has no generic mechanism to resolve this. +The returned value here could either by constant or duplicated, depending on the runtime-defined value of `cond`. If `cond` is true, Enzyme simply returns the shadow of `active_var` as the derivative. However, if `cond` is false, there is no derivative shadow for `constant_var` and Enzyme will throw a `EnzymeRuntimeActivityError` error. For some simple types, e.g. a float Enzyme can circumvent this issue, for example by returning the float 0. Similarly, for some types like the Symbol type, which are never differentiable, such a shadow value will never be used, and Enzyme can return the original "primal" value as its derivative. However, for arbitrary data structures, Enzyme presently has no generic mechanism to resolve this. For example consider a third function: ```julia @@ -308,13 +310,17 @@ Enzyme provides a nice utility `Enzyme.make_zero` which takes a data structure a If one created a new zero'd copy of each return from `g`, this would mean that the derivative `dresult` would have one copy made for the first element, and a second copy made for the second element. This could lead to incorrect results, and is unfortunately not a general resolution. However, for non-mutable variables (e.g. like floats) or non-differrentiable types (e.g. like Symbols) this problem can never arise. -Instead, Enzyme has a special mode known as "Runtime Activity" which can handle these types of situations. It can come with a minor performance reduction, and is therefore off by default. It can be enabled with `Enzyme.API.runtimeActivity!(true)` right after importing Enzyme for the first time. +Instead, Enzyme has a special mode known as "Runtime Activity" which can handle these types of situations. It can come with a minor performance reduction, and is therefore off by default. It can be enabled with by setting runtime activity to true in a desired differentiation mode. The way Enzyme's runtime activity resolves this issue is to return the original primal variable as the derivative whenever it needs to denote the fact that a variable is a constant. As this issue can only arise with mutable variables, they must be represented in memory via a pointer. All addtional loads and stores will now be modified to first check if the primal pointer is the same as the shadow pointer, and if so, treat it as a constant. Note that this check is not saying that the same arrays contain the same values, but rather the same backing memory represents both the primal and the shadow (e.g. `a === b` or equivalently `pointer(a) == pointer(b)`). Enabling runtime activity does therefore, come with a sharp edge, which is that if the computed derivative of a function is mutable, one must also check to see if the primal and shadow represent the same pointer, and if so the true derivative of the function is actually zero. -Generally, the preferred solution to these type of activity unstable codes should be to make your variables all activity-stable (e.g. always containing differentiable memory or always containing non-differentiable memory). However, with care, Enzyme does support "Runtime Activity" as a way to differentiate these programs without having to modify your code. +Generally, the preferred solution to these type of activity unstable codes should be to make your variables all activity-stable (e.g. always containing differentiable memory or always containing non-differentiable memory). However, with care, Enzyme does support "Runtime Activity" as a way to differentiate these programs without having to modify your code. One can enable runtime activity for your code by changing the mode, such as + +```julia +Enzyme.autodiff(set_runtime_activity(Forward), h, Const(condition), Duplicated(x, dx), Const(y)) +``` ## Mixed activity diff --git a/examples/custom_rule.jl b/examples/custom_rule.jl index c2098006c2..86ffcf234a 100644 --- a/examples/custom_rule.jl +++ b/examples/custom_rule.jl @@ -57,7 +57,7 @@ using .EnzymeRules # In this section, we write a simple forward rule to start out: -function forward(func::Const{typeof(f)}, ::Type{<:Duplicated}, y::Duplicated, x::Duplicated) +function forward(config::FwdConfig, func::Const{typeof(f)}, ::Type{<:Duplicated}, y::Duplicated, x::Duplicated) println("Using custom rule!") ret = func.val(y.val, x.val) y.dval .= 2 .* x.val .* x.dval @@ -65,6 +65,7 @@ function forward(func::Const{typeof(f)}, ::Type{<:Duplicated}, y::Duplicated, x: end # In the signature of our rule, we have made use of `Enzyme`'s activity annotations. Let's break down each one: +# - the [`FwdConfig`](@ref) configuration passes certain compile-time information about differentiation procedure (the width, and if we're using runtime activity), # - the [`Const`](@ref) annotation on `f` indicates that we accept a function `f` that does not have a derivative component, # which makes sense since `f` is not a closure with data that could be differentiated. # - the [`Duplicated`](@ref) annotation given in the second argument annotates the return value of `f`. This means that @@ -96,7 +97,7 @@ g(y, x) = f(y, x)^2 # function to differentiate # To squeeze out the last drop of performance, the below rule avoids computing the output of the original function and # just computes its derivative. -function forward(func::Const{typeof(f)}, ::Type{<:DuplicatedNoNeed}, y::Duplicated, x::Duplicated) +function forward(config, func::Const{typeof(f)}, ::Type{<:DuplicatedNoNeed}, y::Duplicated, x::Duplicated) println("Using custom rule with DuplicatedNoNeed output.") y.val .= x.val.^2 y.dval .= 2 .* x.val .* x.dval @@ -127,7 +128,7 @@ dy = [0.0, 0.0] Base.delete_method.(methods(forward, (Const{typeof(f)}, Vararg{Any}))) # delete our old rules -function forward(func::Const{typeof(f)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, +function forward(config, func::Const{typeof(f)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, y::Union{Const, Duplicated}, x::Union{Const, Duplicated}) println("Using our general custom rule!") y.val .= x.val.^2 diff --git a/ext/EnzymeChainRulesCoreExt.jl b/ext/EnzymeChainRulesCoreExt.jl index 81491f608e..c6a41d7771 100644 --- a/ext/EnzymeChainRulesCoreExt.jl +++ b/ext/EnzymeChainRulesCoreExt.jl @@ -54,7 +54,7 @@ function Enzyme._import_frule(fn, tys...) end quote - function EnzymeRules.forward(fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {RetAnnotation, FA<:Annotation{<:$(esc(fn))}, $(anns...)} + function EnzymeRules.forward(config, fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {RetAnnotation, FA<:Annotation{<:$(esc(fn))}, $(anns...)} batchsize = same_or_one(1, $(vals...)) if batchsize == 1 dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index c0c4e0b1e6..2d39f92f45 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,11 +1,11 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.7.8" +version = "0.8.0" [compat] Adapt = "3, 4" -julia = "1.6" +julia = "1.10" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 0e67e4e3c0..0175cb4caf 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -216,59 +216,73 @@ const DefaultABI = FFIABI Abstract type for what differentiation mode will be used. """ -abstract type Mode{ABI, ErrIfFuncWritten} end +abstract type Mode{ABI, ErrIfFuncWritten, RuntimeActivity} end """ - struct ReverseMode{ReturnPrimal,ABI,Holomorphic} <: Mode{ABI} + struct ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} Reverse mode differentiation. - `ReturnPrimal`: Should Enzyme return the primal return value from the augmented-forward. +- `RuntimeActivity`: Should Enzyme enable runtime activity (default off) - `ABI`: What runtime ABI to use - `Holomorphic`: Whether the complex result function is holomorphic and we should compute d/dz +- `ErrIfFuncWritten`: Should Enzyme err if the function differentiated is a closure and written to. """ -struct ReverseMode{ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten} end -const Reverse = ReverseMode{false,DefaultABI, false, false}() -const ReverseWithPrimal = ReverseMode{true,DefaultABI, false, false}() -const ReverseHolomorphic = ReverseMode{false,DefaultABI, true, false}() -const ReverseHolomorphicWithPrimal = ReverseMode{true,DefaultABI, true, false}() +struct ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end +const Reverse = ReverseMode{false,false,DefaultABI, false, false}() +const ReverseWithPrimal = ReverseMode{true,false,DefaultABI, false, false}() +const ReverseHolomorphic = ReverseMode{false,false,DefaultABI, true, false}() +const ReverseHolomorphicWithPrimal = ReverseMode{true,false,DefaultABI, true, false}() -@inline set_err_if_func_written(::ReverseMode{ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,ABI,Holomorphic,true}() -@inline clear_err_if_func_written(::ReverseMode{ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,ABI,Holomorphic,false}() +@inline set_err_if_func_written(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,true}() +@inline clear_err_if_func_written(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,false}() -@inline set_abi(::ReverseMode{ReturnPrimal,OldABI,Holomorphic,ErrIfFuncWritten}, ::Type{NewABI}) where {ReturnPrimal,OldABI,Holomorphic,ErrIfFuncWritten,NewABI<:ABI} = ReverseMode{ReturnPrimal,NewABI,Holomorphic,ErrIfFuncWritten}() +@inline set_abi(::ReverseMode{ReturnPrimal,RuntimeActivity,OldABI,Holomorphic,ErrIfFuncWritten}, ::Type{NewABI}) where {ReturnPrimal,RuntimeActivity,OldABI,Holomorphic,ErrIfFuncWritten,NewABI<:ABI} = ReverseMode{ReturnPrimal,RuntimeActivity,NewABI,Holomorphic,ErrIfFuncWritten}() + +@inline set_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,true,ABI,Holomorphic,ErrIfFuncWritten}() +@inline set_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,rt,ABI,Holomorphic,ErrIfFuncWritten}() +@inline clear_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,false,ABI,Holomorphic,ErrIfFuncWritten}() """ - struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI} <: Mode{ABI} + struct ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI} <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity} Reverse mode differentiation. - `ReturnPrimal`: Should Enzyme return the primal return value from the augmented-forward. - `ReturnShadow`: Should Enzyme return the shadow return value from the augmented-forward. +- `RuntimeActivity`: Should Enzyme differentiate with runtime activity on (default off). - `Width`: Batch Size (0 if to be automatically derived) - `ModifiedBetween`: Tuple of each argument's modified between state (true if to be automatically derived). """ -struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten} end -const ReverseSplitNoPrimal = ReverseModeSplit{false, true, 0, true,DefaultABI, false}() -const ReverseSplitWithPrimal = ReverseModeSplit{true, true, 0, true,DefaultABI, false}() -@inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, Width, MBO, ABI, ErrIfFuncWritten}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,Width,MB,MBO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,MB,ABI, ErrIfFuncWritten}() -@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, WidthO, MB, ABI, ErrIfFuncWritten}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,Width,MB,WidthO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,MB,ABI, ErrIfFuncWritten}() +struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI, ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end +const ReverseSplitNoPrimal = ReverseModeSplit{false, true, false, 0, true,DefaultABI, false}() +const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,DefaultABI, false}() +@inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, Width, MBO, ABI, ErrIfFuncWritten}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,MBO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,MB,ABI, ErrIfFuncWritten}() +@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, WidthO, MB, ABI, ErrIfFuncWritten}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,WidthO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,ABI, ErrIfFuncWritten}() + +@inline set_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, true}() +@inline clear_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, false}() -@inline set_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, true}() -@inline clear_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, false}() +@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,true,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() +@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() +@inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() """ - struct Forward <: Mode + struct Forward{ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} Forward mode differentiation """ -struct ForwardMode{ABI, ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten} +struct ForwardMode{ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} end -const Forward = ForwardMode{DefaultABI, false}() +const Forward = ForwardMode{DefaultABI, false, false}() +@inline set_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,true,RuntimeActivity}() +@inline clear_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,false,RuntimeActivity}() -@inline set_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten}) where {ABI,ErrIfFuncWritten} = ForwardMode{ABI,true}() -@inline clear_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten}) where {ABI,ErrIfFuncWritten} = ForwardMode{ABI,false}() +@inline set_abi(::ForwardMode{OldABI,ErrIfFuncWritten,RuntimeActivity}, ::Type{NewABI}) where {OldABI,ErrIfFuncWritten,RuntimeActivity,NewABI<:ABI} = ForwardMode{NewABI,ErrIfFuncWritten,RuntimeActivity}() -@inline set_abi(::ForwardMode{OldABI,ErrIfFuncWritten}, ::Type{NewABI}) where {OldABI,ErrIfFuncWritten,NewABI<:ABI} = ForwardMode{NewABI,ErrIfFuncWritten}() +@inline set_runtime_activity(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,ErrIfFuncWritten,true}() +@inline set_runtime_activity(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}, rt::Bool) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,ErrIfFuncWritten,rt}() +@inline clear_runtime_activity(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,ErrIfFuncWritten,false}() function autodiff end function autodiff_deferred end diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 398c790087..27b14619e3 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -1,42 +1,77 @@ module EnzymeRules -import EnzymeCore: Annotation, Const, Duplicated -export Config, ConfigWidth, AugmentedReturn -export needs_primal, needs_shadow, width, overwritten +import EnzymeCore +import EnzymeCore: Annotation, Const, Duplicated, Mode +export RevConfig, RevConfigWidth +export FwdConfig, FwdConfigWidth +export AugmentedReturn +export needs_primal, needs_shadow, width, overwritten, runtime_activity export primal_type, shadow_type, tape_type import Base: unwrapva, isvarargtype, unwrap_unionall, rewrap_unionall """ - forward(func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...) + forward(fwdconfig, func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...) -Calculate the forward derivative. The first argument `func` is the callable -for which the rule applies to. Either wrapped in a [`Const`](@ref)), or -a [`Duplicated`](@ref) if it is a closure. -The second argument is the return type annotation, and all other arguments are -the annotated function arguments. +Calculate the forward derivative. The first argument is a [`FwdConfig](@ref) object +describing parameters of the differentiation. +The second argument `func` is the callable for which the rule applies to. +Either wrapped in a [`Const`](@ref)), or a [`Duplicated`](@ref) if it is a closure. +The third argument is the return type annotation, and all other arguments are the annotated function arguments. """ function forward end """ - Config{NeedsPrimal, NeedsShadow, Width, Overwritten} - ConfigWidth{Width} = Config{<:Any,<:Any, Width} + FwdConfig{Width, RuntimeActivity} + FwdConfigWidth{Width} = FwdConfig{Width} + +Configuration type to dispatch on in custom forward rules (see [`forward`](@ref). +* `Width`: an integer that specifies the number of adjoints/shadows simultaneously being propagated. +* `RuntimeActivity`: whether runtime activity is enabled. + +Getters for the type parameters are provided by `width` and `runtime_activity`. +""" +struct FwdConfig{Width, RuntimeActivity} end +const FwdConfigWidth{Width} = FwdConfig{Width} +@inline width(::FwdConfig{Width}) where Width = Width +@inline runtime_activity(::FwdConfig{<:Any, RuntimeActivity}) where RuntimeActivity = RuntimeActivity + + +""" + RevConfig{NeedsPrimal, NeedsShadow, Width, Overwritten, RuntimeActivity} + RevConfigWidth{Width} = RevConfig{<:Any,<:Any, Width} Configuration type to dispatch on in custom reverse rules (see [`augmented_primal`](@ref) and [`reverse`](@ref)). * `NeedsPrimal` and `NeedsShadow`: boolean values specifying whether the primal and shadow (resp.) should be returned. * `Width`: an integer that specifies the number of adjoints/shadows simultaneously being propagated. * `Overwritten`: a tuple of booleans of whether each argument (including the function itself) is modified between the forward and reverse pass (true if potentially modified between). +* `RuntimeActivity`: whether runtime activity is enabled. + +Getters for the four type parameters are provided by `needs_primal`, `needs_shadow`, `width`, `overwritten`, and `runtime_activity`. +""" +struct RevConfig{NeedsPrimal, NeedsShadow, Width, Overwritten, RuntimeActivity} end +const RevConfigWidth{Width} = RevConfig{<:Any,<:Any, Width} + +@inline needs_primal(::RevConfig{NeedsPrimal}) where NeedsPrimal = NeedsPrimal +@inline needs_shadow(::RevConfig{<:Any, NeedsShadow}) where NeedsShadow = NeedsShadow +@inline width(::RevConfig{<:Any, <:Any, Width}) where Width = Width +@inline overwritten(::RevConfig{<:Any, <:Any, <:Any, Overwritten}) where Overwritten = Overwritten +@inline runtime_activity(::RevConfig{<:Any, <:Any, <:Any, <:Any, RuntimeActivity}) where RuntimeActivity = RuntimeActivity + +""" + primal_type(::RevConfig, ::Type{<:Annotation{RT}}) -Getters for the four type parameters are provided by `needs_primal`, `needs_shadow`, `width`, and `overwritten`. +Compute the exepcted primal return type given a reverse mode config and return activity """ -struct Config{NeedsPrimal, NeedsShadow, Width, Overwritten} end -const ConfigWidth{Width} = Config{<:Any,<:Any, Width} +@inline primal_type(config::RevConfig, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing -@inline needs_primal(::Config{NeedsPrimal}) where NeedsPrimal = NeedsPrimal -@inline needs_shadow(::Config{<:Any, NeedsShadow}) where NeedsShadow = NeedsShadow -@inline width(::Config{<:Any, <:Any, Width}) where Width = Width -@inline overwritten(::Config{<:Any, <:Any, <:Any, Overwritten}) where Overwritten = Overwritten +""" + shadow_type(::RevConfig, ::Type{<:Annotation{RT}}) + +Compute the exepcted shadow return type given a reverse mode config and return activity +""" +@inline shadow_type(config::RevConfig, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing """ AugmentedReturn(primal, shadow, tape) @@ -73,7 +108,7 @@ end @inline tape_type(::Type{AugmentedReturnFlexShadow{PrimalType,ShadowType,TapeType}}) where {PrimalType,ShadowType,TapeType} = TapeType @inline tape_type(::AugmentedReturnFlexShadow{PrimalType,ShadowType,TapeType}) where {PrimalType,ShadowType,TapeType} = TapeType """ - augmented_primal(::Config, func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...) + augmented_primal(::RevConfig, func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...) Must return an [`AugmentedReturn`](@ref) type. * The primal must be the same type of the original return if `needs_primal(config)`, otherwise nothing. @@ -84,8 +119,8 @@ Must return an [`AugmentedReturn`](@ref) type. function augmented_primal end """ - reverse(::Config, func::Annotation{typeof(f)}, dret::Active, tape, args::Annotation...) - reverse(::Config, func::Annotation{typeof(f)}, ::Type{<:Annotation), tape, args::Annotation...) + reverse(::RevConfig, func::Annotation{typeof(f)}, dret::Active, tape, args::Annotation...) + reverse(::RevConfig, func::Annotation{typeof(f)}, ::Type{<:Annotation), tape, args::Annotation...) Takes gradient of derivative, activity annotation, and tape. If there is an active return dret is passed as Active{T} with the derivative of the active return val. Otherwise dret is passed as Type{Duplicated{T}}, etc. @@ -117,7 +152,7 @@ function has_frule_from_sig(@nospecialize(TT); method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, caller::Union{Nothing,Core.MethodInstance}=nothing) ft, tt = _annotate_tt(TT) - TT = Tuple{<:Annotation{ft}, Type{<:Annotation}, tt...} + TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} return isapplicable(forward, TT; world, method_table, caller) end @@ -126,7 +161,7 @@ function has_rrule_from_sig(@nospecialize(TT); method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, caller::Union{Nothing,Core.MethodInstance}=nothing) ft, tt = _annotate_tt(TT) - TT = Tuple{<:Config, <:Annotation{ft}, Type{<:Annotation}, tt...} + TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} return isapplicable(augmented_primal, TT; world, method_table, caller) end @@ -241,4 +276,6 @@ Mark a particular type `Ty` as always being inactive. """ inactive_type(::Type) = false +@inline EnzymeCore.set_runtime_activity(::M, ::Config) where {M<:Mode, Config <: Union{FwdConfig, RevConfig}} = EnzymeCore.set_runtime_activity(M, runtime_activity(Config)) + end # EnzymeRules diff --git a/lib/EnzymeCore/test/runtests.jl b/lib/EnzymeCore/test/runtests.jl index 9b76ebf56b..d85d4dea15 100644 --- a/lib/EnzymeCore/test/runtests.jl +++ b/lib/EnzymeCore/test/runtests.jl @@ -4,7 +4,7 @@ using EnzymeCore import EnzymeCore.EnzymeRules: forward, has_frule_from_sig g(x) = x ^ 2 -function forward(::Const{typeof(g)}, ::Type{<:Const}, x::Const) +function forward(config, ::Const{typeof(g)}, ::Type{<:Const}, x::Const) return Const(g(x.val)) end @@ -12,7 +12,7 @@ end f(;kwargs) = 1.0 -function forward(::Const{typeof(f)}, ::Type{<:Const}; kwargs...) +function forward(config, ::Const{typeof(f)}, ::Type{<:Const}; kwargs...) return Const(f(; kwargs...)) end diff --git a/lib/EnzymeTestUtils/Project.toml b/lib/EnzymeTestUtils/Project.toml index 05e5e6b94e..72684a9781 100644 --- a/lib/EnzymeTestUtils/Project.toml +++ b/lib/EnzymeTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeTestUtils" uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a" authors = ["Seth Axen ", "William Moses ", "Valentin Churavy "] -version = "0.1.8" +version = "0.2.0" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" @@ -13,12 +13,12 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] ConstructionBase = "1.4.1" -Enzyme = "0.11, 0.12, 0.13" -EnzymeCore = "0.5, 0.6, 0.7" +Enzyme = "0.13" +EnzymeCore = "0.5, 0.6, 0.7, 0.8" FiniteDifferences = "0.12.12" MetaTesting = "0.1" Quaternions = "0.7" -julia = "1.6" +julia = "1.10" [extras] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/lib/EnzymeTestUtils/src/test_forward.jl b/lib/EnzymeTestUtils/src/test_forward.jl index e57a5c7e34..fcfc987cb9 100644 --- a/lib/EnzymeTestUtils/src/test_forward.jl +++ b/lib/EnzymeTestUtils/src/test_forward.jl @@ -61,6 +61,7 @@ function test_forward( rtol::Real=1e-9, atol::Real=1e-9, testset_name=nothing, + runtime_activity::Bool=false ) call_with_copy(f, xs...) = deepcopy(f)(deepcopy(xs)...; deepcopy(fkwargs)...) call_with_kwargs(f, xs...) = f(xs...; fkwargs...) @@ -78,7 +79,7 @@ function test_forward( # call finitedifferences, avoid mutating original arguments dy_fdm = _fd_forward(fdm, call_with_copy, ret_activity, y, activities) # call autodiff, allow mutating original arguments - y_and_dy_ad = autodiff(Forward, call_with_kwargs, ret_activity, activities...) + y_and_dy_ad = autodiff(set_runtime_activity(Forward, runtime_activity), call_with_kwargs, ret_activity, activities...) if ret_activity <: Union{Duplicated,BatchDuplicated} @test_msg( "For return type $ret_activity the return value and derivative must be returned", diff --git a/lib/EnzymeTestUtils/src/test_reverse.jl b/lib/EnzymeTestUtils/src/test_reverse.jl index f204b00a7b..6c20aebb7a 100644 --- a/lib/EnzymeTestUtils/src/test_reverse.jl +++ b/lib/EnzymeTestUtils/src/test_reverse.jl @@ -81,6 +81,7 @@ function test_reverse( rtol::Real=1e-9, atol::Real=1e-9, testset_name=nothing, + runtime_activity::Bool=false ) call_with_captured_kwargs(f, xs...) = f(xs...; fkwargs...) if testset_name === nothing @@ -108,8 +109,9 @@ function test_reverse( dx_fdm = _fd_reverse(fdm, call_with_captured_kwargs, ȳ, activities, !(ret_activity <: Const)) # call autodiff, allow mutating original arguments c_act = Const(call_with_kwargs) + mode = set_runtime_activity(ReverseSplitWithPrimal, runtime_activity) forward, reverse = autodiff_thunk( - ReverseSplitWithPrimal, typeof(c_act), ret_activity, typeof(Const(fkwargs)), map(typeof, activities)... + mode, typeof(c_act), ret_activity, typeof(Const(fkwargs)), map(typeof, activities)... ) tape, y_ad, shadow_result = forward(c_act, Const(fkwargs), activities...) test_approx( diff --git a/lib/EnzymeTestUtils/test/test_forward.jl b/lib/EnzymeTestUtils/test/test_forward.jl index 7f870af7bf..57385a1dd9 100644 --- a/lib/EnzymeTestUtils/test/test_forward.jl +++ b/lib/EnzymeTestUtils/test/test_forward.jl @@ -178,7 +178,6 @@ end end @testset "mutating function" begin - Enzyme.API.runtimeActivity!(true) sz = (2, 3) @testset for Tret in (Const, Duplicated, BatchDuplicated), Tx in (Const, Duplicated, BatchDuplicated), @@ -196,10 +195,9 @@ end atol = rtol = sqrt(eps(real(T))) @test !fails() do - test_forward(f_mut_fwd!, Tret, (y, Ty), (x, Tx), (a, Ta); atol, rtol) + test_forward(f_mut_fwd!, Tret, (y, Ty), (x, Tx), (a, Ta); atol, rtol, runtime_activity=true) end skip = (VERSION < v"1.8" && T <: Complex) end - Enzyme.API.runtimeActivity!(false) end @testset "incorrect mutated argument detected" begin diff --git a/src/Enzyme.jl b/src/Enzyme.jl index a5b949c60a..fcc12d57a8 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -5,8 +5,8 @@ import EnzymeCore import EnzymeCore: Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal -import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi -export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi +import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity +export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity import EnzymeCore: BatchDuplicatedFunc export BatchDuplicatedFunc @@ -229,7 +229,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) [`Active`](@ref) will automatically convert plain integers to floating point values, but cannot do so for integer values in tuples and structs. """ -@inline function autodiff(rmode::ReverseMode{ReturnPrimal, RABI,Holomorphic, ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RABI<:ABI,Holomorphic, Nargs, ErrIfFuncWritten} +@inline function autodiff(rmode::ReverseMode{ReturnPrimal, RuntimeActivity,RABI,Holomorphic, ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RuntimeActivity, RABI<:ABI,Holomorphic, Nargs, ErrIfFuncWritten} tt′ = vaTypeof(args...) width = same_or_one(1, args...) if width == 0 @@ -256,7 +256,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) if A <: Active if (!allocatedinline(rt) || rt isa Union) && rt != Union{} - forward, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI, Val(ErrIfFuncWritten)) + forward, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) res = forward(f, args...) tape = res[1] if ReturnPrimal @@ -286,7 +286,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) args = seed_complex_args(seen, seen2, args...) tt′ = vaTypeof(args...) - thunk = Enzyme.Compiler.thunk(opt_mi, typeof(f), A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + thunk = Enzyme.Compiler.thunk(opt_mi, typeof(f), A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) results = thunk(f, args..., (rt(0), rt(1), rt(im))) @@ -308,7 +308,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) throw(ErrorException("Reverse-mode Active Complex return is ambiguous and requires more information to specify the desired result. See https://enzyme.mit.edu/julia/stable/faq/#Complex-numbers for more details.")) end - thunk = Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + thunk = Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) if A <: Active args = (args..., Compiler.default_adjoint(rt)) @@ -389,7 +389,7 @@ f(x) = x*x (6.28,) ``` """ -@inline function autodiff(::ForwardMode{RABI, ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {RABI <: ABI, Nargs, ErrIfFuncWritten} +@inline function autodiff(::ForwardMode{RABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {RABI <: ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end @@ -429,7 +429,7 @@ f(x) = x*x end thunk = Enzyme.Compiler.thunk(opt_mi, FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), - ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) thunk(f, args...) end @@ -439,7 +439,7 @@ end Same as [`autodiff`](@ref) but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ReverseMode{ReturnPrimal, ABI,Holomorphic,ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, Nargs, ABI,Holomorphic,ErrIfFuncWritten} +@inline function autodiff_deferred(::ReverseMode{ReturnPrimal, RuntimeActivity, ABI,Holomorphic,ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, Nargs, ABI,Holomorphic,ErrIfFuncWritten, RuntimeActivity} tt′ = vaTypeof(args...) width = same_or_one(1, args...) if width == 0 @@ -463,7 +463,7 @@ code, as well as high-order differentiation. ModifiedBetween = Val(falses_from_args(Nargs+1)) - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), UnknownTapeType, Val(ErrIfFuncWritten)) + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), UnknownTapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) thunk = Compiler.CombinedAdjointThunk{Ptr{Cvoid}, FA, rt, tt′, width, ReturnPrimal}(adjoint_ptr) if rt <: Active @@ -480,7 +480,7 @@ end Same as `autodiff(::ForwardMode, f, Activity, args)` but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ForwardMode{ABI, ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, Nargs, ABI, ErrIfFuncWritten} +@inline function autodiff_deferred(::ForwardMode{ABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, Nargs, ABI, ErrIfFuncWritten, RuntimeActivity} if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end @@ -527,7 +527,7 @@ code, as well as high-order differentiation. ReturnPrimal = RT <: Duplicated || RT <: BatchDuplicated ModifiedBetween = Val(falses_from_args(Nargs+1)) - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), UnknownTapeType, Val(ErrIfFuncWritten)) + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), UnknownTapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) thunk = Compiler.ForwardModeThunk{Ptr{Cvoid}, FA, rt, tt′, width, ReturnPrimal}(adjoint_ptr) thunk(f, args...) end @@ -608,7 +608,7 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI<:ABI, Nargs, ErrIfFuncWritten} +@inline function autodiff_thunk(rs::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT,RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} width = if Width == 0 w = same_or_one(1, args...) if w == 0 @@ -636,7 +636,7 @@ result, ∂v, ∂A else Val(codegen_world_age(eltype(FA), tt)) end - Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) end """ @@ -683,7 +683,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated (6.28,) ``` """ -@inline function autodiff_thunk(::ForwardMode{RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, RABI<:ABI, Nargs, ErrIfFuncWritten} +@inline function autodiff_thunk(::ForwardMode{RABI, ErrIfFuncWritten, RuntimeActivity}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} width = same_or_one(1, A, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) @@ -702,10 +702,10 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated else Val(codegen_world_age(eltype(FA), tt)) end - Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) end -@inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, ErrIfFuncWritten} +@inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT, RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} width = if Width == 0 w = same_or_one(1, args...) if w == 0 @@ -731,7 +731,7 @@ end else Val(codegen_world_age(eltype(FA), primal_tt)) end - nondef = Enzyme.Compiler.thunk(opt_mi, FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + nondef = Enzyme.Compiler.thunk(opt_mi, FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) if nondef[1] isa Enzyme.Compiler.PrimalErrorThunk return Nothing else @@ -747,9 +747,9 @@ const tape_cache_lock = ReentrantLock() import .Compiler: fspec, remove_innerty, UnknownTapeType @inline function tape_type( - parent_job::Union{GPUCompiler.CompilerJob,Nothing}, ::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, + parent_job::Union{GPUCompiler.CompilerJob,Nothing}, ::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs} -) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} +) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, RuntimeActivity} width = if Width == 0 w = same_or_one(1, args...) if w == 0 @@ -779,7 +779,8 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType params = Compiler.EnzymeCompilerParams( Tuple{FA, TT.parameters...}, API.DEM_ReverseModeGradient, width, Compiler.remove_innerty(A), true, #=abiwrap=#false, ModifiedBetweenT, - ReturnPrimal, #=ShadowInit=#false, Compiler.UnknownTapeType, RABI, #=errifwritte=#false + ReturnPrimal, #=ShadowInit=#false, Compiler.UnknownTapeType, RABI, #=errifwritte=#false, + RuntimeActivity ) job = Compiler.CompilerJob(mi, Compiler.CompilerConfig(target, params; kernel=false)) @@ -849,7 +850,7 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_deferred_thunk(mode::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI, ErrIfFuncWritten}, tt::Type{TapeType}, fa::Type{FA}, a2::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A2<:Annotation, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, ErrIfFuncWritten} +@inline function autodiff_deferred_thunk(mode::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT, RABI, ErrIfFuncWritten}, tt::Type{TapeType}, fa::Type{FA}, a2::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A2<:Annotation, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} @assert RABI == FFIABI width = if Width == 0 w = same_or_one(1, args...) @@ -873,8 +874,8 @@ result, ∂v, ∂A primal_tt = Tuple{map(eltype, args)...} world = codegen_world_age(eltype(FA), primal_tt) - primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType, Val(ErrIfFuncWritten)) - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType, Val(ErrIfFuncWritten)) + primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) RT = if A2 <: Duplicated && width != 1 if A2 isa UnionAll @@ -1031,15 +1032,23 @@ grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) (a = 3.0, b = [2.0], c = "str") ``` """ -@inline function gradient(rm::ReverseMode, f::F, x::X) where {F, X} +@inline function gradient(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::X) where {F, X, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState dx = Ref(make_zero(x)) - autodiff(rm, f, Active, MixedDuplicated(x, dx)) - return only(dx) + res = autodiff(rm, f, Active, MixedDuplicated(x, dx)) + if ReturnPrimal + (res[2], only(dx)) + else + only(dx) + end else dx = make_zero(x) - autodiff(rm, f, Active, Duplicated(x, dx)) - return dx + res = autodiff(rm, f, Active, Duplicated(x, dx)) + if ReturnPrimal + (res[2], dx) + else + dx + end end end @@ -1048,15 +1057,23 @@ end Like [`gradient`](@ref), except it using deferred mode. """ -@inline function gradient_deferred(rm::ReverseMode, f::F, x::X) where {F, X} +@inline function gradient_deferred(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::X) where {F, X, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState dx = Ref(make_zero(x)) autodiff_deferred(rm, f, Active, MixedDuplicated(x, dx)) - return only(dx) + if ReturnPrimal + return (res[2], only(dx)) + else + return only(dx) + end else dx = make_zero(x) autodiff_deferred(rm, f, Active, Duplicated(x, dx)) - return dx + if ReturnPrimal + (res[2], dx) + else + dx + end end end @@ -1082,10 +1099,14 @@ gradient!(Reverse, dx, f, [2.0, 3.0]) 2.0 ``` """ -@inline function gradient!(::ReverseMode, dx::X, f::F, x::X) where {X<:Array, F} +@inline function gradient!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} make_zero!(dx) - autodiff(Reverse, f, Active, Duplicated(x, dx)) - dx + res = autodiff(rm, f, Active, Duplicated(x, dx)) + return if ReturnPrimal + (res[2], dx) + else + dx + end end @@ -1094,10 +1115,14 @@ end Like [`gradient!`](@ref), except it using deferred mode. """ -@inline function gradient_deferred!(::ReverseMode, dx::X, f::F, x::X) where {X<:Array, F} +@inline function gradient_deferred!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} make_zero!(dx) - autodiff_deferred(Reverse, f, Active, Duplicated(x, dx)) - dx + autodiff_deferred(rm, f, Active, Duplicated(x, dx)) + return if ReturnPrimal + (res[2], dx) + else + dx + end end """ @@ -1121,11 +1146,11 @@ grad = gradient(Forward, f, [2.0, 3.0]) (3.0, 2.0) ``` """ -@inline function gradient(::ForwardMode, f, x; shadow=onehot(x)) +@inline function gradient(fm::ForwardMode, f, x; shadow=onehot(x)) if length(shadow) == 0 return () end - res = values(only(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) + res = values(only(autodiff(fm, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) if x isa AbstractFloat res[1] else @@ -1169,12 +1194,12 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2)) (3.0, 2.0) ``` """ -@inline function gradient(::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk} +@inline function gradient(fm::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk} if chunk == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end tmp = ntuple(length(shadow)) do i - values(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1]) + values(autodiff(fm, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1]) end res = tupleconcat(tmp...) if x isa AbstractFloat @@ -1184,9 +1209,9 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2)) end end -@inline function gradient(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X} +@inline function gradient(fm::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X} res = ntuple(length(shadow)) do i - autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] + autodiff(fm, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] end if x isa AbstractFloat res[1] @@ -1223,11 +1248,11 @@ whose shape is `(size(output)..., size(input)...)` For functions who return other types, this function will retun an array or tuple of shape `size(input)` of values of the output type. """ -@inline function jacobian(::ForwardMode, f, x; shadow=onehot(x)) +@inline function jacobian(fm::ForwardMode, f, x; shadow=onehot(x)) cols = if length(shadow) == 0 () else - values(only(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) + values(only(autodiff(fm, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) end if x isa AbstractFloat cols[1] @@ -1252,13 +1277,13 @@ of shape `size(input)` of values of the output type. end end -@inline function jacobian(::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk} +@inline function jacobian(fm::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk} if chunk == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end tmp = ntuple(length(shadow)) do i Base.@_inline_meta - values(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1]) + values(autodiff(fm, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1]) end cols = tupleconcat(tmp...) if x isa AbstractFloat @@ -1284,10 +1309,10 @@ end end end -@inline function jacobian(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F,X} +@inline function jacobian(fm::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F,X} cols = ntuple(length(shadow)) do i Base.@_inline_meta - autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] + autodiff(fm, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] end if x isa AbstractFloat cols[1] @@ -1341,7 +1366,7 @@ For functions who return other types, this function will retun an array or tuple of shape `size(output)` of values of the input type. ``` """ -@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, #=Holomorphic=#false, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, RABI<:ABI, ErrIfFuncWritten} +@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RuntimeActivity, RABI, #=Holomorphic=#false, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, RABI<:ABI, ErrIfFuncWritten, RuntimeActivity} num = ((n_out_val + chunk - 1) ÷ chunk) if chunk == 0 @@ -1360,7 +1385,7 @@ of shape `size(output)` of values of the input type. else Val(codegen_world_age(Core.Typeof(f), tt)) end - primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) if num * chunk == n_out_val last_size = chunk @@ -1368,7 +1393,7 @@ of shape `size(output)` of values of the input type. else last_size = n_out_val - (num-1)*chunk tt′ = Tuple{BatchDuplicated{Core.Typeof(x), last_size}} - primal2, adjoint2 = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + primal2, adjoint2 = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) end tmp = ntuple(num) do i @@ -1417,7 +1442,7 @@ of shape `size(output)` of values of the input type. end end -@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, #=Holomorphic=#false, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,RABI<:ABI, ErrIfFuncWritten} +@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RuntimeActivity,RABI, #=Holomorphic=#false, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,RuntimeActivity,RABI<:ABI, ErrIfFuncWritten} XT = Core.Typeof(x) MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} @@ -1430,7 +1455,7 @@ end else Val(codegen_world_age(Core.Typeof(f), tt)) end - primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) tmp = ntuple(n_outs) do i Base.@_inline_meta z = make_zero(x) @@ -1466,12 +1491,12 @@ end end end -@inline function jacobian(::ReverseMode{ReturnPrimal,RABI,Holomorphic,ErrIfFuncWritten}, f::F, x::X) where {ReturnPrimal, F, X, RABI<:ABI, Holomorphic, ErrIfFuncWritten} +@inline function jacobian(::ReverseMode{ReturnPrimal,RuntimeActivity, RABI, Holomorphic, ErrIfFuncWritten}, f::F, x::X) where {ReturnPrimal, F, X, RABI<:ABI, ErrIfFuncWritten, RuntimeActivity, Holomorphic} res = f(x) jac = if res isa AbstractArray - jacobian(ReverseMode{false,RABI, Holomorphic, ErrIfFuncWritten}(), f, x, Val(length(res))) + jacobian(ReverseMode{false,RuntimeActivity,RABI, Holomorphic, ErrIfFuncWritten}(), f, x, Val(length(res))) elseif res isa AbstractFloat - gradient(ReverseMode{false,RABI, Holomorphic, ErrIfFuncWritten}(), f, x) + gradient(ReverseMode{false,RuntimeActivity,RABI, Holomorphic, ErrIfFuncWritten}(), f, x) else throw(AssertionError("Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))")) end diff --git a/src/api.jl b/src/api.jl index 6de95beb55..9e446dcf57 100644 --- a/src/api.jl +++ b/src/api.jl @@ -156,30 +156,30 @@ end # \p AtomicAdd is whether to perform all adjoint updates to memory in an atomic way # \p PostOpt is whether to perform basic optimization of the function after synthesis function EnzymeCreatePrimalAndGradient(logic, todiff, retType, constant_args, TA, - returnValue, dretUsed, mode, width, additionalArg, + returnValue, dretUsed, mode, runtimeActivity, width, additionalArg, forceAnonymousTape, typeInfo, uncacheable_args, augmented, atomicAdd) freeMemory = true ccall((:EnzymeCreatePrimalAndGradient, libEnzyme), LLVMValueRef, (EnzymeLogicRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t, - EnzymeTypeAnalysisRef, UInt8, UInt8, CDerivativeMode, Cuint, UInt8, LLVMTypeRef, UInt8, CFnTypeInfo, + EnzymeTypeAnalysisRef, UInt8, UInt8, CDerivativeMode, UInt8, Cuint, UInt8, LLVMTypeRef, UInt8, CFnTypeInfo, Ptr{UInt8}, Csize_t, EnzymeAugmentedReturnPtr, UInt8), logic, C_NULL, C_NULL, todiff, retType, constant_args, length(constant_args), TA, returnValue, - dretUsed, mode, width, freeMemory, additionalArg, forceAnonymousTape, typeInfo, uncacheable_args, length(uncacheable_args), + dretUsed, mode, runtimeActivity, width, freeMemory, additionalArg, forceAnonymousTape, typeInfo, uncacheable_args, length(uncacheable_args), augmented, atomicAdd) end function EnzymeCreateForwardDiff(logic, todiff, retType, constant_args, TA, - returnValue, mode, width, additionalArg, typeInfo, + returnValue, mode, runtimeActivity, width, additionalArg, typeInfo, uncacheable_args) freeMemory = true aug = C_NULL ccall((:EnzymeCreateForwardDiff, libEnzyme), LLVMValueRef, (EnzymeLogicRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t, - EnzymeTypeAnalysisRef, UInt8, CDerivativeMode, UInt8, Cuint, LLVMTypeRef, CFnTypeInfo, + EnzymeTypeAnalysisRef, UInt8, CDerivativeMode, UInt8, UInt8, Cuint, LLVMTypeRef, CFnTypeInfo, Ptr{UInt8}, Csize_t, EnzymeAugmentedReturnPtr), logic, C_NULL, C_NULL, todiff, retType, constant_args, length(constant_args), TA, returnValue, - mode, freeMemory, width, additionalArg, typeInfo, uncacheable_args, length(uncacheable_args), aug) + mode, freeMemory, runtimeActivity, width, additionalArg, typeInfo, uncacheable_args, length(uncacheable_args), aug) end # Create an augmented forward pass. @@ -195,14 +195,14 @@ end # \p PostOpt is whether to perform basic optimization of the function after synthesis function EnzymeCreateAugmentedPrimal(logic, todiff, retType, constant_args, TA, returnUsed, shadowReturnUsed, - typeInfo, uncacheable_args, forceAnonymousTape, width, atomicAdd) + typeInfo, uncacheable_args, forceAnonymousTape, runtimeActivity, width, atomicAdd) ccall((:EnzymeCreateAugmentedPrimal, libEnzyme), EnzymeAugmentedReturnPtr, (EnzymeLogicRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t, EnzymeTypeAnalysisRef, UInt8, UInt8, - CFnTypeInfo, Ptr{UInt8}, Csize_t, UInt8, Cuint, UInt8), + CFnTypeInfo, Ptr{UInt8}, Csize_t, UInt8, UInt8, Cuint, UInt8), logic, C_NULL, C_NULL, todiff, retType, constant_args, length(constant_args), TA, returnUsed, shadowReturnUsed, - typeInfo, uncacheable_args, length(uncacheable_args), forceAnonymousTape, width, atomicAdd) + typeInfo, uncacheable_args, length(uncacheable_args), forceAnonymousTape, runtimeActivity, width, atomicAdd) end # typedef uint8_t (*CustomRuleType)(int /*direction*/, CTypeTreeRef /*return*/, @@ -252,6 +252,7 @@ EnzymeGradientUtilsErase(gutils, a) = ccall((:EnzymeGradientUtilsErase, libEnzym EnzymeGradientUtilsEraseWithPlaceholder(gutils, a, orig, erase) = ccall((:EnzymeGradientUtilsEraseWithPlaceholder, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, LLVMValueRef, UInt8), gutils, a, orig, erase) EnzymeGradientUtilsGetMode(gutils) = ccall((:EnzymeGradientUtilsGetMode, libEnzyme), CDerivativeMode, (EnzymeGradientUtilsRef,), gutils) EnzymeGradientUtilsGetWidth(gutils) = ccall((:EnzymeGradientUtilsGetWidth, libEnzyme), UInt64, (EnzymeGradientUtilsRef,), gutils) +EnzymeGradientUtilsGetRuntimeActivity(gutils) = ccall((:EnzymeGradientUtilsGetRuntimeActivity, libEnzyme), UInt8, (EnzymeGradientUtilsRef,), gutils) != 0 EnzymeGradientUtilsNewFromOriginal(gutils, val) = ccall((:EnzymeGradientUtilsNewFromOriginal, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef), gutils, val) EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, val, orig) = ccall((:EnzymeGradientUtilsSetDebugLocFromOriginal, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef), gutils, val, orig) EnzymeGradientUtilsLookup(gutils, val, B) = ccall((:EnzymeGradientUtilsLookup, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), gutils, val, B) @@ -556,51 +557,6 @@ function strong_zero!(val) ccall((:EnzymeSetCLInteger, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) end -""" - runtimeActivity!(val::Bool) - -Enzyme runs an activity analysis which deduces which values, instructions, etc -are necessary to be differentiated and therefore involved in the differentiation -procedure. This runs at compile time. However, there may be implementation flaws -in this analysis that means that Enzyme cannot deduce that an inactive (const) -value is actually const. Alternatively, there may be some data which is conditionally -active, depending on which runtime branch is taken. In these cases Enzyme conservatively -presumes the value is active. - -However, in certain cases, an insufficiently aggressive activity analysis may result -in derivative errors -- for example by mistakenly using the primal (const) argument -and mistaking it for the duplicated shadow. As a result this may result in incorrect -results, or accidental updates to the primal. - -This flag enables runntime activity which tells all load/stores to check at runtime -whether the value they are updating is indeed active (in addition to the compile-time -activity analysis). This will remedy these such errors, but at a performance penalty -of performing such checks. - -It is on the Enzyme roadmap to add a PotentiallyDuplicated style activity, in addition -to the current Const and Duplicated styles that will disable the need for this, -which does not require the check when a value is guaranteed active, but still supports -runtime-based activity information. - -This function takes an argument to set the runtime activity value, true means it is on, -and false means off. By default it is off. -""" -function runtimeActivity!(val::Bool) - ptr = cglobal((:EnzymeRuntimeActivityCheck, libEnzyme)) - ccall((:EnzymeSetCLInteger, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) -end - -""" - runtimeActivity() - -Gets the current value of the runtime activity. See [`runtimeActivity!`](@ref) for -more information. -""" -function runtimeActivity() - ptr = cglobal((:EnzymeRuntimeActivityCheck, libEnzyme)) - return EnzymeGetCLBool(ptr) != 0 -end - """ typeWarning!(val::Bool) diff --git a/src/compiler.jl b/src/compiler.jl index 0d9e147156..fb7ce8d8bf 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -954,11 +954,11 @@ end function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError) println(io, "Constant memory is stored (or returned) to a differentiable variable.") println(io, "As a result, Enzyme cannot provably ensure correctness and throws this error.") - println(io, "This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage).") + println(io, "This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).") println(io, "If Enzyme should be able to prove this use non-differentable, open an issue!"); println(io, "To work around this issue, either:"); println(io, " a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or") - println(io, " b) set Enzyme.API.runtimeActivity!(true) immediately after loading Enzyme (which maintains correctness, but may slightly reduce performance).") + println(io, " b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.") msg = Base.unsafe_string(ece.msg) print(io, msg, '\n') end @@ -3324,6 +3324,9 @@ struct EnzymeCompilerParams <: AbstractEnzymeCompilerParams ABI::Type{<:ABI} # Whether to error if the function is written to err_if_func_written::Bool + + # Whether runtime activity is enabled + runtimeActivity::Bool end struct UnknownTapeType end @@ -3843,6 +3846,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr world = job.world interp = GPUCompiler.get_interpreter(job) rt = job.config.params.rt + runtimeActivity = job.config.params.runtimeActivity @assert eltype(rt) != Union{} shadow_init = job.config.params.shadowInit @@ -3960,7 +3964,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr augmented = API.EnzymeCreateAugmentedPrimal( logic, primalf, retType, args_activity, TA, #=returnUsed=# returnUsed, #=shadowReturnUsed=#shadowReturnUsed, - typeInfo, uncacheable_args, #=forceAnonymousTape=# false, width, #=atomicAdd=# parallel) + typeInfo, uncacheable_args, #=forceAnonymousTape=# false, runtimeActivity, width, #=atomicAdd=# parallel) # 2. get new_primalf and tape augmented_primalf = LLVM.Function(API.EnzymeExtractFunctionFromAugmentation(augmented)) @@ -3988,7 +3992,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr adjointf = LLVM.Function(API.EnzymeCreatePrimalAndGradient( logic, primalf, retType, args_activity, TA, - #=returnValue=#false, #=dretUsed=#false, #=mode=#API.DEM_ReverseModeGradient, width, + #=returnValue=#false, #=dretUsed=#false, #=mode=#API.DEM_ReverseModeGradient, runtimeActivity, width, #=additionalArg=#tape, #=forceAnonymousTape=#false, typeInfo, uncacheable_args, augmented, #=atomicAdd=# parallel)) if wrap @@ -3999,7 +4003,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr returnUsed &= returnPrimal adjointf = LLVM.Function(API.EnzymeCreatePrimalAndGradient( logic, primalf, retType, args_activity, TA, - #=returnValue=#returnUsed, #=dretUsed=#false, #=mode=#API.DEM_ReverseModeCombined, width, + #=returnValue=#returnUsed, #=dretUsed=#false, #=mode=#API.DEM_ReverseModeCombined, runtimeActivity, width, #=additionalArg=#C_NULL, #=forceAnonymousTape=#false, typeInfo, uncacheable_args, #=augmented=#C_NULL, #=atomicAdd=# parallel)) augmented_primalf = nothing @@ -4011,7 +4015,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr returnUsed &= returnPrimal adjointf = LLVM.Function(API.EnzymeCreateForwardDiff( logic, primalf, retType, args_activity, TA, - #=returnValue=#returnUsed, #=mode=#API.DEM_ForwardMode, width, + #=returnValue=#returnUsed, #=mode=#API.DEM_ForwardMode, runtimeActivity, width, #=additionalArg=#C_NULL, typeInfo, uncacheable_args)) augmented_primalf = nothing @@ -5495,7 +5499,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; ForwardModeTypes = ("s", "d", "c", "z") ReverseModeTypes = ("s", "d") # Tablegen BLAS does not support forward mode yet - if !(mode == API.DEM_ForwardMode && Enzyme.API.runtimeActivity()) + if !(mode == API.DEM_ForwardMode && params.runtimeActivity) for ty in (mode == API.DEM_ForwardMode ? ForwardModeTypes : ReverseModeTypes) for func in (mode == API.DEM_ForwardMode ? ForwardModeDerivatives : ReverseModeDerivatives) for prefix in ("", "cblas_") @@ -7124,9 +7128,9 @@ end @inline remove_innerty(::Type{<:MixedDuplicated}) = MixedDuplicated @inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated -@inline function thunkbase(ctx, mi::Core.MethodInstance, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI, ErrIfFuncWritten} +@inline function thunkbase(ctx, mi::Core.MethodInstance, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}, ::Val{RuntimeActivity}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI, ErrIfFuncWritten, RuntimeActivity} target = Compiler.EnzymeTarget() - params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI, ErrIfFuncWritten) + params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI, ErrIfFuncWritten, RuntimeActivity) tmp_job = if World isa Nothing Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) else @@ -7165,7 +7169,7 @@ end A2 end - params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, run_enzyme, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI, ErrIfFuncWritten) + params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, run_enzyme, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI, ErrIfFuncWritten, RuntimeActivity) job = if World isa Nothing Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) else @@ -7202,25 +7206,25 @@ end end end -@inline function thunk(mi::Core.MethodInstance, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, ABI, ErrIfFuncWritten} +@inline function thunk(mi::Core.MethodInstance, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}, ::Val{RuntimeActivity}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, ABI, ErrIfFuncWritten, RuntimeActivity} ts_ctx = JuliaContext() ctx = context(ts_ctx) activate(ctx) try - return thunkbase(ctx, mi, Val(#=World=#nothing), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten)) + return thunkbase(ctx, mi, Val(#=World=#nothing), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) finally deactivate(ctx) dispose(ts_ctx) end end -@inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI, ErrIfFuncWritten} +@inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}, ::Val{RuntimeActivity}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI, ErrIfFuncWritten, RuntimeActivity} mi = fspec(eltype(FA), TT, World) ts_ctx = JuliaContext() ctx = context(ts_ctx) activate(ctx) res = try - thunkbase(ctx, mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten)) + thunkbase(ctx, mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) finally deactivate(ctx) dispose(ts_ctx) @@ -7234,14 +7238,14 @@ end import GPUCompiler: deferred_codegen_jobs @generated function deferred_codegen(::Val{World}, ::Type{FA}, ::Val{TT}, ::Val{A},::Val{Mode}, - ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal},::Val{ShadowInit},::Type{ExpectedTapeType}, ::Val{ErrIfFuncWritten}) where {World, FA<:Annotation,TT, A, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, ErrIfFuncWritten} + ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal},::Val{ShadowInit},::Type{ExpectedTapeType}, ::Val{ErrIfFuncWritten}, ::Val{RuntimeActivity}) where {World, FA<:Annotation,TT, A, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, ErrIfFuncWritten, RuntimeActivity} JuliaContext() do ctx Base.@_inline_meta mi = fspec(eltype(FA), TT, World) target = EnzymeTarget() rt2 = if A isa UnionAll - params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI, ErrIfFuncWritten) + params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI, ErrIfFuncWritten, RuntimeActivity) tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) interp = GPUCompiler.get_interpreter(tmp_job) @@ -7265,7 +7269,7 @@ import GPUCompiler: deferred_codegen_jobs A end - params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI, ErrIfFuncWritten) + params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI, ErrIfFuncWritten, RuntimeActivity) job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) addr = get_trampoline(job) diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl index 944b0b2498..583a6f2f68 100644 --- a/src/compiler/reflection.jl +++ b/src/compiler/reflection.jl @@ -1,5 +1,5 @@ function get_job(@nospecialize(func), @nospecialize(A), @nospecialize(types); - run_enzyme::Bool=true, mode::API.CDerivativeMode=API.DEM_ReverseModeCombined, dupClosure::Bool=false, argwrap::Bool=true, width::Int=1, modifiedBetween=nothing, returnPrimal::Bool=false, augmentedInit=false, world=nothing, ABI=DefaultABI, ErrIfFuncWritten=false, kwargs...) + run_enzyme::Bool=true, mode::API.CDerivativeMode=API.DEM_ReverseModeCombined, dupClosure::Bool=false, argwrap::Bool=true, width::Int=1, modifiedBetween=nothing, returnPrimal::Bool=false, augmentedInit=false, world=nothing, ABI=DefaultABI, ErrIfFuncWritten=false, RuntimeActivity=true, kwargs...) tt = Tuple{map(eltype, types.parameters)...} if world === nothing @@ -15,7 +15,7 @@ function get_job(@nospecialize(func), @nospecialize(A), @nospecialize(types); defaultMod = mode != API.DEM_ReverseModeCombined && mode != API.DEM_ForwardMode modifiedBetween = (defaultMod, (defaultMod for _ in types.parameters)...) end - params = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){Core.Typeof(func)}, types.parameters...}, mode, width, rt, run_enzyme, argwrap, modifiedBetween, returnPrimal, augmentedInit, Compiler.UnknownTapeType, ABI, ErrIfFuncWritten) + params = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){Core.Typeof(func)}, types.parameters...}, mode, width, rt, run_enzyme, argwrap, modifiedBetween, returnPrimal, augmentedInit, Compiler.UnknownTapeType, ABI, ErrIfFuncWritten, RuntimeActivity) return Compiler.CompilerJob(primal, CompilerConfig(target, params; kernel=false), world) end diff --git a/src/gradientutils.jl b/src/gradientutils.jl index f7f80fd396..ff83f60c1d 100644 --- a/src/gradientutils.jl +++ b/src/gradientutils.jl @@ -13,6 +13,7 @@ end get_width(gutils::GradientUtils) = API.EnzymeGradientUtilsGetWidth(gutils) get_mode(gutils::GradientUtils) = API.EnzymeGradientUtilsGetMode(gutils) +get_runtime_activity(gutils::GradientUtils) = API.EnzymeGradientUtilsGetRuntimeActivity(gutils) function get_shadow_type(gutils::GradientUtils, T::LLVM.LLVMType) w = get_width(gutils) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 8b772e8e24..96f774f69e 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -121,23 +121,13 @@ Enzyme.EnzymeRules.inactive_noinl(::typeof(Core._compute_sparams), args...) = no @inline EnzymeRules.inactive_type(v::Type{Core.Compiler.WorldRange}) = true @inline EnzymeRules.inactive_type(v::Type{Core.MethodInstance}) = true -@inline width(::Duplicated) = 1 -@inline width(::BatchDuplicated{T, N}) where {T, N} = N -@inline width(::DuplicatedNoNeed) = 1 -@inline width(::BatchDuplicatedNoNeed{T, N}) where {T, N} = N - -@inline width(::Type{Duplicated{T}}) where T = 1 -@inline width(::Type{BatchDuplicated{T, N}}) where {T, N} = N -@inline width(::Type{DuplicatedNoNeed{T}}) where T = 1 -@inline width(::Type{BatchDuplicatedNoNeed{T, N}}) where {T, N} = N - # Note all of these forward mode definitions do not support runtime activity as # the do not keep the primal if shadow(x.y) == primal(x.y) -function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) return deepcopy(x.dval) end -function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} ntuple(Val(N)) do _ deepcopy(x.dval) end @@ -164,19 +154,19 @@ end return seen[shadow] end -function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:Duplicated}, x::Duplicated) +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{typeof(Base.deepcopy)}, ::Type{<:Duplicated}, x::Duplicated) primal = func.val(x.val) return Duplicated(primal, deepcopy_rtact(primal, x.val, IdDict(), x.dval)) end -function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} primal = func.val(x.val) return BatchDuplicated(primal, ntuple(Val(N)) do i deepcopy_rtact(primal, x.val, IdDict(), x.dval[i]) end) end -function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, x::Annotation{Ty}) where {RT, Ty} +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, x::Annotation{Ty}) where {RT, Ty} primal = if EnzymeRules.needs_primal(config) func.val(x.val) else @@ -244,7 +234,7 @@ end return seen[into] end -function EnzymeRules.reverse(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, x::Annotation{Ty}) where {RT, Ty} +function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, x::Annotation{Ty}) where {RT, Ty} if EnzymeRules.needs_shadow(config) if EnzymeRules.width(config) == 1 accumulate_into(x.dval, IdDict(), shadow) @@ -266,9 +256,9 @@ end unsafe_store!(tapes, thunk(f, Const(idx), fargs...)[1], idx) end -function EnzymeRules.augmented_primal(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} - config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI, false}() + config2 = ReverseModeSplit{false, false, EnzymeRules.runtime_activity(config), EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI, false}() fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) TapeType = EnzymeRules.tape_type(fwd_thunk) @@ -291,9 +281,9 @@ end thunk(f, Const(idx), fargs..., unsafe_load(tapes, idx)) end -function EnzymeRules.reverse(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, tapes, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} +function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, tapes, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} - config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI, false}() + config2 = ReverseModeSplit{false, false, EnzymeRules.runtime_activity(config), EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI, false}() fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) Enzyme.pmap(pmap_rev, count.val, tapes, rev_thunk, body, args...) @@ -338,7 +328,7 @@ end # y=inv(A) B # dA −= z y^T # dB += z, where z = inv(A^T) dy -function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT}, A::Annotation{AT}, b::Annotation{BT}) where {RT, AT <: Array, BT <: Array} +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(\)}, ::Type{RT}, A::Annotation{AT}, b::Annotation{BT}) where {RT, AT <: Array, BT <: Array} cache_A = if EnzymeRules.overwritten(config)[2] copy(A.val) @@ -395,13 +385,13 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT} ) return EnzymeRules.AugmentedReturn{ - EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing, - EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing, + EnzymeRules.primal_type(config, RT), + EnzymeRules.shadow_type(config, RT), typeof(cache) }(retres, dres, cache) end -function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, A::Annotation{<:Array}, b::Annotation{<:Array}) where RT +function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(\)}, ::Type{RT}, cache, A::Annotation{<:Array}, b::Annotation{<:Array}) where RT y, dys, cache_A, cache_b = cache @@ -469,7 +459,7 @@ const EnzymeTriangulars = Union{ } function EnzymeRules.augmented_primal( - config, + config::EnzymeRules.RevConfig, func::Const{typeof(ldiv!)}, ::Type{RT}, Y::Annotation{YT}, @@ -483,12 +473,17 @@ function EnzymeRules.augmented_primal( primal = EnzymeRules.needs_primal(config) ? Y.val : nothing shadow = EnzymeRules.needs_shadow(config) ? Y.dval : nothing func.val(Y.val, A.val, B.val) - return EnzymeRules.AugmentedReturn{typeof(primal), typeof(shadow), Any}( - primal, shadow, (cache_Y, cache_A, cache_B)) + return EnzymeRules.AugmentedReturn{ + EnzymeRules.primal_type(config, RT), + EnzymeRules.shadow_type(config, RT), + Tuple{typeof(cache_Y), typeof(cache_A), typeof(cache_B)} + }( + primal, shadow, (cache_Y, cache_A, cache_B) + ) end function EnzymeRules.reverse( - config, + config::EnzymeRules.RevConfig, func::Const{typeof(ldiv!)}, ::Type{RT}, cache, @@ -521,7 +516,7 @@ _zero_unused_elements!(X, ::UnitUpperTriangular) = triu!(X, 1) _zero_unused_elements!(X, ::UnitLowerTriangular) = tril!(X, -1) # Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) -function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} primal = if EnzymeRules.needs_primal(config) out.val else @@ -536,7 +531,7 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill return EnzymeRules.AugmentedReturn(primal, shadow, nothing) end -function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} +function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} nr, nc = size(out.val,1), size(out.val,2) for b in 1:EnzymeRules.width(config) da = if EnzymeRules.width(config) == 1 @@ -569,7 +564,7 @@ function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Ty return (nothing, nothing) end -function EnzymeRules.forward( +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ::Const{typeof(sort!)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, xs::Duplicated{T}; @@ -587,7 +582,7 @@ function EnzymeRules.forward( end end -function EnzymeRules.forward( +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ::Const{typeof(sort!)}, RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, xs::BatchDuplicated{T, N}; @@ -609,7 +604,7 @@ end function EnzymeRules.augmented_primal( - config::EnzymeRules.ConfigWidth{1}, + config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(sort!)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, xs::Duplicated{T}; @@ -632,7 +627,7 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.ConfigWidth{1}, + config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(sort!)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, tape, @@ -645,7 +640,7 @@ function EnzymeRules.reverse( return (nothing,) end -function EnzymeRules.forward( +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ::Const{typeof(partialsort!)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, xs::Duplicated{T}, @@ -670,7 +665,7 @@ function EnzymeRules.forward( end end -function EnzymeRules.forward( +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ::Const{typeof(partialsort!)}, RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, xs::BatchDuplicated{T, N}, @@ -702,7 +697,7 @@ function EnzymeRules.forward( end function EnzymeRules.augmented_primal( - config::EnzymeRules.ConfigWidth{1}, + config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(partialsort!)}, RT::Type{<:Union{Const, Active, DuplicatedNoNeed, Duplicated}}, xs::Duplicated{T}, @@ -728,7 +723,7 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.ConfigWidth{1}, + config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(partialsort!)}, dret::Union{Active, Type{<:Union{Const, Active, DuplicatedNoNeed, Duplicated}}}, tape, @@ -755,7 +750,7 @@ end # -> # B(out) = inv(A) B(in) # dB(out) = inv(A) [ dB(in) - dA B(out) ] -function EnzymeRules.forward(func::Const{typeof(ldiv!)}, +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{typeof(ldiv!)}, RT::Type{<:Union{Const,Duplicated,BatchDuplicated}}, fact::Annotation{<:Cholesky}, B::Annotation{<:AbstractVecOrMat}; @@ -763,7 +758,7 @@ function EnzymeRules.forward(func::Const{typeof(ldiv!)}, if B isa Const return func.val(fact.val, B.val; kwargs...) else - N = width(B) + N = EnzymeRules.width(config) retval = B.val L = fact.val.L @@ -810,7 +805,7 @@ end # Float64 ranges in Julia use bitwise `&` with higher precision # to correct for numerical error, thus we put rules over the # operations as this is not directly differentiable -function EnzymeRules.forward(func::Const{Colon}, +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{Colon}, RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated, BatchDuplicated,BatchDuplicatedNoNeed}}, start::Annotation{<:AbstractFloat}, step::Annotation{<:AbstractFloat}, stop::Annotation{<:AbstractFloat}) @@ -820,7 +815,7 @@ function EnzymeRules.forward(func::Const{Colon}, elseif start isa Duplicated || start isa DuplicatedNoNeed start.dval elseif start isa BatchDuplicated || start isa BatchDuplicatedNoNeed - ntuple(i -> start.dval[i], Val(width(RT))) + ntuple(i -> start.dval[i], Val(EnzymeRules.width(config))) else error("Annotation type $(typeof(start)) not supported for range start. Please open an issue") end @@ -830,7 +825,7 @@ function EnzymeRules.forward(func::Const{Colon}, elseif step isa Duplicated || step isa DuplicatedNoNeed step.dval elseif step isa BatchDuplicated || step isa BatchDuplicatedNoNeed - ntuple(i -> step.dval[i], Val(width(RT))) + ntuple(i -> step.dval[i], Val(EnzymeRules.width(config))) else error("Annotation type $(typeof(start)) not supported for range step. Please open an issue") end @@ -845,11 +840,11 @@ function EnzymeRules.forward(func::Const{Colon}, BatchDuplicated(ret, ntuple(i -> range(dstart isa Number ? dstart : dstart[i]; step=dstep isa Number ? dstep : dstep[i], - length=length(ret)), Val(width(RT)))) + length=length(ret)), Val(EnzymeRules.width(config)))) elseif RT <: BatchDuplicatedNoNeed ntuple(i -> range(dstart isa Number ? dstart : dstart[i]; step=dstep isa Number ? dstep : dstep[i], - length=length(ret)), Val(width(RT))) + length=length(ret)), Val(EnzymeRules.width(config))) else error("This should not be possible. Please report.") end @@ -857,7 +852,7 @@ end -function EnzymeRules.augmented_primal(config, func::Const{Colon}, ::Type{<:Active}, +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{Colon}, ::Type{<:Active}, start::Annotation{<:AbstractFloat}, step::Annotation{<:AbstractFloat}, stop::Annotation{<:AbstractFloat}) if EnzymeRules.needs_primal(config) @@ -868,7 +863,7 @@ function EnzymeRules.augmented_primal(config, func::Const{Colon}, ::Type{<:Activ return EnzymeRules.AugmentedReturn(primal, nothing, nothing) end -function EnzymeRules.reverse(config, func::Const{Colon}, dret, tape::Nothing, +function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{Colon}, dret, tape::Nothing, start::Annotation{T1}, step::Annotation{T2}, stop::Annotation{T3}) where {T1<:AbstractFloat, T2<:AbstractFloat, T3<:AbstractFloat} dstart = if start isa Const @@ -908,7 +903,7 @@ function EnzymeRules.reverse(config, func::Const{Colon}, dret, tape::Nothing, end -function EnzymeRules.forward( +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, Ty::Const{Type{BigFloat}}, RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}; kwargs... @@ -920,13 +915,13 @@ function EnzymeRules.forward( elseif RT <: Duplicated return RT(Ty.val(; kwargs...), Ty.val(; kwargs...)) elseif RT <: BatchDuplicatedNoNeed - ntuple(Val(width(RT))) do i + ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta Ty.val(; kwargs...) end else @assert RT <: BatchDuplicated - tup = ntuple(Val(width(RT))) do i + tup = ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta Ty.val(; kwargs...) end @@ -935,7 +930,7 @@ function EnzymeRules.forward( end function EnzymeRules.augmented_primal( - config, + config::EnzymeRules.RevConfig, Ty::Const{Type{BigFloat}}, RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}, kwargs... @@ -961,7 +956,7 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config, + config::EnzymeRules.RevConfig, Ty::Const{Type{BigFloat}}, RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}, tape, diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 005557c65a..75e36370d8 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -352,6 +352,7 @@ end end width = get_width(gutils) + C = EnzymeRules.FwdConfig{Int(width), get_runtime_activity(gutils)} if shadowR != C_NULL unsafe_store!(shadowR,UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))).ref) @@ -384,10 +385,12 @@ end @assert kwtup !== nothing insert!(tt, 1, kwtup) insert!(tt, 2, Core.typeof(EnzymeRules.forward)) - insert!(tt, 4, Type{RT}) + insert!(tt, 3, C) + insert!(tt, 5, Type{RT}) else @assert kwtup === nothing - insert!(tt, 2, Type{RT}) + insert!(tt, 1, C) + insert!(tt, 3, Type{RT}) end TT = Tuple{tt...} @@ -595,7 +598,7 @@ end fn = LLVM.parent(LLVM.parent(orig)) world = enzyme_extract_world(fn) - C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten} + C = EnzymeRules.RevConfig{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten, get_runtime_activity(gutils)} mode = get_mode(gutils) @@ -673,7 +676,7 @@ end needsShadow end - C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten} + C = EnzymeRules.RevConfig{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten, get_runtime_activity(gutils)} alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 49622ada90..f3f05087c0 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -218,7 +218,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) world = codegen_world_age(FT, tt) opt_mi = Val(world) - forward = thunk(opt_mi, (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val(($(ModifiedBetween...),)), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) + forward = thunk(opt_mi, (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val(($(ModifiedBetween...),)), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) res = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) @@ -238,13 +238,13 @@ function func_runtime_generic_fwd(N, Width) body = body_runtime_generic_fwd(N, Width, wrapped, primtypes) quote - function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, ReturnType, F, DF, $(typeargs...)} + function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, ReturnType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} +@generated function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, RuntimeActivity, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, _, _, _ = setup_macro_wraps(true, N, Width, :allargs) return body_runtime_generic_fwd(N, Width, wrapped, primtypes) @@ -333,7 +333,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) opt_mi = Val(world) forward, adjoint = thunk(opt_mi, dupClosure0 ? Duplicated{FT} : Const{FT}, annotationA, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) internal_tape, origRet, initShadow = forward(dupClosure0 ? Duplicated(f, df) : Const(f), args...) annotation = annotationA @@ -358,13 +358,13 @@ function func_runtime_generic_augfwd(N, Width) body = body_runtime_generic_augfwd(N, Width, wrapped, primtypes, active_refs) quote - function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} + function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, RuntimeActivity, F, DF, $(typeargs...)} $body end end end -@generated function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, F, DF} +@generated function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...)::ReturnType where {ActivityTup, MB, RuntimeActivity, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, _, _, active_refs = setup_macro_wraps(false, N, Width, :allargs) return body_runtime_generic_augfwd(N, Width, wrapped, primtypes, active_refs) @@ -462,7 +462,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act opt_mi = Val(world) _, adjoint = thunk(opt_mi, dupClosure0 ? Duplicated{FT} : Const{FT}, annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) tup = if annotation0 <: Active || annotation0 <: MixedDuplicated || annotation0 <: BatchMixedDuplicated adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] @@ -480,13 +480,13 @@ function func_runtime_generic_rev(N, Width) body = body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs, active_refs) quote - function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)} + function runtime_generic_rev(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, MB, TapeType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} +@generated function runtime_generic_rev(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, RuntimeActivity, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs) return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs, active_refs) @@ -676,7 +676,7 @@ end end # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] -function fwddiff_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {width, dupClosure0, ReturnType, FT, tt′, DF, Nargs} +function fwddiff_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {RuntimeActivity, width, dupClosure0, ReturnType, FT, tt′, DF, Nargs} ReturnPrimal = Val(true) ModifiedBetween = Val(Enzyme.falses_from_args(Nargs+1)) @@ -714,7 +714,7 @@ function fwddiff_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType end opt_mi = Val(world) res = thunk(opt_mi, FA, annotation, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), - ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false))(fa, args...) + ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity)(fa, args...) return if annotation <: Const ReturnType(allFirst(Val(width+1), res)) else @@ -736,7 +736,7 @@ function body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - fwddiff_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, FT, tt′, f, df, args...)::ReturnType + fwddiff_with_return(runtimeActivity, Val($Width), Val(ActivityTup[1]), ReturnType, FT, tt′, f, df, args...)::ReturnType end end @@ -745,13 +745,13 @@ function func_runtime_iterate_fwd(N, Width) body = body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) quote - function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, ReturnType, F, DF, $(typeargs...)} + function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, ReturnType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} +@generated function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, RuntimeActivity, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, _, _, active_refs = setup_macro_wraps(true, N, Width, :allargs, #=iterate=#true) return body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) @@ -822,7 +822,7 @@ end end # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] -function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Val{ModifiedBetween0}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {width, dupClosure0, ReturnType, ModifiedBetween0, FT, tt′, DF, Nargs} +function augfwd_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Val{ModifiedBetween0}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {RuntimeActivity, width, dupClosure0, ReturnType, ModifiedBetween0, FT, tt′, DF, Nargs} ReturnPrimal = Val(true) ModifiedBetween = Val(ModifiedBetween0) @@ -869,7 +869,7 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} opt_mi = Val(world) forward, adjoint = thunk(opt_mi, FA, annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) forward(fa, args...) else nothing, primal_tuple(args...), annotation <: Active ? nothing : shadow_tuple(annotation, Val(width), args...) @@ -933,7 +933,7 @@ function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, a args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - tmpvals = augfwd_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, Val(concat($(modbetween...))), FT, tt′, f, df, args...)::ReturnType + tmpvals = augfwd_with_return(runtimeActivity, Val($Width), Val(ActivityTup[1]), ReturnType, Val(concat($(modbetween...))), FT, tt′, f, df, args...)::ReturnType ReturnType(($(results...), (tmpvals[$(Width+2)], refs))) end end @@ -943,13 +943,13 @@ function func_runtime_iterate_augfwd(N, Width) body = body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) quote - function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} + function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, MB, ReturnType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, ReturnType, F, DF} +@generated function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, RuntimeActivity, MB, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, _ , modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) return body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) @@ -970,7 +970,7 @@ function add_into_vec!(val::T, expr, vec, idx_in_vec) where T end # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] -@generated function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween0}, ::Val{lengths}, ::Type{FT}, ::Type{ttp}, f::FT, df::DF, tape, shadowargs, args::Vararg{Annotation, Nargs})::Nothing where {width, dupClosure0, ModifiedBetween0, lengths, FT, ttp, DF, Nargs} +@generated function rev_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween0}, ::Val{lengths}, ::Type{FT}, ::Type{ttp}, f::FT, df::DF, tape, shadowargs, args::Vararg{Annotation, Nargs})::Nothing where {RuntimeActivity, width, dupClosure0, ModifiedBetween0, lengths, FT, ttp, DF, Nargs} nontupexprs = Vector{Expr}(undef, Nargs) for i in 1:Nargs @@ -1092,7 +1092,7 @@ end opt_mi = Val(world) forward, adjoint = thunk(opt_mi, FA, annotation, $ttp, Val(API.DEM_ReverseModePrimal), Val($width), - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) tup = if tape.shadow_return !== nothing $shadadj @@ -1155,7 +1155,7 @@ function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shado args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - rev_with_return(Val($Width), Val(ActivityTup[1]), Val(concat($(modbetween...))), Val(concat($(lengths...))), FT, tt′, f, df, tape0, ($(shadowsplat...),), args...) + rev_with_return(runtimeActivity, Val($Width), Val(ActivityTup[1]), Val(concat($(modbetween...))), Val(concat($(lengths...))), FT, tt′, f, df, tape0, ($(shadowsplat...),), args...) return nothing end end @@ -1165,13 +1165,13 @@ function func_runtime_iterate_rev(N, Width) body = body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs, active_refs) quote - function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)} + function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, MB, TapeType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} +@generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, RuntimeActivity, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true; reverse=true) return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs, active_refs) @@ -1187,7 +1187,7 @@ for (N, Width) in Iterators.product(0:30, 1:10) eval(func_runtime_iterate_rev(N, Width)) end -function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false, endcast=true, firstconst_after_tape=true) +function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false, endcast=true, firstconst_after_tape=true, runtime_activity=true) width = get_width(gutils) mode = get_mode(gutils) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -1234,7 +1234,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, if lookup inverted = lookup_value(gutils, inverted, B) end - if API.runtimeActivity() + if get_runtime_activity(gutils) inv_0 = if width == 1 inverted else @@ -1295,6 +1295,9 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, end pushfirst!(vals, unsafe_to_llvm(B, Val(Int(width)))) + if runtime_activity + pushfirst!(vals, unsafe_to_llvm(B, Val(get_runtime_activity(gutils)))) + end etup0 = emit_tuple!(B, ActivityList) etup = emit_apply_type!(B, Base.Val, [etup0]) if isa(etup, LLVM.Instruction) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 962a4f46af..965114447c 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -335,7 +335,7 @@ end GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type $btstr" LLVM.memset!(B, get_array_data(B, shadowres), LLVM.ConstantInt(i8, 0, false), length, algn) end - if API.runtimeActivity() + if get_runtime_activity(gutils) prev = new_from_original(gutils, orig) shadowres = LLVM.select!(B, LLVM.icmp!(B, LLVM.API.LLVMIntNE, shadowin, new_from_original(gutils, origops[1])), shadowres, prev) API.moveBefore(prev, shadowres, B) @@ -358,7 +358,7 @@ end GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type $btstr" LLVM.memset!(B, get_array_data(B, callv), LLVM.ConstantInt(i8, 0, false), length, algn) end - if API.runtimeActivity() + if get_runtime_activity(gutils) prev = new_from_original(gutils, orig) callv = LLVM.select!(B, LLVM.icmp!(B, LLVM.API.LLVMIntNE, ev, new_from_original(gutils, origops[1])), callv, prev) if idx == 1 @@ -1094,7 +1094,7 @@ end else extract_value!(B, shadowin, idx-1) end - if API.runtimeActivity() + if get_runtime_activity(gutils) emit_error(B, orig, "Enzyme: Not yet implemented runtime activity for reverse of jl_array_del_end") end args = LLVM.Value[anti, offset] diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index 54208fe21c..2964838947 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -1,9 +1,9 @@ -function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, ::Val{width}) where {FT1, FT2, World, width} +function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, runtimeActivity::Val{RuntimeActivity}, ::Val{width}) where {FT1, FT2, World, width, RuntimeActivity} FT = Core.Typeof(fn) ghos = guaranteed_const(FT) opt_mi = world - forward = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) + forward = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) ft = ghos ? Const(fn) : Duplicated(fn, dfn) function fclosure() res = forward(ft) @@ -13,12 +13,12 @@ function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ss return ccall(:jl_new_task, Ref{Task}, (Any, Any, Int), fclosure, post, ssize) end -function runtime_newtask_augfwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, ::Val{width}, ::Val{ModifiedBetween}) where {FT1, FT2, World, width, ModifiedBetween} +function runtime_newtask_augfwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{ModifiedBetween}) where {FT1, FT2, World, width, ModifiedBetween, RuntimeActivity} # TODO make this AD subcall type stable FT = Core.Typeof(fn) ghos = guaranteed_const(FT) opt_mi = world - forward, adjoint = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) + forward, adjoint = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) ft = ghos ? Const(fn) : Duplicated(fn, dfn) taperef = Ref{Any}() @@ -189,7 +189,7 @@ end if mode == API.DEM_ForwardMode if fwdmodenm === nothing etarget = Compiler.EnzymeTarget() - eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ForwardMode, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, FFIABI, #=ErrIfFuncWritten=#false) + eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ForwardMode, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, FFIABI, #=ErrIfFuncWritten=#false, get_runtime_activity(gutils)) ejob = Compiler.CompilerJob(mi2, CompilerConfig(etarget, eparams; kernel=false), world) cmod, fwdmodenm, _, _ = _thunk(ejob, #=postopt=#false) @@ -220,7 +220,7 @@ end if augfwdnm === nothing || adjointnm === nothing etarget = Compiler.EnzymeTarget() # TODO modifiedBetween - eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ReverseModePrimal, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, FFIABI, #=ErrIfFuncWritten=#false) + eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ReverseModePrimal, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, FFIABI, #=ErrIfFuncWritten=#false, get_runtime_activity(gutils)) ejob = Compiler.CompilerJob(mi2, CompilerConfig(etarget, eparams; kernel=false), world) cmod, adjointnm, augfwdnm, TapeType = _thunk(ejob, #=postopt=#false) @@ -505,6 +505,7 @@ end invert_pointer(gutils, ops[1], B), new_from_original(gutils, ops[2]), (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)(B, new_from_original(gutils, ops[3])), + unsafe_to_llvm(B, Val(get_runtime_activity(gutils))), unsafe_to_llvm(B, Val(width)), ] @@ -555,7 +556,7 @@ end invert_pointer(gutils, ops[1], B), new_from_original(gutils, ops[2]), (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)(B, new_from_original(gutils, ops[3])), - unsafe_to_llvm(B, Val(width)), + unsafe_to_llvm(B, Val(get_runtime_activity(gutils))), unsafe_to_llvm(B, Val(width)), unsafe_to_llvm(B, Val(ModifiedBetween)), ] diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 5372a67726..6117e464d8 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -326,7 +326,7 @@ function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tap width = get_width(gutils) - sret = generic_setup(orig, runtime_newstruct_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset, B, false; firstconst=true, endcast = false, firstconst_after_tape=true) + sret = generic_setup(orig, runtime_newstruct_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset, B, false; firstconst=true, endcast = false, firstconst_after_tape=true, runtime_activity=false) if width == 1 shadow = sret @@ -370,7 +370,7 @@ function common_newstructv_rev(offset, B, orig, gutils, tape) if !newstruct_common(#=fwd=#false, #=run=#false, offset, B, orig, gutils, #=normalR=#nothing, #=shadowR=#nothing) @assert tape !== C_NULL width = get_width(gutils) - generic_setup(orig, runtime_newstruct_rev, Nothing, gutils, #=start=#offset, B, true; firstconst=true, tape, firstconst_after_tape=true) + generic_setup(orig, runtime_newstruct_rev, Nothing, gutils, #=start=#offset, B, true; firstconst=true, tape, firstconst_after_tape=true, runtime_activity=false) end return nothing @@ -399,7 +399,7 @@ function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) width = get_width(gutils) - sret = generic_setup(orig, runtime_tuple_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset+1, B, false; endcast = false) + sret = generic_setup(orig, runtime_tuple_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset+1, B, false; endcast = false, runtime_activity=false) if width == 1 shadow = sret @@ -465,7 +465,7 @@ function common_f_tuple_rev(offset, B, orig, gutils, tape) else tape end - generic_setup(orig, runtime_tuple_rev, Nothing, gutils, #=start=#offset+1, B, true; tape=tape2) + generic_setup(orig, runtime_tuple_rev, Nothing, gutils, #=start=#offset+1, B, true; tape=tape2, runtime_activity=false) end return nothing end diff --git a/test/Project.toml b/test/Project.toml index a3f8452712..3ce8fc645c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -24,4 +24,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Aqua = "0.8" -EnzymeTestUtils = "0.1.4" +EnzymeTestUtils = "0.1.4, 0.2" diff --git a/test/ext/.chainrulescore.jl.swp b/test/ext/.chainrulescore.jl.swp new file mode 100644 index 0000000000000000000000000000000000000000..94b31875e128106c0c92d4e28639ebc06d4e3451 GIT binary patch literal 12288 zcmeI2&u<$=6vrpMP@ts<4i&dYe85{A{|b?4R99(3f{;oMsNqtoXl&2A3-+$NJ67x- zhy;H@4^W`M0d55$aYLe9;D)$!;m83YAt51gtoY8%I=gn_G;J?LdMkZ)*R$`vdGEV# zY(<%7_088_r?cg%!11&YU*B5xfBJjp*r^k9V!Ib5DlrcZKAhb`Pqvy-^iDsH>g70+ zy>bw06^F7r_I}qHRyu*Mtc5p5Jym|YThS`f6*bhA)-@S~@t~`cRur2@V?VyK0<6Fz zDlpLg!pkQ&7wa{<)C=>^)3eWPK4MUIX9ZXRR)7^?1y})AfE8c`Sb?LYfbP$U9VC1# zP5NAVotwI*ANj%xumY?AE5Hh{0;~WlzzVPetN<&(3a|o4PyyK%;^t#Qe0&_q6AN##eJ+nEJH$IoX27 zKI7{leT^9ov$Jr^kK%0^w~I*>>k?g#70$&W?uMQxtQk%PdjnGpQxH<|l|jE7w4_$; z5?z$KbuAmndsz>J+~bp(Z$ukPwKJ#OoD=z7l!+?4B&Q~H`L0G`q9;;5q`OPtL4#xk z?C&{i^33m`d1rY~Rh|zq^(d#rtYG`6l8qiB#MGotimAavgvsV~5wg#4&Nakjo1328 zJ7Pv>Rs1mB?fOcXu;wVxL-cgInRg}V^|8vfe6xv{^r){Qzqs4i9?^wxBgH4}SFesAS(_D8WP+tN-!$uw+Pf=qo|>wIMe3BZcI3AM-!G16;XL_^MH_bH znb00J8Vxf=7Bn>;rpEX#fPKz*DK$Q*Du2;4)mV_W4ZdpjFTY{-FRz;Y%PUDsc4fR} z#XwS+X9Nw3OJ|}`6^EE`%NCnc{5y`kS=&4X+8#d;le6@bvWj?b) zYv5O%1WUe?;J~Wt9<6sHsTb?_uc(e1^5vw(*kN0V4N;(^uUdg0v?vKWq11y|nL_p{ zrS)aj!ib- Op22j$L`rX6Xz?G6bcs*^ literal 0 HcmV?d00001 diff --git a/test/kwrrules.jl b/test/kwrrules.jl index f5b9d2338a..34749e9baa 100644 --- a/test/kwrrules.jl +++ b/test/kwrrules.jl @@ -11,7 +11,7 @@ end import .EnzymeRules: augmented_primal, reverse using .EnzymeRules -function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f_kw)}, ::Type{<:Active}, x::Active; kwargs...) +function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f_kw)}, ::Type{<:Active}, x::Active; kwargs...) @show kwargs @assert length(overwritten(config)) == 2 if needs_primal(config) @@ -21,7 +21,7 @@ function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f_kw)}, ::T end end -function reverse(config::ConfigWidth{1}, ::Const{typeof(f_kw)}, dret::Active, tape, x::Active; kwargs...) +function reverse(config::RevConfigWidth{1}, ::Const{typeof(f_kw)}, dret::Active, tape, x::Active; kwargs...) @show kwargs # TODO do we want them here? @assert length(overwritten(config)) == 2 if needs_primal(config) @@ -43,7 +43,7 @@ function f_kw2(x; kwargs...) x^2 end -function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f_kw2)}, ::Type{<:Active}, x::Active) +function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f_kw2)}, ::Type{<:Active}, x::Active) if needs_primal(config) return AugmentedReturn(func.val(x.val), nothing, nothing) else @@ -51,7 +51,7 @@ function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f_kw2)}, :: end end -function reverse(config::ConfigWidth{1}, ::Const{typeof(f_kw2)}, dret::Active, tape, x::Active) +function reverse(config::RevConfigWidth{1}, ::Const{typeof(f_kw2)}, dret::Active, tape, x::Active) if needs_primal(config) return (10+2*x.val*dret.val,) else @@ -68,7 +68,7 @@ function f_kw3(x; val=nothing) x^2 end -function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f_kw3)}, ::Type{<:Active}, x::Active; dval=nothing) +function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f_kw3)}, ::Type{<:Active}, x::Active; dval=nothing) if needs_primal(config) return AugmentedReturn(func.val(x.val), nothing, nothing) else @@ -76,7 +76,7 @@ function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f_kw3)}, :: end end -function reverse(config::ConfigWidth{1}, ::Const{typeof(f_kw3)}, dret::Active, tape, x::Active; dval=nothing) +function reverse(config::RevConfigWidth{1}, ::Const{typeof(f_kw3)}, dret::Active, tape, x::Active; dval=nothing) if needs_primal(config) return (10+2*x.val*dret.val,) else @@ -92,7 +92,7 @@ function f_kw4(x; y=2.0) x*y end -function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f_kw4)}, ::Type{<:Active}, x::Active; y) +function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f_kw4)}, ::Type{<:Active}, x::Active; y) @assert length(overwritten(config)) == 2 if needs_primal(config) return AugmentedReturn(func.val(x.val), nothing, nothing) @@ -101,7 +101,7 @@ function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f_kw4)}, :: end end -function reverse(config::ConfigWidth{1}, ::Const{typeof(f_kw4)}, dret::Active, tape, x::Active; y) +function reverse(config::RevConfigWidth{1}, ::Const{typeof(f_kw4)}, dret::Active, tape, x::Active; y) @assert length(overwritten(config)) == 2 return (1000*y+2*x.val*dret.val,) end @@ -126,7 +126,7 @@ function wrapclos(cl, x) cl(x; width=9) end -function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closure2}, +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{Closure2}, ::Type{<:Active}, args::Vararg{Active,N}; width=7) where {N} vec = copy(func.val.v) pval = func.val(args[1].val) @@ -138,7 +138,7 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closur return AugmentedReturn(primal, nothing, vec) end -function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{Closure2}, +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{Closure2}, dret::Active, tape, args::Vararg{Active,N}; width=7) where {N} dargs = ntuple(Val(N)) do i 7 * args[1].val * dret.val + tape[1] * 1000 + width * 100000 diff --git a/test/kwrules.jl b/test/kwrules.jl index 91d3dc859d..9761c23510 100644 --- a/test/kwrules.jl +++ b/test/kwrules.jl @@ -10,7 +10,7 @@ function f_kw(x; kwargs...) x^2 end -function forward(::Const{typeof(f_kw)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated; kwargs...) +function forward(config, ::Const{typeof(f_kw)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated; kwargs...) return 10+2*x.val*x.dval end @@ -25,7 +25,7 @@ function f_kw2(x; kwargs...) x^2 end -function forward(::Const{typeof(f_kw2)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) +function forward(config, ::Const{typeof(f_kw2)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) return 10+2*x.val*x.dval end @@ -37,7 +37,7 @@ function f_kw3(x; val=nothing) x^2 end -function forward(::Const{typeof(f_kw3)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated; dval=nothing) +function forward(config, ::Const{typeof(f_kw3)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated; dval=nothing) return 10+2*x.val*x.dval end @@ -49,7 +49,7 @@ function f_kw4(x; y=2.0) x*y end -function forward(::Const{typeof(f_kw4)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated; y) +function forward(config, ::Const{typeof(f_kw4)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated; y) return 1000*y+2*x.val*x.dval end diff --git a/test/mixedrrule.jl b/test/mixedrrule.jl index 32407f3c12..db1d3ab251 100644 --- a/test/mixedrrule.jl +++ b/test/mixedrrule.jl @@ -17,7 +17,7 @@ function mixouter(x, y) return res end -function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(mixfnc)}, +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(mixfnc)}, ::Type{<:Active}, tup::MixedDuplicated{Tuple{Float64, Vector{Float64}}}) pval = func.val(tup.val) vec = copy(tup.val[2]) @@ -29,7 +29,7 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof return AugmentedReturn(primal, nothing, vec) end -function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(mixfnc)}, +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(mixfnc)}, dret::Active, tape, tup::MixedDuplicated{Tuple{Float64, Vector{Float64}}}) prev = tup.dval[] tup.dval[] = (7 * tape[1] * dret.val, prev[2]) @@ -57,7 +57,7 @@ function recmixouter(x, y, z) return res end -function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(recmixfnc)}, +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(recmixfnc)}, ::Type{<:Active}, tup) pval = func.val(tup.val) vec = copy(tup.val[2]) @@ -76,7 +76,7 @@ end return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState end -function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(recmixfnc)}, +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(recmixfnc)}, dret::Active, tape, tup) prev = tup.dval[] dRT = typeof(prev) diff --git a/test/rrules.jl b/test/rrules.jl index 6c2a965b0e..cd41b49716 100644 --- a/test/rrules.jl +++ b/test/rrules.jl @@ -15,7 +15,7 @@ end import .EnzymeRules: augmented_primal, reverse, Annotation, has_rrule_from_sig using .EnzymeRules -function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f)}, ::Type{<:Active}, x::Active) +function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f)}, ::Type{<:Active}, x::Active) if needs_primal(config) return AugmentedReturn(func.val(x.val), nothing, nothing) else @@ -23,7 +23,7 @@ function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f)}, ::Type end end -function reverse(config::ConfigWidth{1}, ::Const{typeof(f)}, dret::Active, tape, x::Active) +function reverse(config::RevConfigWidth{1}, ::Const{typeof(f)}, dret::Active, tape, x::Active) if needs_primal(config) return (10+2*x.val*dret.val,) else @@ -31,13 +31,13 @@ function reverse(config::ConfigWidth{1}, ::Const{typeof(f)}, dret::Active, tape, end end -function augmented_primal(::Config{false, false, 1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) +function augmented_primal(::RevConfig{false, false, 1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) v = x.val[1] x.val[1] *= v return AugmentedReturn(nothing, nothing, v) end -function reverse(::Config{false, false, 1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape, x::Duplicated) +function reverse(::RevConfig{false, false, 1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape, x::Duplicated) x.dval[1] = 100 + x.dval[1] * tape return (nothing,) end @@ -107,7 +107,7 @@ end end q(x) = x^2 -function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(q)}, ::Type{<:Active}, x::Active) +function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(q)}, ::Type{<:Active}, x::Active) tape = (Ref(2.0), Ref(3.4)) if needs_primal(config) return AugmentedReturn(func.val(x.val), nothing, tape) @@ -116,7 +116,7 @@ function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(q)}, ::Type end end -function reverse(config::ConfigWidth{1}, ::Const{typeof(q)}, dret::Active, tape, x::Active) +function reverse(config::RevConfigWidth{1}, ::Const{typeof(q)}, dret::Active, tape, x::Active) @test tape[1][] == 2.0 @test tape[2][] == 3.4 if needs_primal(config) @@ -133,7 +133,7 @@ end foo(x::Complex) = 2x function EnzymeRules.augmented_primal( - config::EnzymeRules.ConfigWidth{1}, + config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(foo)}, ::Type{<:Active}, x @@ -154,7 +154,7 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.ConfigWidth{1}, + config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(foo)}, dret, tape, @@ -177,7 +177,7 @@ function _dot(X::StridedArray{T}, Y::StridedArray{T}) where {T<:Union{Real,Compl end function augmented_primal( - config::ConfigWidth{1}, + config::RevConfigWidth{1}, func::Const{typeof(_dot)}, ::Type{<:Union{Const,Active}}, X::Duplicated{<:StridedArray{T}}, @@ -191,7 +191,7 @@ function augmented_primal( end function reverse( - ::ConfigWidth{1}, + ::RevConfigWidth{1}, ::Const{typeof(_dot)}, dret::Union{Active,Type{<:Const}}, tape, @@ -235,7 +235,7 @@ function cprimal(x0, y0) return @inbounds x[1] end -function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(cmyfunc!)}, ::Type{<:Const}, +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(cmyfunc!)}, ::Type{<:Const}, y::Duplicated, x::Duplicated) cmyfunc!(y.val, x.val) tape = (copy(x.val), 3) @@ -243,7 +243,7 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof end const seen = Set() -function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(cmyfunc!)}, ::Type{<:Const}, tape, +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(cmyfunc!)}, ::Type{<:Const}, tape, y::Duplicated, x::Duplicated) xval = tape[1] p = pointer(xval) @@ -265,7 +265,7 @@ function remultr(arg) arg * arg end -function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(remultr)}, +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(remultr)}, ::Type{<:Active}, args::Vararg{Active,N}) where {N} primal = if EnzymeRules.needs_primal(config) func.val(args[1].val) @@ -275,7 +275,7 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof return AugmentedReturn(primal, nothing, nothing) end -function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(remultr)}, +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(remultr)}, dret::Active, tape, args::Vararg{Active,N}) where {N} dargs = ntuple(Val(N)) do i @@ -315,7 +315,7 @@ function (cl::Closure)(x) end -function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closure}, +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{Closure}, ::Type{<:Active}, args::Vararg{Active,N}) where {N} vec = copy(func.val.v) pval = func.val(args[1].val) @@ -327,7 +327,7 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closur return AugmentedReturn(primal, nothing, vec) end -function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{Closure}, +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{Closure}, dret::Active, tape, args::Vararg{Active,N}) where {N} dargs = ntuple(Val(N)) do i 7 * args[1].val * dret.val + tape[1] * 1000 @@ -377,7 +377,7 @@ end unstabletape(x) = x^2 -function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(unstabletape)}, ::Type{<:Active}, x::Active) +function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(unstabletape)}, ::Type{<:Active}, x::Active) tape = if x.val < 3 400 else @@ -390,7 +390,7 @@ function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(unstabletap end end -function reverse(config::ConfigWidth{1}, ::Const{typeof(unstabletape)}, dret, tape, x::Active{T}) where T +function reverse(config::RevConfigWidth{1}, ::Const{typeof(unstabletape)}, dret, tape, x::Active{T}) where T return (T(tape)::T,) end diff --git a/test/ruleinvalidation.jl b/test/ruleinvalidation.jl index 87c52861cc..62579e2415 100644 --- a/test/ruleinvalidation.jl +++ b/test/ruleinvalidation.jl @@ -11,25 +11,25 @@ call_issue696(args...) = issue696(args...) @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 # should invalidate cache for the previous result -forward(::Const{typeof(issue696)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) = +forward(config, ::Const{typeof(issue696)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) = 10+2*x.val*x.dval -forward(func::Const{typeof(issue696)}, ::Type{<:Duplicated}, x::Duplicated) = +forward(config, func::Const{typeof(issue696)}, ::Type{<:Duplicated}, x::Duplicated) = Duplicated(func.val(x.val), 10+2*x.val*x.dval) @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 12.0 @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 12.0 # should invalidate cache for the previous result again -forward(::Const{typeof(issue696)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) = +forward(config, ::Const{typeof(issue696)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) = 20+2*x.val*x.dval -forward(func::Const{typeof(issue696)}, ::Type{<:Duplicated}, x::Duplicated) = +forward(config, func::Const{typeof(issue696)}, ::Type{<:Duplicated}, x::Duplicated) = Duplicated(func.val(x.val), 20+2*x.val*x.dval) @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 22.0 @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 22.0 # check that `Base.delete_method` works as expected -for m in methods(forward, Tuple{Const{typeof(issue696)},Vararg{Any}}) +for m in methods(forward, Tuple{Any,Const{typeof(issue696)},Vararg{Any}}) Base.delete_method(m) end @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 diff --git a/test/rules.jl b/test/rules.jl index b6644d8c55..0ef2e0fe8e 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -4,7 +4,7 @@ using Enzyme using Enzyme: EnzymeRules using Test -import .EnzymeRules: forward, Annotation, has_frule_from_sig +import .EnzymeRules: forward, Annotation, has_frule_from_sig, FwdConfig f(x) = x^2 @@ -13,23 +13,23 @@ function f_ip(x) return nothing end -function forward(::Const{typeof(f)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) +function forward(config, ::Const{typeof(f)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) return 10+2*x.val*x.dval end -function forward(::Const{typeof(f)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} +function forward(config, ::Const{typeof(f)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} return NTuple{N, T}(1000+2*x.val*dv for dv in x.dval) end -function forward(func::Const{typeof(f)}, ::Type{<:Duplicated}, x::Duplicated) +function forward(config, func::Const{typeof(f)}, ::Type{<:Duplicated}, x::Duplicated) return Duplicated(func.val(x.val), 100+2*x.val*x.dval) end -function forward(func::Const{typeof(f)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} +function forward(config, func::Const{typeof(f)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} return BatchDuplicated(func.val(x.val), NTuple{N, T}(10000+2*x.val*dv for dv in x.dval)) end -function forward(::Const{Core.typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) +function forward(config, ::Const{Core.typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) ld = x.val[1] x.val[1] *= ld x.dval[1] *= 2 * ld + 10 @@ -38,7 +38,7 @@ end function has_frule(f, @nospecialize(RT), @nospecialize(TT::Type{<:Tuple}); world=Base.get_world_counter()) TT = Base.unwrap_unionall(TT) - TT = Tuple{<:Annotation{Core.typeof(f)}, Type{<:RT}, TT.parameters...} + TT = Tuple{<:FwdConfig, <:Annotation{Core.typeof(f)}, Type{<:RT}, TT.parameters...} EnzymeRules.isapplicable(forward, TT; world) end @@ -82,7 +82,7 @@ end end g(x) = x ^ 2 -function forward(func::Const{typeof(g)}, ::Type{<:Const}, x::Const) +function forward(config, func::Const{typeof(g)}, ::Type{<:Const}, x::Const) return Const(g(x.val)) end @@ -107,11 +107,11 @@ function h2(x) y * y end -function forward(func::Const{typeof(alloc_sq)}, ::Type{<:Duplicated}, x::Duplicated) +function forward(config, func::Const{typeof(alloc_sq)}, ::Type{<:Duplicated}, x::Duplicated) return Duplicated(Ref(x.val*x.val), Ref(10*2*x.val*x.dval)) end -function forward(func::Const{typeof(alloc_sq)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) +function forward(config, func::Const{typeof(alloc_sq)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) return Ref(1000*2*x.val*x.dval) end @@ -123,7 +123,7 @@ function h3(x) alloc_sq2(x)[] end -function forward(func::Const{typeof(alloc_sq2)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) +function forward(config, func::Const{typeof(alloc_sq2)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) return Duplicated(Ref(0.0), Ref(1000*2*x.val*x.dval)) end @@ -136,7 +136,7 @@ end foo(x) = 2x; -function EnzymeRules.forward( +function EnzymeRules.forward(config, func::Const{typeof(foo)}, RT::Type{<:Union{Duplicated,BatchDuplicated}}, x::Union{Duplicated,BatchDuplicated}, diff --git a/test/runtests.jl b/test/runtests.jl index 6ffd3dd09c..bdda7604bf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -141,10 +141,10 @@ end @test Enzyme.Compiler.active_reg_inner(Tuple, (), nothing, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) == Enzyme.Compiler.MixedState @test Enzyme.Compiler.active_reg_inner(Tuple{A,A} where A, (), nothing, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) == Enzyme.Compiler.MixedState world = codegen_world_age(typeof(f0), Tuple{Float64}) - thunk_a = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false)) - thunk_b = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Const, Tuple{Const{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false)) - thunk_c = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false)) - thunk_d = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false)) + thunk_a = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) + thunk_b = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Const, Tuple{Const{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) + thunk_c = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) + thunk_d = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) @test thunk_a.adjoint !== thunk_b.adjoint @test thunk_c.adjoint === thunk_a.adjoint @test thunk_c.adjoint === thunk_d.adjoint @@ -153,7 +153,7 @@ end @test thunk_a(Const(f0), Active(2.0), 2.0) == ((2.0,),) @test thunk_b(Const(f0), Const(2.0)) === ((nothing,),) - forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false)) + forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) @test forward(Const(f0), Active(2.0)) == (nothing,nothing,nothing) @test pullback(Const(f0), Active(2.0), 1.0, nothing) == ((1.0,),) @@ -164,7 +164,7 @@ end d = Duplicated([3.0, 5.0], [0.0, 0.0]) world = codegen_world_age(typeof(mul2), Tuple{Vector{Float64}}) - forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(mul2)}, Active, Tuple{Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, true)), Val(false), Val(false), DefaultABI, Val(false)) + forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(mul2)}, Active, Tuple{Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, true)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) res = forward(Const(mul2), d) @test typeof(res[1]) == Tuple{Float64, Float64} pullback(Const(mul2), d, 1.0, res[1]) @@ -173,7 +173,7 @@ end d = Duplicated([3.0, 5.0], [0.0, 0.0]) world = codegen_world_age(typeof(vrec), Tuple{Int, Vector{Float64}}) - forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(vrec)}, Active, Tuple{Const{Int}, Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false, true)), Val(false), Val(false), DefaultABI, Val(false)) + forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(vrec)}, Active, Tuple{Const{Int}, Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false, true)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) res = forward(Const(vrec), Const(Int(1)), d) pullback(Const(vrec), Const(1), d, 1.0, res[1]) @test d.dval[1] ≈ 5.0 @@ -1225,8 +1225,7 @@ end # Technically this test doesn't need runtimeactivity since the closure combo of active itr1 and const data # doesn't use any of the const data values, but now that we error for activity confusion, we need to # mark runtimeActivity to let this pass - Enzyme.API.runtimeActivity!(true) - Enzyme.autodiff(Enzyme.Reverse, Const(smallrf), Enzyme.Duplicated(weights, dweights), Enzyme.Const(data)) + Enzyme.autodiff(set_runtime_activity(Enzyme.Reverse), Const(smallrf), Enzyme.Duplicated(weights, dweights), Enzyme.Const(data)) @test dweights[1] ≈ 1. function invokesum(weights::Vector{Float64}, data::Vector{Float64})::Float64 @@ -1244,8 +1243,7 @@ end weights = [0.2, 0.8] dweights = [0.0, 0.0] - Enzyme.autodiff(Enzyme.Reverse, invokesum, Enzyme.Duplicated(weights, dweights), Enzyme.Const(data)) - Enzyme.API.runtimeActivity!(false) + Enzyme.autodiff(set_runtime_activity(Enzyme.Reverse), invokesum, Enzyme.Duplicated(weights, dweights), Enzyme.Const(data)) @test dweights[1] ≈ 20. @test dweights[2] ≈ 20. end @@ -1388,9 +1386,7 @@ end @testset "AbstractType calling convention" begin # TODO get rid of runtime activity - Enzyme.API.runtimeActivity!(true) - @test 1.0 ≈ Enzyme.autodiff(Reverse, dxdt_pred, Active(1.0))[1][1] - Enzyme.API.runtimeActivity!(false) + @test 1.0 ≈ Enzyme.autodiff(set_runtime_activity(Reverse), dxdt_pred, Active(1.0))[1][1] end function fillsum(x) @@ -1424,11 +1420,9 @@ function rtg_f(V,@nospecialize(cv)) end @testset "RuntimeActivity generic call" begin - Enzyme.API.runtimeActivity!(true) - res = autodiff(Forward, rtg_f, Duplicated, Duplicated([0.2], [1.0]), Const(RTGData(3.14))) + res = autodiff(set_runtime_activity(Forward), rtg_f, Duplicated, Duplicated([0.2], [1.0]), Const(RTGData(3.14))) @test 3.14 ≈ res[1] @test 0.0 ≈ res[2] - Enzyme.API.runtimeActivity!(false) end @inline function myquantile(v::AbstractVector, p::Real; alpha) @@ -2523,14 +2517,11 @@ end @testset "Getfield with reference" begin - Enzyme.API.runtimeActivity!(true) - d = GFNamedDist((;a = GFNormal(0.0, 1.0), b = GFProductDist([GFUniform(0.0, 1.0), GFUniform(0.0, 1.0)]))) p = (a = 1.0, b = [0.5, 0.5]) dp = Enzyme.make_zero(p) GFlogpdf(d, p) - autodiff(Reverse, GFlogpdf, Active, Const(d), Duplicated(p, dp)) - Enzyme.API.runtimeActivity!(false) + autodiff(set_runtime_activity(Reverse), GFlogpdf, Active, Const(d), Duplicated(p, dp)) end @testset "BLAS" begin @@ -2630,6 +2621,7 @@ end @testset "Union i8" begin args = ( Val{(false, false, false)}, + Val(false), Val(1), Val((true, true, true)), Base.Val(NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3")), Tuple{Any, Any, Any}}), @@ -2647,6 +2639,7 @@ end args2 = ( Val{(false, false, false)}, + Val(false), Val(1), Val((true, true, true)), Base.Val(NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3")), Tuple{Any, Any, Any}}), @@ -2664,13 +2657,13 @@ end end @testset "Batched inactive" begin - augres = Enzyme.Compiler.runtime_generic_augfwd(Val{(false, false, false)}, Val(2), Val((true, true, true)), + augres = Enzyme.Compiler.runtime_generic_augfwd(Val{(false, false, false)}, Val(false), Val(2), Val((true, true, true)), Val(Enzyme.Compiler.AnyArray(2+Int(2))), ==, nothing, nothing, :foo, nothing, nothing, :bar, nothing, nothing) - Enzyme.Compiler.runtime_generic_rev(Val{(false, false, false)}, Val(2), Val((true, true, true)), augres[end], + Enzyme.Compiler.runtime_generic_rev(Val{(false, false, false)}, Val(false), Val(2), Val((true, true, true)), augres[end], ==, nothing, nothing, :foo, nothing, nothing, :bar, nothing, nothing) @@ -3566,9 +3559,7 @@ end fn(0.0) end - Enzyme.API.runtimeActivity!(true) - res = autodiff(Forward, Const(f2), Duplicated, Duplicated(0.2, 1.0)) - Enzyme.API.runtimeActivity!(false) + res = autodiff(set_runtime_activity(Forward), Const(f2), Duplicated, Duplicated(0.2, 1.0)) @test res[1] ≈ 0.2 # broken as the return of an apply generic is {primal, primal} # but since the return is abstractfloat doing the diff --git a/test/sc.jl b/test/sc.jl new file mode 100644 index 0000000000..6978908201 --- /dev/null +++ b/test/sc.jl @@ -0,0 +1,64 @@ +module ReverseRules + +using Enzyme +using Enzyme: EnzymeRules +using LinearAlgebra +using Test + +f(x) = x^2 + +function f_ip(x) + x[1] *= x[1] + return nothing +end + +import .EnzymeRules: augmented_primal, reverse, Annotation, has_rrule_from_sig +using .EnzymeRules + +Enzyme.API.printall!(true) + +struct Closure + v::Vector{Float64} +end + +function (cl::Closure)(x) + val = cl.v[1] * x + cl.v[1] = 0.0 + return val +end + + +function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closure}, + ::Type{<:Active}, args::Vararg{Active,N}) where {N} + vec = copy(func.val.v) + pval = func.val(args[1].val) + primal = if EnzymeRules.needs_primal(config) + pval + else + nothing + end + return AugmentedReturn(primal, nothing, vec) +end + +function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{Closure}, + dret::Active, tape, args::Vararg{Active,N}) where {N} + + @show tape + @show dret + @show args + dargs = ntuple(Val(N)) do i + fval = 7 * args[1].val * dret.val + tape[1] * 1000 + @show fval + fval + end + return dargs +end + +@testset "Closure rule" begin + cl = Closure([3.14]) + res = autodiff(Reverse, cl, Active, Active(2.7))[1][1] + @test res ≈ 7 * 2.7 + 3.14 * 1000 + @test cl[1] ≈ 0.0 +end + +end # ReverseRules From e10ad8ca364026e82f68f07a7d15c178161fbeea Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 15 Sep 2024 11:32:17 -0500 Subject: [PATCH 281/495] Update Project.toml (#1831) * Update Project.toml * fix * fix --- Project.toml | 3 +-- src/compiler/interpreter.jl | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 1ea7b5c05b..e7baec9129 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.13.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" Enzyme_jll = "7cc45869-7501-5eee-bdea-0790c847d4ef" GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" @@ -36,7 +35,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8" -Enzyme_jll = "0.0.146, 0.0.148" +Enzyme_jll = "0.0.149" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, =9.0" LogExpFunctions = "0.3" diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 61a433af4c..2ef66a1571 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -41,7 +41,7 @@ function EnzymeInterpreter(cache_or_token, mt::Union{Nothing,Core.MethodTable}, @assert world <= Base.get_world_counter() parms = @static if VERSION < v"1.12" - InferenceParams(unoptimize_throw_blocks=false), + InferenceParams(unoptimize_throw_blocks=false) else InferenceParams() end From 7faa4108eaebc5c01e99e22bd09ef8f74ad74fb5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 15 Sep 2024 13:43:02 -0500 Subject: [PATCH 282/495] Fix rand set (#1833) --- src/internal_rules.jl | 62 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 96f774f69e..238f7f7b03 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -964,3 +964,65 @@ function EnzymeRules.reverse( ) return () end + +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + Ty::Const{typeof(Random.rand!)}, + RT::Type, + rng::Annotation{rngty}, + dst::Annotation{<:Array{FT}}, + smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, + ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} + Ty.val(rng.val, dst.val, smpl.val) + if RT <: Duplicated + fill!(dst.dval, 0) + Duplicated(dst.val, dst.dval) + elseif RT <: Const + dst.val + elseif RT <: DuplicatedNoNeed + fill!(dst.dval, 0) + dst.dval + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + fill!(dst.dval[i], 0) + nothing + end + if RT <: BatchDuplicated + BatchDuplicated(dst.val, dst.dval) + else + dst.dval + end + end +end + +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, + Ty::Const{typeof(Random.rand!)}, + RT::Type, + rng::Annotation{rngty}, + dst::Annotation{<:Array{FT}}, + smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, + ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} + Ty.val(rng.val, dst.val, smpl.val) + if RT <: Duplicated || RT <: DuplicatedNoNeed + fill!(dst.dval, 0) + dst.dval + elseif RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + fill!(dst.dval[i], 0) + nothing + end + end + return EnzymeRules.AugmentedReturn(EnzymeRules.needs_primal(config) ? dst.val : nothing, EnzymeRules.needs_shadow(config) ? dst.dval : nothing, nothing) +end + +function EnzymeRules.reverse(config::EnzymeRules.RevConfig, + Ty::Const{typeof(Random.rand!)}, + RT::Type, + tape, + rng::Annotation{rngty}, + dst::Annotation{<:Array{FT}}, + smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, + ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} + return (nothing, nothing, nothing) +end From d8b09f75dbeef3fb931974ac3781a7ecb55548ff Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 15 Sep 2024 19:57:43 -0500 Subject: [PATCH 283/495] Jitrules batched fn (#1835) --- src/rules/jitrules.jl | 59 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 52 insertions(+), 7 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index f3f05087c0..d5818ecf2f 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -191,6 +191,21 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) else :(return ReturnType((res[1], res[2]...))) end + dup = if Width == 1 + :(Duplicated(f, df)) + else + fargs = [:df] + for i in 2:Width + push!(fargs, Symbol("df_$i")) + end + :(BatchDuplicated(f, ($(fargs...),))) + end + dupty = if Width == 1 + :(Duplicated{FT}) + else + :(BatchDuplicated{FT, $Width}) + end + return quote args = ($(wrapped...),) @@ -218,9 +233,9 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) world = codegen_world_age(FT, tt) opt_mi = Val(world) - forward = thunk(opt_mi, (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val(($(ModifiedBetween...),)), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) + forward = thunk(opt_mi, dupClosure ? $dupty : Const{FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val(($(ModifiedBetween...),)), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) - res = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) + res = forward(dupClosure ? $dup : Const(f), args...) if length(res) == 0 return ReturnType(($(nnothing...),)) @@ -304,6 +319,21 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) else :(return ReturnType((origRet, shadow_return..., tape))) end + + dup = if Width == 1 + :(Duplicated(f, df)) + else + fargs = [:df] + for i in 2:Width + push!(fargs, Symbol("df_$i")) + end + :(BatchDuplicated(f, ($(fargs...),))) + end + dupty = if Width == 1 + :(Duplicated{FT}) + else + :(BatchDuplicated{FT, $Width}) + end return quote $(active_refs...) @@ -331,11 +361,11 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) world = codegen_world_age(FT, tt) opt_mi = Val(world) - forward, adjoint = thunk(opt_mi, dupClosure0 ? Duplicated{FT} : Const{FT}, + forward, adjoint = thunk(opt_mi, dupClosure0 ? $dupty : Const{FT}, annotationA, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) - internal_tape, origRet, initShadow = forward(dupClosure0 ? Duplicated(f, df) : Const(f), args...) + internal_tape, origRet, initShadow = forward(dupClosure0 ? $dup : Const(f), args...) annotation = annotationA resT = typeof(origRet) @@ -435,6 +465,21 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act @inbounds Types[i] = Symbol("type_$i") end + dup = if Width == 1 + :(Duplicated(f, df)) + else + fargs = [:df] + for i in 2:Width + push!(fargs, Symbol("df_$i")) + end + :(BatchDuplicated(f, ($(fargs...),))) + end + dupty = if Width == 1 + :(Duplicated{FT}) + else + :(BatchDuplicated{FT, $Width}) + end + quote $(active_refs...) args = ($(wrapped...),) @@ -460,14 +505,14 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act world = codegen_world_age(FT, tt) opt_mi = Val(world) - _, adjoint = thunk(opt_mi, dupClosure0 ? Duplicated{FT} : Const{FT}, + _, adjoint = thunk(opt_mi, dupClosure0 ? $dupty : Const{FT}, annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) tup = if annotation0 <: Active || annotation0 <: MixedDuplicated || annotation0 <: BatchMixedDuplicated - adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] + adjoint(dupClosure0 ? $dup : Const(f), args..., $shadowret, tape.internal_tape)[1] else - adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] + adjoint(dupClosure0 ? $dup : Const(f), args..., tape.internal_tape)[1] end $(outs...) From b02cb6dc26bacb1f9f94e0817b3abbe6b5ae1c55 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 15 Sep 2024 22:11:55 -0500 Subject: [PATCH 284/495] Force usage of full typetree on copy/memset (#1838) * Force usage of full typetree on copy/memset * fix * fix * fix * fix * fix * fix * fix * hopefully final fix? * Update Project.toml --- src/compiler.jl | 21 ++++++- src/typetree.jl | 163 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 144 insertions(+), 40 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index fb7ce8d8bf..df3db3c086 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5,7 +5,7 @@ import Enzyme: Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, BatchDuplicatedFunc, Annotation, guess_activity, eltype, - API, TypeTree, typetree, TypeTreeTable, only!, shift!, data0!, merge!, to_md, + API, TypeTree, typetree, TypeTreeTable, only!, shift!, data0!, merge!, to_md, to_fullmd, TypeAnalysis, FnTypeInfo, Logic, allocatedinline, ismutabletype using Enzyme @@ -6123,7 +6123,26 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if length(blocks(fn)) != 0 continue end + + intr = LLVM.API.LLVMGetIntrinsicID(fn) + + if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id || intr == LLVM.Intrinsic("llvm.memset").id + legal, jTy = abs_typeof(operands(inst)[1]) + sz = if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id + operands(inst)[3] + else + operands(inst)[3] + end + if legal && Base.isconcretetype(jTy) + if !(jTy isa UnionAll || jTy isa Union || jTy == Union{} || jTy === Tuple || (is_concrete_tuple(jTy) && any(T2 isa Core.TypeofVararg for T2 in jTy.parameters))) + if isa(sz, LLVM.ConstantInt) && sizeof(jTy) == convert(Int, sz) + metadata(inst)["enzyme_truetype"] = to_fullmd(jTy) + end + end + end + end end + ty = value_type(inst) if ty == LLVM.VoidType() continue diff --git a/src/typetree.jl b/src/typetree.jl index 40b01edcce..89e5a040f3 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -59,6 +59,119 @@ function merge!(dst::TypeTree, src::TypeTree; consume=true) return nothing end +@inline function typetree_primitive(t) + return nothing +end +@inline function typetree_primitive(::Type{T}) where {T<:Integer} + return API.DT_Integer +end +@inline function typetree_primitive(::Type{Char}) + return API.DT_Integer +end +@inline function typetree_primitive(::Type{Float16}) + return API.DT_Half +end +@inline function typetree_primitive(::Type{Float32}) + return API.DT_Float +end +@inline function typetree_primitive(::Type{Float64}) + return API.DT_Double +end + + +@static if VERSION >= v"1.11-" +const TypeTreePrimitives = ( + Char, + Float16, + Float32, + Float64, + Core.BFloat16 +) +else +const TypeTreePrimitives = ( + Char, + Float16, + Float32, + Float64 +) +end + +const TypeTreeEmptyPointers = ( + BigFloat, + Any, + Symbol, + Union{}, +) + +function get_offsets(@nospecialize(T::Type)) + for sT in (Integer, TypeTreePrimitives...) + if T <: sT + return ((typetree_primitive(T), 0),) + end + end + for sT in (DataType, AbstractString, TypeTreeEmptyPointers...) + if T <: sT + return ((API.DT_Pointer, 0),) + end + end + +@static if VERSION < v"1.11-" + TypeTreePtrs = (Core.SimpleVector, Ptr, Core.LLVMPtr, Array) +else + TypeTreePtrs = (Core.SimpleVector, Ptr, Core.LLVMPtr, Array, GenericMemory) +end + for sT in TypeTreeEmptyPointers + if T <: sT + return ((API.DT_Pointer, 0),) + end + end + + @assert !(T <: AbstractFloat) + + if fieldcount(T) == 0 + return () + end + + results = Tuple{API.CConcreteType, Int}[] + for f in 1:fieldcount(T) + offset = fieldoffset(T, f) + subT = fieldtype(T, f) + + if !allocatedinline(subT) || subT isa UnionAll || subT isa Union || subT == Union{} + push!(results, (API.DT_Pointer, offset)) + continue + end + + for (sT, sO) in get_offsets(subT) + push!(results, (sT, sO+offset)) + end + end + return results +end + +function to_fullmd(@nospecialize(T::Type)) + mds = LLVM.Metadata[] + for (sT, sO) in get_offsets(T) + if sT == API.DT_Pointer + push!(mds, LLVM.MDString("Pointer")) + elseif sT == API.DT_Integer + push!(mds, LLVM.MDString("Integer")) + elseif sT == API.DT_Half + push!(mds, LLVM.MDString("Float@half")) + elseif sT == API.DT_Float + push!(mds, LLVM.MDString("Float@float")) + elseif sT == API.DT_BFloat16 + push!(mds, LLVM.MDString("Float@bfloat16")) + elseif sT == API.DT_Double + push!(mds, LLVM.MDString("Float@double")) + else + @assert false + end + push!(mds, LLVM.Metadata(LLVM.ConstantInt(sO))) + end + return LLVM.MDNode(mds) +end + function to_md(tt::TypeTree, ctx) return LLVM.Metadata(LLVM.MetadataAsValue(ccall((:EnzymeTypeTreeToMD, API.libEnzyme), LLVM.API.LLVMValueRef, @@ -91,48 +204,28 @@ function typetree(@nospecialize(T::Type), ctx, dl, seen=TypeTreeTable()) return tree::TypeTree end -function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where {T<:Integer} - return TypeTree(API.DT_Integer, -1, ctx) -end - -function typetree_inner(::Type{Char}, ctx, dl, seen::TypeTreeTable) +function typetree_inner(::Type{<:Integer}, ctx, dl, seen::TypeTreeTable) return TypeTree(API.DT_Integer, -1, ctx) end - -function typetree_inner(::Type{Float16}, ctx, dl, seen::TypeTreeTable) - return TypeTree(API.DT_Half, -1, ctx) -end - -function typetree_inner(::Type{Float32}, ctx, dl, seen::TypeTreeTable) - return TypeTree(API.DT_Float, -1, ctx) -end - -function typetree_inner(::Type{Float64}, ctx, dl, seen::TypeTreeTable) - return TypeTree(API.DT_Double, -1, ctx) -end - -@static if VERSION >= v"1.11-" -function typetree_inner(::Type{Core.BFloat16}, ctx, dl, seen::TypeTreeTable) - return TypeTree(API.DT_BFloat16, -1, ctx) -end -end - -function typetree_inner(::Type{BigFloat}, ctx, dl, seen::TypeTreeTable) - return TypeTree() +for sT in TypeTreePrimitives + @eval function typetree_inner(::Type{$sT}, ctx, dl, seen::TypeTreeTable) + return TypeTree($(typetree_primitive(sT)), -1, ctx) + end end function typetree_inner(::Type{<:DataType}, ctx, dl, seen::TypeTreeTable) return TypeTree() end - -function typetree_inner(::Type{Any}, ctx, dl, seen::TypeTreeTable) +function typetree_inner(::Type{<:AbstractString}, ctx, dl, seen::TypeTreeTable) return TypeTree() end - -function typetree_inner(::Type{Symbol}, ctx, dl, seen::TypeTreeTable) - return TypeTree() +for sT in TypeTreeEmptyPointers + @eval function typetree_inner(::Type{$sT}, ctx, dl, seen::TypeTreeTable) + return TypeTree() + end end + function typetree_inner(::Type{Core.SimpleVector}, ctx, dl, seen::TypeTreeTable) tt = TypeTree() for i in 0:(sizeof(Csize_t) - 1) @@ -141,14 +234,6 @@ function typetree_inner(::Type{Core.SimpleVector}, ctx, dl, seen::TypeTreeTable) return tt end -function typetree_inner(::Type{Union{}}, ctx, dl, seen::TypeTreeTable) - return TypeTree() -end - -function typetree_inner(::Type{<:AbstractString}, ctx, dl, seen::TypeTreeTable) - return TypeTree() -end - function typetree_inner(::Type{<:Union{Ptr{T},Core.LLVMPtr{T}}}, ctx, dl, seen::TypeTreeTable) where {T} tt = copy(typetree(T, ctx, dl, seen)) From 1992f33f940bc9822b32b70164d1c3e9c368ba53 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 16 Sep 2024 00:04:50 -0500 Subject: [PATCH 285/495] Autodiff with do blocks (#1840) --- src/Enzyme.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index fcc12d57a8..4ad8a4b061 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -639,6 +639,28 @@ result, ∂v, ∂A Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) end +""" + autodiff(::Function, ::Mode, args...) + +Specialization of [`autodiff`](@ref) to handle do argument closures. + +```jldoctest + +autodiff(Reverse, Active(3.1)) do x + return x*x +end + +# output +((6.2,),) +``` +""" +@inline function autodiff(f::Function, m::MMode, ::Type{A}, args::Vararg{Annotation, Nargs}) where {A<:Annotation, Nargs, MMode<:Mode} + autodiff(m, f, A, args...) +end +@inline function autodiff(f::Function, m::MMode, args::Vararg{Annotation, Nargs}) where {Nargs, MMode<:Mode} + autodiff(m, f, args...) +end + """ autodiff_thunk(::ForwardMode, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs}) From 8cd61a529192373bcc1445cbd1bcd6e40ab2ca51 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Sep 2024 01:32:17 -0400 Subject: [PATCH 286/495] Fixup docs --- examples/custom_rule.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/custom_rule.jl b/examples/custom_rule.jl index 86ffcf234a..2b3f226fb0 100644 --- a/examples/custom_rule.jl +++ b/examples/custom_rule.jl @@ -168,7 +168,7 @@ g(y, x) = f(y, x)^2 # function to differentiate # Let's look at how to write a simple reverse-mode rule! # First, we write a method for [`EnzymeRules.augmented_primal`](@ref): -function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f)}, ::Type{<:Active}, +function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f)}, ::Type{<:Active}, y::Duplicated, x::Duplicated) println("In custom augmented primal rule.") ## Compute primal @@ -203,7 +203,7 @@ end # Now, we write a method for [`EnzymeRules.reverse`](@ref): -function reverse(config::ConfigWidth{1}, func::Const{typeof(f)}, dret::Active, tape, +function reverse(config::RevConfigWidth{1}, func::Const{typeof(f)}, dret::Active, tape, y::Duplicated, x::Duplicated) println("In custom reverse rule.") ## retrieve x value, either from original x or from tape if x may have been overwritten. From b94e3f490ff5d1ae830417d72b09d8ee91808e98 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 16 Sep 2024 12:17:41 -0500 Subject: [PATCH 287/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e7baec9129..19315d01dc 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8" -Enzyme_jll = "0.0.149" +Enzyme_jll = "0.0.150" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, =9.0" LogExpFunctions = "0.3" From 6dc7c8f2a46c975bd76610a6b79c969d7161746e Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 16 Sep 2024 14:33:31 -0500 Subject: [PATCH 288/495] Move return primal into forward mode (#1832) * Move return primal into forward mode * fix * fix * more fixups * fix * fix * fix etu * fix * fix * fix * fixup * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * Update Project.toml * fix * Update internal_rules.jl * fix * fix * fix * fix --- lib/EnzymeCore/src/EnzymeCore.jl | 23 +-- lib/EnzymeCore/src/rules.jl | 21 +- lib/EnzymeTestUtils/src/test_forward.jl | 18 +- src/Enzyme.jl | 244 +++++++++++++----------- src/compiler.jl | 16 +- src/internal_rules.jl | 182 +++++++++++------- src/rules/customrules.jl | 6 +- src/rules/jitrules.jl | 8 +- test/abi.jl | 56 +++--- test/applyiter.jl | 14 +- test/ext/chainrulescore.jl | 13 +- test/rules.jl | 14 +- test/runtests.jl | 40 ++-- 13 files changed, 373 insertions(+), 282 deletions(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 0175cb4caf..cc71f0f9c6 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -1,8 +1,8 @@ module EnzymeCore -export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal +export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal -export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed +export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, Annotation export MixedDuplicated, BatchMixedDuplicated export DefaultABI, FFIABI, InlineABI, NonGenABI export BatchDuplicatedFunc @@ -267,22 +267,23 @@ const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,Defau @inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() """ - struct Forward{ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} + struct Forward{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} Forward mode differentiation """ -struct ForwardMode{ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} +struct ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} end -const Forward = ForwardMode{DefaultABI, false, false}() +const Forward = ForwardMode{false, DefaultABI, false, false}() +const ForwardWithPrimal = ForwardMode{true, DefaultABI, false, false}() -@inline set_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,true,RuntimeActivity}() -@inline clear_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,false,RuntimeActivity}() +@inline set_err_if_func_written(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,true,RuntimeActivity}() +@inline clear_err_if_func_written(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,false,RuntimeActivity}() -@inline set_abi(::ForwardMode{OldABI,ErrIfFuncWritten,RuntimeActivity}, ::Type{NewABI}) where {OldABI,ErrIfFuncWritten,RuntimeActivity,NewABI<:ABI} = ForwardMode{NewABI,ErrIfFuncWritten,RuntimeActivity}() +@inline set_abi(::ForwardMode{ReturnPrimal,OldABI,ErrIfFuncWritten,RuntimeActivity}, ::Type{NewABI}) where {ReturnPrimal,OldABI,ErrIfFuncWritten,RuntimeActivity,NewABI<:ABI} = ForwardMode{ReturnPrimal,NewABI,ErrIfFuncWritten,RuntimeActivity}() -@inline set_runtime_activity(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,ErrIfFuncWritten,true}() -@inline set_runtime_activity(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}, rt::Bool) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,ErrIfFuncWritten,rt}() -@inline clear_runtime_activity(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,ErrIfFuncWritten,false}() +@inline set_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,true}() +@inline set_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, rt::Bool) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,rt}() +@inline clear_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,false}() function autodiff end function autodiff_deferred end diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 27b14619e3..8d01d321da 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -22,24 +22,29 @@ The third argument is the return type annotation, and all other arguments are th function forward end """ - FwdConfig{Width, RuntimeActivity} - FwdConfigWidth{Width} = FwdConfig{Width} + FwdConfig{NeedsPrimal, NeedsShadow, Width, RuntimeActivity} + FwdConfigWidth{Width} = FwdConfig{<:Any, <:Any, Width} Configuration type to dispatch on in custom forward rules (see [`forward`](@ref). +* `NeedsPrimal` and `NeedsShadow`: boolean values specifying whether the primal and shadow (resp.) should be returned. * `Width`: an integer that specifies the number of adjoints/shadows simultaneously being propagated. * `RuntimeActivity`: whether runtime activity is enabled. -Getters for the type parameters are provided by `width` and `runtime_activity`. +Getters for the type parameters are provided by `needs_primal`, `needs_shadow`, `width` and `runtime_activity`. """ -struct FwdConfig{Width, RuntimeActivity} end -const FwdConfigWidth{Width} = FwdConfig{Width} -@inline width(::FwdConfig{Width}) where Width = Width -@inline runtime_activity(::FwdConfig{<:Any, RuntimeActivity}) where RuntimeActivity = RuntimeActivity +struct FwdConfig{NeedsPrimal, NeedsShadow, Width, RuntimeActivity} end +const FwdConfigWidth{Width} = FwdConfig{<:Any,<:Any,Width} + +@inline needs_primal(::FwdConfig{NeedsPrimal}) where NeedsPrimal = NeedsPrimal +@inline needs_shadow(::FwdConfig{<:Any, NeedsShadow}) where NeedsShadow = NeedsShadow + +@inline width(::FwdConfig{<:Any, <:Any, Width}) where Width = Width +@inline runtime_activity(::FwdConfig{<:Any, <:Any, <:Any, RuntimeActivity}) where RuntimeActivity = RuntimeActivity """ RevConfig{NeedsPrimal, NeedsShadow, Width, Overwritten, RuntimeActivity} - RevConfigWidth{Width} = RevConfig{<:Any,<:Any, Width} + RevConfigWidth{Width} = RevConfig{<:Any, <:Any, Width} Configuration type to dispatch on in custom reverse rules (see [`augmented_primal`](@ref) and [`reverse`](@ref)). * `NeedsPrimal` and `NeedsShadow`: boolean values specifying whether the primal and shadow (resp.) should be returned. diff --git a/lib/EnzymeTestUtils/src/test_forward.jl b/lib/EnzymeTestUtils/src/test_forward.jl index fcfc987cb9..8830ce6784 100644 --- a/lib/EnzymeTestUtils/src/test_forward.jl +++ b/lib/EnzymeTestUtils/src/test_forward.jl @@ -79,13 +79,27 @@ function test_forward( # call finitedifferences, avoid mutating original arguments dy_fdm = _fd_forward(fdm, call_with_copy, ret_activity, y, activities) # call autodiff, allow mutating original arguments - y_and_dy_ad = autodiff(set_runtime_activity(Forward, runtime_activity), call_with_kwargs, ret_activity, activities...) + mode = if ret_activity <: Union{DuplicatedNoNeed,BatchDuplicatedNoNeed, Const} + Forward + else + ForwardWithPrimal + end + mode = set_runtime_activity(mode, runtime_activity) + + ret_activity2 = if ret_activity <: DuplicatedNoNeed + Duplicated + elseif ret_activity <: BatchDuplicatedNoNeed + BatchDuplicated + else + ret_activity + end + y_and_dy_ad = autodiff(mode, call_with_kwargs, ret_activity2, activities...) if ret_activity <: Union{Duplicated,BatchDuplicated} @test_msg( "For return type $ret_activity the return value and derivative must be returned", length(y_and_dy_ad) == 2, ) - y_ad, dy_ad = y_and_dy_ad + dy_ad, y_ad = y_and_dy_ad test_approx( y_ad, y, "The return value of the rule and function must agree"; atol, rtol ) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 4ad8a4b061..bb86a33fc7 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -2,8 +2,8 @@ module Enzyme import EnzymeCore -import EnzymeCore: Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal -export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal +import EnzymeCore: Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal +export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity @@ -358,38 +358,33 @@ instead use [`Duplicated`](@ref) or variants like [`DuplicatedNoNeed`](@ref). `Activity` is the Activity of the return value, it may be: * `Const` if the return is not to be differentiated with respect to -* `Duplicated`, if the return is being differentiated with respect to and - both the original value and the derivative return are desired -* `DuplicatedNoNeed`, if the return is being differentiated with respect to - and only the derivative return is desired. +* `Duplicated`, if the return is being differentiated with respect to * `BatchDuplicated`, like `Duplicated`, but computing multiple derivatives at once. All batch sizes must be the same for all arguments. -* `BatchDuplicatedNoNeed`, like `DuplicatedNoNeed`, but computing multiple - derivatives at one. All batch sizes must be the same for all arguments. Example returning both original return and derivative: ```jldoctest f(x) = x*x -res, ∂f_∂x = autodiff(Forward, f, Duplicated, Duplicated(3.14, 1.0)) +res, ∂f_∂x = autodiff(ForwardWithPrimal, f, Duplicated, Duplicated(3.14, 1.0)) # output -(9.8596, 6.28) +(6.28, 9.8596) ``` Example returning just the derivative: ```jldoctest f(x) = x*x -∂f_∂x = autodiff(Forward, f, DuplicatedNoNeed, Duplicated(3.14, 1.0)) +∂f_∂x = autodiff(Forward, f, Duplicated, Duplicated(3.14, 1.0)) # output (6.28,) ``` """ -@inline function autodiff(::ForwardMode{RABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {RABI <: ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} +@inline function autodiff(::ForwardMode{ReturnPrimal, RABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {ReturnPrimal, RABI <: ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end @@ -401,7 +396,9 @@ f(x) = x*x if A <: Active throw(ErrorException("Active Returns not allowed in forward mode")) end - ReturnPrimal = Val(A <: Duplicated || A <: BatchDuplicated) + if A <: DuplicatedNoNeed || A <: BatchDuplicatedNoNeed + throw(ErrorException("Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)")) + end RT = if A <: Duplicated && width != 1 if A isa UnionAll BatchDuplicated{T, width} where T @@ -429,7 +426,7 @@ f(x) = x*x end thunk = Enzyme.Compiler.thunk(opt_mi, FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), - ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) thunk(f, args...) end @@ -480,7 +477,7 @@ end Same as `autodiff(::ForwardMode, f, Activity, args)` but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ForwardMode{ABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, Nargs, ABI, ErrIfFuncWritten, RuntimeActivity} +@inline function autodiff_deferred(::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {ReturnPrimal, FA<:Annotation, A<:Annotation, Nargs, ABI, ErrIfFuncWritten, RuntimeActivity} if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end @@ -489,6 +486,9 @@ code, as well as high-order differentiation. if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end + if A <: DuplicatedNoNeed || A <: BatchDuplicatedNoNeed + throw(ErrorException("Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)")) + end RT = if A <: Duplicated && width != 1 if A isa UnionAll BatchDuplicated{T, width} where T @@ -524,7 +524,6 @@ code, as well as high-order differentiation. throw(ErrorException("Active Returns not allowed in forward mode")) end - ReturnPrimal = RT <: Duplicated || RT <: BatchDuplicated ModifiedBetween = Val(falses_from_args(Nargs+1)) adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), UnknownTapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) @@ -673,7 +672,7 @@ ftype when called with args of type `argtypes`. The forward function will return the primal (if requested) and the shadow (or nothing if not a `Duplicated` variant). -Example returning both original return and derivative: +Example returning both the return derivative and original return: ```jldoctest a = 4.2 @@ -681,12 +680,12 @@ b = [2.2, 3.3]; ∂f_∂b = zero(b) c = 55; d = 9 f(x) = x*x -forward = autodiff_thunk(Forward, Const{typeof(f)}, Duplicated, Duplicated{Float64}) +forward = autodiff_thunk(ForwardWithPrimal, Const{typeof(f)}, Duplicated, Duplicated{Float64}) res, ∂f_∂x = forward(Const(f), Duplicated(3.14, 1.0)) # output -(9.8596, 6.28) +(6.28, 9.8596) ``` Example returning just the derivative: @@ -697,7 +696,7 @@ b = [2.2, 3.3]; ∂f_∂b = zero(b) c = 55; d = 9 f(x) = x*x -forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated{Float64}) +forward = autodiff_thunk(Forward, Const{typeof(f)}, Duplicated, Duplicated{Float64}) ∂f_∂x = forward(Const(f), Duplicated(3.14, 1.0)) # output @@ -705,7 +704,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated (6.28,) ``` """ -@inline function autodiff_thunk(::ForwardMode{RABI, ErrIfFuncWritten, RuntimeActivity}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} +@inline function autodiff_thunk(::ForwardMode{ReturnPrimal, RABI, ErrIfFuncWritten, RuntimeActivity}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {ReturnPrimal, FA<:Annotation, A<:Annotation, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} width = same_or_one(1, A, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) @@ -713,7 +712,10 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated if A <: Active throw(ErrorException("Active Returns not allowed in forward mode")) end - ReturnPrimal = Val(A <: Duplicated || A <: BatchDuplicated) + if A <: DuplicatedNoNeed || A <: BatchDuplicatedNoNeed + throw(ErrorException("Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)")) + end + ModifiedBetween = Val(falses_from_args(Nargs+1)) tt = Tuple{map(eltype, args)...} @@ -724,7 +726,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated else Val(codegen_world_age(eltype(FA), tt)) end - Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + results = Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) end @inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT, RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} @@ -1046,6 +1048,14 @@ grad = gradient(Reverse, f, [2.0, 3.0]) 2.0 ``` +```jldoctest gradient + +grad = gradient(ReverseWithPrimal, f, [2.0, 3.0]) + +# output +([3.0, 2.0], 6.0) +``` + ```jldoctest gradient grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) @@ -1059,7 +1069,7 @@ grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) dx = Ref(make_zero(x)) res = autodiff(rm, f, Active, MixedDuplicated(x, dx)) if ReturnPrimal - (res[2], only(dx)) + (only(dx), res[2]) else only(dx) end @@ -1067,7 +1077,7 @@ grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) dx = make_zero(x) res = autodiff(rm, f, Active, Duplicated(x, dx)) if ReturnPrimal - (res[2], dx) + (dx, res[2]) else dx end @@ -1084,7 +1094,7 @@ Like [`gradient`](@ref), except it using deferred mode. dx = Ref(make_zero(x)) autodiff_deferred(rm, f, Active, MixedDuplicated(x, dx)) if ReturnPrimal - return (res[2], only(dx)) + return (only(dx), res[2]) else return only(dx) end @@ -1092,7 +1102,7 @@ Like [`gradient`](@ref), except it using deferred mode. dx = make_zero(x) autodiff_deferred(rm, f, Active, Duplicated(x, dx)) if ReturnPrimal - (res[2], dx) + (dx, res[2]) else dx end @@ -1108,7 +1118,7 @@ Both `x` and `dx` must be `Array`s of the same type. Example: -```jldoctest +```jldoctest gradip f(x) = x[1]*x[2] dx = [0.0, 0.0] @@ -1120,12 +1130,20 @@ gradient!(Reverse, dx, f, [2.0, 3.0]) 3.0 2.0 ``` + +```jldoctest gradip +dx = [0.0, 0.0] +gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0]) + +# output +([3.0, 2.0], 6.0) +``` """ @inline function gradient!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} make_zero!(dx) res = autodiff(rm, f, Active, Duplicated(x, dx)) return if ReturnPrimal - (res[2], dx) + (dx, res[2]) else dx end @@ -1141,7 +1159,7 @@ Like [`gradient!`](@ref), except it using deferred mode. make_zero!(dx) autodiff_deferred(rm, f, Active, Duplicated(x, dx)) return if ReturnPrimal - (res[2], dx) + (dx, res[2]) else dx end @@ -1158,7 +1176,7 @@ within this call. Example: -```jldoctest +```jldoctest gradfwd f(x) = x[1]*x[2] grad = gradient(Forward, f, [2.0, 3.0]) @@ -1167,17 +1185,35 @@ grad = gradient(Forward, f, [2.0, 3.0]) (3.0, 2.0) ``` + +```jldoctest gradfwd +gradient(ForwardWithPrimal, f, [2.0, 3.0]) + +# output +((3.0, 2.0), 6.0) +``` """ -@inline function gradient(fm::ForwardMode, f, x; shadow=onehot(x)) +@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f, x; shadow=onehot(x)) where {ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} if length(shadow) == 0 - return () + if ReturnPrimal + ((), f(x.val)) + else + return () + end end - res = values(only(autodiff(fm, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) - if x isa AbstractFloat + resp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadow)) + + res = values(resp[1]) + dres = if x isa AbstractFloat res[1] else res end + if ReturnPrimal + (dres, resp[2]) + else + dres + end end @inline function chunkedonehot(x, ::Val{chunk}) where chunk @@ -1216,29 +1252,64 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2)) (3.0, 2.0) ``` """ -@inline function gradient(fm::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk} +@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk, ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} if chunk == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end - tmp = ntuple(length(shadow)) do i - values(autodiff(fm, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1]) - end - res = tupleconcat(tmp...) - if x isa AbstractFloat - res[1] + if ReturnPrimal + rp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadow[1]))[1] + dres1 = if chunk == 1 + (rp[1],) + else + values(rp[1]) + end + gres = if x isa AbstractFloat + dres1 + else + fm2 = ForwardMode{#=ReturnPrimal=#false, ABI, ErrIfFuncWritten,RuntimeActivity}() + tmp = ntuple(length(shadow)-1) do i + values(autodiff(fm2, f, BatchDuplicated, BatchDuplicated(x, shadow[i+1]))[1]) + end + tupleconcat(dres1, tmp...) + end + (gres, rp[2]) else - res + tmp = ntuple(length(shadow)) do i + values(autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadow[i]))[1]) + end + res = tupleconcat(tmp...) + if x isa AbstractFloat + res[1] + else + res + end end end -@inline function gradient(fm::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X} - res = ntuple(length(shadow)) do i - autodiff(fm, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] - end - if x isa AbstractFloat - res[1] +@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X, ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} + if ReturnPrimal + rp = autodiff(fm, f, Duplicated, Duplicated(x, shadow[1])) + dres1 = rp[1] + fm2 = ForwardMode{#=ReturnPrimal=#false, ABI, ErrIfFuncWritten,RuntimeActivity}() + + res = ntuple(length(shadow)-1) do i + autodiff(fm2, f, Duplicated, Duplicated(x, shadow[i+1]))[1] + end + gres = if x isa AbstractFloat + dres1 + else + (dres1, res...) + end + (gres, rp[2]) else - res + res = ntuple(length(shadow)) do i + autodiff(fm, f, Duplicated, Duplicated(x, shadow[i]))[1] + end + if x isa AbstractFloat + res[1] + else + res + end end end @@ -1270,46 +1341,16 @@ whose shape is `(size(output)..., size(input)...)` For functions who return other types, this function will retun an array or tuple of shape `size(input)` of values of the output type. """ -@inline function jacobian(fm::ForwardMode, f, x; shadow=onehot(x)) - cols = if length(shadow) == 0 - () +@inline function jacobian(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, args...; kwargs...) where {ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} + gradtup = gradient(fm, args...; kwargs...) + cols = if ReturnPrimal + gradtup[1] else - values(only(autodiff(fm, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) + gradtup end - if x isa AbstractFloat - cols[1] - elseif length(cols) > 0 && cols[1] isa AbstractArray - inshape = size(x) - outshape = size(cols[1]) - # st : outshape x total inputs - st = Base.stack(cols) - - st3 = if length(inshape) <= 1 - st - else - reshape(st, (outshape..., inshape...)) - end - - st3 - elseif x isa AbstractArray - inshape = size(x) - reshape(collect(cols), inshape) - else + x = args[2] + res = if x isa AbstractFloat cols - end -end - -@inline function jacobian(fm::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk} - if chunk == 0 - throw(ErrorException("Cannot differentiate with a batch size of 0")) - end - tmp = ntuple(length(shadow)) do i - Base.@_inline_meta - values(autodiff(fm, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1]) - end - cols = tupleconcat(tmp...) - if x isa AbstractFloat - cols[1] elseif length(cols) > 0 && cols[1] isa AbstractArray inshape = size(x) outshape = size(cols[1]) @@ -1329,33 +1370,10 @@ end else cols end -end - -@inline function jacobian(fm::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F,X} - cols = ntuple(length(shadow)) do i - Base.@_inline_meta - autodiff(fm, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] - end - if x isa AbstractFloat - cols[1] - elseif length(cols) > 0 && cols[1] isa AbstractArray - inshape = size(x) - outshape = size(cols[1]) - # st : outshape x total inputs - st = Base.stack(cols) - - st3 = if length(inshape) <= 1 - st - else - reshape(st, (outshape..., inshape...)) - end - - st3 - elseif x isa AbstractArray - inshape = size(x) - reshape(collect(cols), inshape) + if ReturnPrimal + (res, gradtup[2]) else - cols + res end end diff --git a/src/compiler.jl b/src/compiler.jl index df3db3c086..1d21fb99a1 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -637,7 +637,7 @@ end return Const{T} end if Mode == API.DEM_ForwardMode - return DuplicatedNoNeed{T} + return Duplicated{T} else if ActReg == ActiveState return Active{T} @@ -4216,9 +4216,6 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end end if Mode == API.DEM_ForwardMode - if returnPrimal - push!(sret_types, literal_rt) - end if !(rettype <: Const) if width == 1 push!(sret_types, literal_rt) @@ -4226,6 +4223,9 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, push!(sret_types, AnonymousStruct(NTuple{width, literal_rt})) end end + if returnPrimal + push!(sret_types, literal_rt) + end end combinedReturn = if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types) @@ -4562,7 +4562,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, val else @assert count_llvm_Sret > 1 - extract_value!(builder, val, returnNum) + extract_value!(builder, val, 1-returnNum) end) ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) @@ -6906,7 +6906,7 @@ end push!(sret_types, TapeType) end - if returnPrimal + if returnPrimal && !(CC <: ForwardModeThunk) push!(sret_types, jlRT) end if is_forward @@ -6930,6 +6930,10 @@ end end end + if returnPrimal && (CC <: ForwardModeThunk) + push!(sret_types, jlRT) + end + # calls fptr llvmtys = LLVMType[convert(LLVMType, x; allow_boxed=true) for x in types] diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 238f7f7b03..f29ed0d977 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -573,12 +573,14 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] - if RT <: Const - return xs.val - elseif RT <: DuplicatedNoNeed + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return xs + elseif EnzymeRules.needs_shadow(config) return xs.dval + elseif EnzymeRules.needs_primal(config) + return xs.val else - return xs + return nothing end end @@ -593,12 +595,14 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, for i in 1:N xs.dval[i] .= xs.dval[i][inds] end - if RT <: Const - return xs.val - elseif RT <: BatchDuplicatedNoNeed + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return xs + elseif EnzymeRules.needs_shadow(config) return xs.dval + elseif EnzymeRules.needs_primal(config) + return xs.val else - return xs + return nothing end end @@ -652,16 +656,19 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, partialsortperm!(inds, xs.val, kv; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] - if RT <: Const - return kv isa Integer ? xs.val[kv] : view(xs.val, kv) - elseif RT <: DuplicatedNoNeed - return kv isa Integer ? xs.dval[kv] : view(xs.dval, kv) - else + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) if kv isa Integer return Duplicated(xs.val[kv], xs.dval[kv]) else return Duplicated(view(xs.val, kv), view(xs.dval, kv)) end + elseif EnzymeRules.needs_shadow(config) + return kv isa Integer ? xs.dval[kv] : view(xs.dval, kv) + elseif EnzymeRules.needs_primal(config) + return kv isa Integer ? xs.val[kv] : view(xs.val, kv) + else + return nothing end end @@ -679,20 +686,23 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, for i in 1:N xs.dval[i] .= xs.dval[i][inds] end - if RT <: Const - return kv isa Integer ? xs.val[kv] : view(xs.val, kv) - elseif RT <: BatchDuplicatedNoNeed + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) if kv isa Integer - return ntuple(i -> xs.dval[i][kv], N) + return BatchDuplicated(xs.val[kv], ntuple(i -> xs.dval[i][kv], N)) else - return ntuple(i -> view(xs.dval[i], kv), N) + return BatchDuplicated(view(xs.val, kv), ntuple(i -> view(xs.dval[i], kv), N)) end - else + elseif EnzymeRules.needs_shadow(config) if kv isa Integer - return BatchDuplicated(xs.val[kv], ntuple(i -> xs.dval[i][kv], N)) + return ntuple(i -> xs.dval[i][kv], N) else - return BatchDuplicated(view(xs.val, kv), ntuple(i -> view(xs.dval[i], kv), N)) + return ntuple(i -> view(xs.dval[i], kv), N) end + elseif EnzymeRules.needs_primal(config) + return kv isa Integer ? xs.val[kv] : view(xs.val, kv) + else + return nothing end end @@ -756,7 +766,12 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{typeof(l B::Annotation{<:AbstractVecOrMat}; kwargs...) if B isa Const - return func.val(fact.val, B.val; kwargs...) + retval = func.val(fact.val, B.val; kwargs...) + if EnzymeRules.needs_primal(config) + retval + else + return nothing + end else N = EnzymeRules.width(config) retval = B.val @@ -787,16 +802,23 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{typeof(l return dB end - if RT <: Const + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + return Duplicated(retval, dretvals[1]) + else + return BatchDuplicated(retval, dretvals) + end + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + return dretvals[1] + else + return dretvals + end + elseif EnzymeRules.needs_primal(config) return retval - elseif RT <: DuplicatedNoNeed - return dretvals[1] - elseif RT <: Duplicated - return Duplicated(retval, dretvals[1]) - elseif RT <: BatchDuplicatedNoNeed - return dretvals else - return BatchDuplicated(retval, dretvals) + return nothing end end end @@ -830,23 +852,27 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{Colon}, error("Annotation type $(typeof(start)) not supported for range step. Please open an issue") end - if RT <: Duplicated - Duplicated(ret, range(dstart; step=dstep, length=length(ret))) - elseif RT <: Const - ret - elseif RT <: DuplicatedNoNeed - range(dstart; step=dstep, length=length(ret)) - elseif RT <: BatchDuplicated - BatchDuplicated(ret, + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + return Duplicated(ret, range(dstart; step=dstep, length=length(ret))) + else + return BatchDuplicated(ret, ntuple(i -> range(dstart isa Number ? dstart : dstart[i]; step=dstep isa Number ? dstep : dstep[i], length=length(ret)), Val(EnzymeRules.width(config)))) - elseif RT <: BatchDuplicatedNoNeed - ntuple(i -> range(dstart isa Number ? dstart : dstart[i]; + end + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + return range(dstart; step=dstep, length=length(ret)) + else + return ntuple(i -> range(dstart isa Number ? dstart : dstart[i]; step=dstep isa Number ? dstep : dstep[i], length=length(ret)), Val(EnzymeRules.width(config))) + end + elseif EnzymeRules.needs_primal(config) + return ret else - error("This should not be possible. Please report.") + return nothing end end @@ -908,24 +934,30 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}; kwargs... ) - if RT <: Const - return Ty.val(; kwargs...) - elseif RT <: DuplicatedNoNeed - return Ty.val(; kwargs...) - elseif RT <: Duplicated - return RT(Ty.val(; kwargs...), Ty.val(; kwargs...)) - elseif RT <: BatchDuplicatedNoNeed - ntuple(Val(EnzymeRules.width(config))) do i - Base.@_inline_meta - Ty.val(; kwargs...) + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + return RT(Ty.val(; kwargs...), Ty.val(; kwargs...)) + else + tup = ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + Ty.val(; kwargs...) + end + return RT(Ty.val(; kwargs...), tup) end - else - @assert RT <: BatchDuplicated - tup = ntuple(Val(EnzymeRules.width(config))) do i - Base.@_inline_meta - Ty.val(; kwargs...) + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + return Ty.val(; kwargs...) + else + return ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + Ty.val(; kwargs...) + end end - RT(Ty.val(; kwargs...), tup) + elseif EnzymeRules.needs_primal(config) + return Ty.val(; kwargs...) + else + return nothing end end @@ -973,26 +1005,28 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} Ty.val(rng.val, dst.val, smpl.val) - if RT <: Duplicated - fill!(dst.dval, 0) - Duplicated(dst.val, dst.dval) - elseif RT <: Const - dst.val - elseif RT <: DuplicatedNoNeed - fill!(dst.dval, 0) - dst.dval - else - ntuple(Val(EnzymeRules.width(config))) do i - Base.@_inline_meta - fill!(dst.dval[i], 0) - nothing - end - if RT <: BatchDuplicated - BatchDuplicated(dst.val, dst.dval) + + if !(dst isa Const) + if EnzymeRules.width(config) == 1 + fill!(dst.dval, 0) else - dst.dval + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + fill!(dst.dval[i], 0) + nothing + end end end + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + dst + elseif EnzymeRules.needs_shadow(config) + dst.dval + elseif EnzymeRules.needs_primal(config) + dst.val + else + nothing + end end function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 75e36370d8..e0eae36e4d 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -352,7 +352,6 @@ end end width = get_width(gutils) - C = EnzymeRules.FwdConfig{Int(width), get_runtime_activity(gutils)} if shadowR != C_NULL unsafe_store!(shadowR,UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))).ref) @@ -374,6 +373,8 @@ end args, activity, overwritten, actives, kwtup, _ = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#false, isKWCall) RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B) + C = EnzymeRules.FwdConfig{Bool(needsPrimal), Bool(needsShadow), Int(width), get_runtime_activity(gutils)} + alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) mode = get_mode(gutils) @@ -494,8 +495,7 @@ end normalV = C_NULL if RT <: Const - # TODO introduce const-no-need - if needsPrimal || true + if needsPrimal if RealRt != fwd_RT emit_error(B, orig, "Enzyme: incorrect return type of const primal-only forward custom rule - "*(string(RT))*" "*string(activity)*" want just return type "*string(RealRt)*" found "*string(fwd_RT)) return false diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index d5818ecf2f..01edec7118 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -187,9 +187,9 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) end retres = if Width == 1 - :(return ReturnType((res[1], res[2]))) + :(return ReturnType((res[2], res[1]))) else - :(return ReturnType((res[1], res[2]...))) + :(return ReturnType((res[2], res[1]...))) end dup = if Width == 1 :(Duplicated(f, df)) @@ -764,9 +764,9 @@ function fwddiff_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width} ReturnType(allFirst(Val(width+1), res)) else if width == 1 - ReturnType((res[1], res[2])) + ReturnType((res[2], res[1])) else - ReturnType((res[1], res[2]...)) + ReturnType((res[2], res[1]...)) end end end diff --git a/test/abi.jl b/test/abi.jl index e07b7403ce..63fe48dc61 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -45,7 +45,7 @@ using Test @test_throws ErrorException autodiff(Reverse, f, Active, Active(1.5 + 0.7im)) cres, = autodiff(ReverseHolomorphic, f, Active, Active(1.5 + 0.7im))[1] @test cres ≈ 1.0 + 0.0im - cres, = autodiff(Forward, f, DuplicatedNoNeed, Duplicated(1.5 + 0.7im, 1.0 + 0im)) + cres, = autodiff(Forward, f, Duplicated, Duplicated(1.5 + 0.7im, 1.0 + 0im)) @test cres ≈ 1.0 + 0.0im @test_throws ErrorException autodiff(Reverse, f, Active(1.5 + 0.7im)) @@ -68,12 +68,12 @@ using Test _, res0 = autodiff(Enzyme.set_abi(Reverse, NonGenABI), unused, Active, Const(nothing), Active(2.0))[1] @test res0 ≈ 1.0 - res0, = autodiff(Forward, unused, DuplicatedNoNeed, Const(nothing), Duplicated(2.0, 1.0)) + res0, = autodiff(Forward, unused, Duplicated, Const(nothing), Duplicated(2.0, 1.0)) @test res0 ≈ 1.0 - res0, = autodiff(Forward, unused, DuplicatedNoNeed, Const(nothing), DuplicatedNoNeed(2.0, 1.0)) + res0, = autodiff(Forward, unused, Duplicated, Const(nothing), DuplicatedNoNeed(2.0, 1.0)) @test res0 ≈ 1.0 - res0, = autodiff(Enzyme.set_abi(Forward, NonGenABI), unused, DuplicatedNoNeed, Const(nothing), Duplicated(2.0, 1.0)) + res0, = autodiff(Enzyme.set_abi(Forward, NonGenABI), unused, Duplicated, Const(nothing), Duplicated(2.0, 1.0)) @test res0 ≈ 1.0 _, res0 = autodiff(Reverse, unused, Const(nothing), Active(2.0))[1] @@ -193,7 +193,7 @@ using Test res2, = autodiff(Reverse, g, Active, Active(Foo(3, 1.2)))[1] @test res2.qux ≈ 1.0 - @test 1.0≈ first(autodiff(Forward, g, DuplicatedNoNeed, Duplicated(Foo(3, 1.2), Foo(0, 1.0)))) + @test 1.0≈ first(autodiff(Forward, g, Duplicated, Duplicated(Foo(3, 1.2), Foo(0, 1.0)))) res2, = autodiff(Reverse, g, Active(Foo(3, 1.2)))[1] @test res2.qux ≈ 1.0 @@ -204,7 +204,7 @@ using Test _, resF = autodiff(Reverse, unused2, Active, Const(nothing), Active(Foo(3, 2.0)))[1] @test resF.qux ≈ 1.0 - @test 1.0≈ first(autodiff(Forward, unused2, DuplicatedNoNeed, Const(nothing), Duplicated(Foo(3, 1.2), Foo(0, 1.0)))) + @test 1.0≈ first(autodiff(Forward, unused2, Duplicated, Const(nothing), Duplicated(Foo(3, 1.2), Foo(0, 1.0)))) _, resF = autodiff(Reverse, unused2, Const(nothing), Active(Foo(3, 2.0)))[1] @test resF.qux ≈ 1.0 @@ -216,7 +216,7 @@ using Test @test res3[1].qux ≈ 3.4 @test res3[2].qux ≈ 1.2 - @test 7*3.4 + 9 * 1.2 ≈ first(autodiff(Forward, h, DuplicatedNoNeed, Duplicated(Foo(3, 1.2), Foo(0, 7.0)), Duplicated(Foo(5, 3.4), Foo(0, 9.0)))) + @test 7*3.4 + 9 * 1.2 ≈ first(autodiff(Forward, h, Duplicated, Duplicated(Foo(3, 1.2), Foo(0, 7.0)), Duplicated(Foo(5, 3.4), Foo(0, 9.0)))) res3 = autodiff(Reverse, h, Active(Foo(3, 1.2)), Active(Foo(5, 3.4)))[1] @test res3[1].qux ≈ 3.4 @@ -228,7 +228,7 @@ using Test _, res4 = autodiff(Reverse, caller, Active, Const((x)->x), Active(3.0))[1] @test res4 ≈ 1.0 - res4, = autodiff(Forward, caller, DuplicatedNoNeed, Const((x)->x), Duplicated(3.0, 1.0)) + res4, = autodiff(Forward, caller, Duplicated, Const((x)->x), Duplicated(3.0, 1.0)) @test res4 ≈ 1.0 _, res4 = autodiff(Reverse, caller, Const((x)->x), Active(3.0))[1] @@ -257,7 +257,7 @@ using Test @test ad === ((nothing,),) @test shadow.val ≈ 1.0 && shadow.next.val ≈ 1.0 - @test 2.0 ≈ first(autodiff(Forward, sumlist, DuplicatedNoNeed, Duplicated(regular, shadow))) + @test 2.0 ≈ first(autodiff(Forward, sumlist, Duplicated, Duplicated(regular, shadow))) mulr(x, y) = x[] * y[] x = Ref(2.0) @@ -273,7 +273,7 @@ using Test y = Ref(3.0) dx = Ref(5.0) dy = Ref(7.0) - @test 5.0*3.0 + 2.0*7.0≈ first(autodiff(Forward, mulr, DuplicatedNoNeed, Duplicated(x, dx), Duplicated(y, dy))) + @test 5.0*3.0 + 2.0*7.0≈ first(autodiff(Forward, mulr, Duplicated, Duplicated(x, dx), Duplicated(y, dy))) _, mid = Enzyme.autodiff(Reverse, (fs, x) -> fs[1](x), Active, Const((x->x*x,)), Active(2.0))[1] @test mid ≈ 4.0 @@ -281,10 +281,10 @@ using Test _, mid = Enzyme.autodiff(Reverse, (fs, x) -> fs[1](x), Active, Const([x->x*x]), Active(2.0))[1] @test mid ≈ 4.0 - mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), DuplicatedNoNeed, Const((x->x*x,)), Duplicated(2.0, 1.0)) + mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), Duplicated, Const((x->x*x,)), Duplicated(2.0, 1.0)) @test mid ≈ 4.0 - mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), DuplicatedNoNeed, Const([x->x*x]), Duplicated(2.0, 1.0)) + mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), Duplicated, Const([x->x*x]), Duplicated(2.0, 1.0)) @test mid ≈ 4.0 @@ -394,8 +394,8 @@ end @test Enzyme.autodiff(Reverse, method, Active, Const(AFoo(2.0)), Active(3.0))[1][2] ≈ 2.0 @test Enzyme.autodiff(Reverse, AFoo(2.0), Active, Active(3.0))[1][1] ≈ 2.0 - @test Enzyme.autodiff(Forward, method, DuplicatedNoNeed, Const(AFoo(2.0)), Duplicated(3.0, 1.0))[1] ≈ 2.0 - @test Enzyme.autodiff(Forward, AFoo(2.0), DuplicatedNoNeed, Duplicated(3.0, 1.0))[1] ≈ 2.0 + @test Enzyme.autodiff(Forward, method, Duplicated, Const(AFoo(2.0)), Duplicated(3.0, 1.0))[1] ≈ 2.0 + @test Enzyme.autodiff(Forward, AFoo(2.0), Duplicated, Duplicated(3.0, 1.0))[1] ≈ 2.0 struct ABar end @@ -407,8 +407,8 @@ end @test Enzyme.autodiff(Reverse, method, Active, Const(ABar()), Active(3.0))[1][2] ≈ 2.0 @test Enzyme.autodiff(Reverse, ABar(), Active, Active(3.0))[1][1] ≈ 2.0 - @test Enzyme.autodiff(Forward, method, DuplicatedNoNeed, Const(ABar()), Duplicated(3.0, 1.0))[1] ≈ 2.0 - @test Enzyme.autodiff(Forward, ABar(), DuplicatedNoNeed, Duplicated(3.0, 1.0))[1] ≈ 2.0 + @test Enzyme.autodiff(Forward, method, Duplicated, Const(ABar()), Duplicated(3.0, 1.0))[1] ≈ 2.0 + @test Enzyme.autodiff(Forward, ABar(), Duplicated, Duplicated(3.0, 1.0))[1] ≈ 2.0 struct RWClos x::Vector{Float64} @@ -446,14 +446,14 @@ end @testset "Promotion" begin x = [1.0, 2.0]; dx_1 = [1.0, 0.0]; dx_2 = [0.0, 1.0]; rosenbrock_inp(x) = (1.0 - x[1])^2 + 100.0 * (x[2] - x[1]^2)^2 - r = autodiff(Forward, rosenbrock_inp, Duplicated, BatchDuplicated(x, (dx_1, dx_2))) - @test r[1] ≈ 100.0 - @test r[2][1] ≈ -400.0 - @test r[2][2] ≈ 200.0 - r = autodiff_deferred(Forward, rosenbrock_inp, Duplicated, BatchDuplicated(x, (dx_1, dx_2))) - @test r[1] ≈ 100.0 - @test r[2][1] ≈ -400.0 - @test r[2][2] ≈ 200.0 + r = autodiff(ForwardWithPrimal, rosenbrock_inp, Duplicated, BatchDuplicated(x, (dx_1, dx_2))) + @test r[2] ≈ 100.0 + @test r[1][1] ≈ -400.0 + @test r[1][2] ≈ 200.0 + r = autodiff_deferred(ForwardWithPrimal, rosenbrock_inp, Duplicated, BatchDuplicated(x, (dx_1, dx_2))) + @test r[2] ≈ 100.0 + @test r[1][1] ≈ -400.0 + @test r[1][2] ≈ 200.0 end abssum(x) = sum(abs2, x); @@ -467,11 +467,14 @@ mulsin(x) = sin(x[1] * x[2]) @inferred autodiff(Enzyme.ReverseHolomorphic, abssum, Duplicated(x,x)) @inferred autodiff(Enzyme.ReverseHolomorphicWithPrimal, abssum, Duplicated(x,x)) @inferred autodiff(Enzyme.Forward, abssum, Duplicated(x,x)) + @inferred autodiff(Enzyme.ForwardWithPrimal, abssum, Duplicated, Duplicated(x,x)) @inferred autodiff(Enzyme.Forward, abssum, Duplicated, Duplicated(x,x)) - @inferred autodiff(Enzyme.Forward, abssum, DuplicatedNoNeed, Duplicated(x,x)) @inferred gradient(Reverse, abssum, x) @inferred gradient!(Reverse, x, abssum, x) + + @inferred gradient(ReverseWithPrimal, abssum, x) + @inferred gradient!(ReverseWithPrimal, x, abssum, x) cx = ones(10) @inferred autodiff(Enzyme.ReverseHolomorphic, sum, Duplicated(cx,cx)) @@ -489,6 +492,9 @@ mulsin(x) = sin(x[1] * x[2]) @inferred gradient(Reverse, abssum, tx) @inferred gradient(Forward, abssum, tx) + @inferred gradient(ReverseWithPrimal, abssum, tx) + @inferred gradient(ForwardWithPrimal, abssum, tx) + @inferred hvp(mulsin, [2.0, 3.0], [5.0, 2.7]) @inferred hvp!(zeros(2), mulsin, [2.0, 3.0], [5.0, 2.7]) diff --git a/test/applyiter.jl b/test/applyiter.jl index 5b55617e55..642ad62035 100644 --- a/test/applyiter.jl +++ b/test/applyiter.jl @@ -267,7 +267,7 @@ end @test dres[3] ≈ 100.02 @test dres[4] ≈ 304.1 - res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(x, dx)) + dres, res = Enzyme.autodiff(ForwardWithPrimal, metaconcat, Duplicated, Duplicated(x, dx)) @test length(res) == 4 @test res[1] ≈ 2.0 @test res[2] ≈ 3.0 @@ -290,7 +290,7 @@ end @test dres[3] == "c" @test dres[4] == "d" - res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(a, da)) + dres, res = Enzyme.autodiff(ForwardWithPrimal, metaconcat, Duplicated, Duplicated(a, da)) @test length(res) == 4 @test res[1] == "a" @test res[2] == "b" @@ -313,7 +313,7 @@ end @test dres[4] == "c" @test dres[5] == "d" - res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Duplicated(a, da)) + dres, res = Enzyme.autodiff(ForwardWithPrimal, midconcat, Duplicated, Duplicated(1.0, 7.0), Duplicated(a, da)) @test length(res) == 5 @test res[1] ≈ 1.0 @test res[2] == "a" @@ -337,7 +337,7 @@ end @test dres[4] == "c" @test dres[5] == "d" - res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Const(a)) + dres, res = Enzyme.autodiff(ForwardWithPrimal, midconcat, Duplicated, Duplicated(1.0, 7.0), Const(a)) @test length(res) == 5 @test res[1] ≈ 1.0 @test res[2] == "a" @@ -365,7 +365,7 @@ end @test dres[7] ≈ -9100.02 @test dres[8] ≈ -9304.1 - res, dres = Enzyme.autodiff(Forward, metaconcat2, Duplicated, Duplicated(x, dx), Duplicated(y, dy)) + dres, res = Enzyme.autodiff(ForwardWithPrimal, metaconcat2, Duplicated, Duplicated(x, dx), Duplicated(y, dy)) @test length(res) == 8 @test res[1] ≈ 2.0 @test res[2] ≈ 3.0 @@ -403,7 +403,7 @@ end @test dres[11] ≈ -9100.02 @test dres[12] ≈ -9304.1 - res, dres = Enzyme.autodiff(Forward, metaconcat3, Duplicated, Duplicated(x, dx), Const(a), Duplicated(y, dy)) + dres, res = Enzyme.autodiff(ForwardWithPrimal, metaconcat3, Duplicated, Duplicated(x, dx), Const(a), Duplicated(y, dy)) @test length(res) == 12 @test res[1] ≈ 2.0 @test res[2] ≈ 3.0 @@ -449,7 +449,7 @@ end @test dres[2][3] ≈ -9100.02 @test dres[2][4] ≈ -9304.1 - res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, BatchDuplicated(x, (dx, dy))) + dres, res = Enzyme.autodiff(ForwardWithPrimal, metaconcat, Duplicated, BatchDuplicated(x, (dx, dy))) @test length(res) == 4 @test res[1] ≈ 2.0 @test res[2] ≈ 3.0 diff --git a/test/ext/chainrulescore.jl b/test/ext/chainrulescore.jl index b73117faf2..65984ef26f 100644 --- a/test/ext/chainrulescore.jl +++ b/test/ext/chainrulescore.jl @@ -24,8 +24,11 @@ function ChainRulesCore.rrule(::typeof(MockModule.mock_function), x) return y, ȳ -> 2 * ȳ end -fdiff(f, x::Number) = autodiff(Forward, f, Duplicated, Duplicated(x, one(x)))[2] -fdiff(f, x::MockModule.MockType) = autodiff(Forward, f, Duplicated, Duplicated(x, MockModule.MockType(one(x.x))))[2] +fdiff(f, x::Number) = autodiff(ForwardWithPrimal, f, Duplicated, Duplicated(x, one(x)))[1] +fdiff(f, x::MockModule.MockType) = autodiff(ForwardWithPrimal, f, Duplicated, Duplicated(x, MockModule.MockType(one(x.x))))[1] + +fdiff2(f, x::Number) = autodiff(Forward, f, Duplicated, Duplicated(x, one(x)))[1] +fdiff2(f, x::MockModule.MockType) = autodiff(Forward, f, Duplicated, Duplicated(x, MockModule.MockType(one(x.x))))[1] @testset "import_frule" begin f1(x) = 2*x @@ -33,6 +36,8 @@ fdiff(f, x::MockModule.MockType) = autodiff(Forward, f, Duplicated, Duplicated(x Enzyme.@import_frule typeof(f1) Any @test fdiff(f1, 1f0) === 5f0 @test fdiff(f1, 1.0) === 5.0 + @test fdiff2(f1, 1f0) === 5f0 + @test fdiff2(f1, 1.0) === 5.0 # specific signature f2(x) = 2*x @@ -40,6 +45,8 @@ fdiff(f, x::MockModule.MockType) = autodiff(Forward, f, Duplicated, Duplicated(x Enzyme.@import_frule typeof(f2) Float32 @test fdiff(f2, 1f0) === 5f0 @test fdiff(f2, 1.0) === 2.0 + @test fdiff2(f2, 1f0) === 5f0 + @test fdiff2(f2, 1.0) === 2.0 # two arguments f3(x, y) = 2*x + y @@ -47,6 +54,8 @@ fdiff(f, x::MockModule.MockType) = autodiff(Forward, f, Duplicated, Duplicated(x Enzyme.@import_frule typeof(f3) Any Any @test fdiff(x -> f3(x, 1.0), 2.) === 5.0 @test fdiff(y -> f3(1.0, y), 2.) === 2.0 + @test fdiff2(x -> f3(x, 1.0), 2.) === 5.0 + @test fdiff2(y -> f3(1.0, y), 2.) === 2.0 # external module (checks correct type escaping, PR #1446) Enzyme.@import_frule typeof(MockModule.mock_function) MockModule.MockType diff --git a/test/rules.jl b/test/rules.jl index 0ef2e0fe8e..b306c353fb 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -61,11 +61,11 @@ end @test autodiff(Forward, f, Duplicated(2.0, 1.0))[1] ≈ 14.0 @test autodiff(Forward, x->f(x)^2, Duplicated(2.0, 1.0))[1] ≈ 832.0 - res = autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(2.0, (1.0, 3.0)))[1] + res = autodiff(Forward, f, BatchDuplicated, BatchDuplicated(2.0, (1.0, 3.0)))[1] @test res[1] ≈ 1004.0 @test res[2] ≈ 1012.0 - res = Enzyme.autodiff(Forward, x->f(x)^2, BatchDuplicatedNoNeed, BatchDuplicated(2.0, (1.0, 3.0)))[1] + res = Enzyme.autodiff(Forward, x->f(x)^2, BatchDuplicated, BatchDuplicated(2.0, (1.0, 3.0)))[1] @test res[1] ≈ 80032.0 @test res[2] ≈ 80096.0 @@ -129,7 +129,7 @@ end @testset "Shadow" begin @test Enzyme.autodiff(Forward, h, Duplicated(3.0, 1.0)) == (6000.0,) - @test Enzyme.autodiff(Forward, h, Duplicated, Duplicated(3.0, 1.0)) == (9.0, 60.0) + @test Enzyme.autodiff(ForwardWithPrimal, h, Duplicated(3.0, 1.0)) == (60.0, 9.0) @test Enzyme.autodiff(Forward, h2, Duplicated(3.0, 1.0)) == (1080.0,) @test_throws Enzyme.Compiler.EnzymeRuntimeException Enzyme.autodiff(Forward, h3, Duplicated(3.0, 1.0)) end @@ -149,10 +149,10 @@ function EnzymeRules.forward(config, end @testset "Batch complex" begin - res = autodiff(Forward, foo, BatchDuplicated, BatchDuplicated(0.1 + 0im, (0.2 + 0im, 0.3 + 0im))) # errors, see below - @test res[1] ≈ 0.2 + 0.0im - @test res[2][1] ≈ 0.4 + 0.0im - @test res[2][2] ≈ 0.6 + 0.0im + res = autodiff(ForwardWithPrimal, foo, BatchDuplicated(0.1 + 0im, (0.2 + 0im, 0.3 + 0im))) + @test res[2] ≈ 0.2 + 0.0im + @test res[1][1] ≈ 0.4 + 0.0im + @test res[1][2] ≈ 0.6 + 0.0im end end # module ForwardRules diff --git a/test/runtests.jl b/test/runtests.jl index bdda7604bf..18d765938d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -342,8 +342,8 @@ make3() = (1.0, 2.0, 3.0) f1(x) = 1.0 + x f2(x) = x*x @test autodiff(Reverse, f1, Active, Active(1.0))[1][1] ≈ 1.0 - @test autodiff(Forward, f1, DuplicatedNoNeed, Duplicated(1.0, 1.0))[1] ≈ 1.0 - @test autodiff(Forward, f1, Duplicated, Duplicated(1.0, 1.0))[2] ≈ 1.0 + @test autodiff(Forward, f1, Duplicated, Duplicated(1.0, 1.0))[1] ≈ 1.0 + @test autodiff(ForwardWithPrimal, f1, Duplicated, Duplicated(1.0, 1.0))[1] ≈ 1.0 @test autodiff(Reverse, f2, Active, Active(1.0))[1][1] ≈ 2.0 @test autodiff(Forward, f2, Duplicated(1.0, 1.0))[1] ≈ 2.0 tup = autodiff(Forward, f2, BatchDuplicated(1.0, (1.0, 2.0, 3.0)))[1] @@ -1323,8 +1323,8 @@ end (sin(x)::Float64 + x)::Float64 end @test 0.5838531634528576 ≈ Enzyme.autodiff(Reverse, boxfloat, Active, Active(2.0))[1][1] - @test 0.5838531634528576 ≈ Enzyme.autodiff(Forward, boxfloat, DuplicatedNoNeed, Duplicated(2.0, 1.0))[1] - res = Enzyme.autodiff(Forward, boxfloat, BatchDuplicatedNoNeed, BatchDuplicated(2.0, (1.0, 2.0)))[1] + @test 0.5838531634528576 ≈ Enzyme.autodiff(Forward, boxfloat, Duplicated, Duplicated(2.0, 1.0))[1] + res = Enzyme.autodiff(Forward, boxfloat, BatchDuplicated, BatchDuplicated(2.0, (1.0, 2.0)))[1] @test 0.5838531634528576 ≈ res[1] @test 1.1677063269057153 ≈ res[2] end @@ -1420,9 +1420,9 @@ function rtg_f(V,@nospecialize(cv)) end @testset "RuntimeActivity generic call" begin - res = autodiff(set_runtime_activity(Forward), rtg_f, Duplicated, Duplicated([0.2], [1.0]), Const(RTGData(3.14))) - @test 3.14 ≈ res[1] - @test 0.0 ≈ res[2] + res = autodiff(set_runtime_activity(ForwardWithPrimal), rtg_f, Duplicated, Duplicated([0.2], [1.0]), Const(RTGData(3.14))) + @test 3.14 ≈ res[2] + @test 0.0 ≈ res[1] end @inline function myquantile(v::AbstractVector, p::Real; alpha) @@ -1452,9 +1452,9 @@ end @testset "Attributor issues" begin cor = fquantile(2.0) - res = autodiff(Forward, fquantile, Duplicated,Duplicated(2.0, 1.0)) - @test cor ≈ res[1] - @test 0.7 ≈ res[2] + res = autodiff(ForwardWithPrimal, fquantile, Duplicated,Duplicated(2.0, 1.0)) + @test cor ≈ res[2] + @test 0.7 ≈ res[1] end @@ -1739,13 +1739,13 @@ end dx = [1.0, 1.0, 1.0] dx2 = [10.0, 20.0, 30.0] - res = Enzyme.autodiff(Forward, fwdlatestfoo, BatchDuplicated, BatchDuplicated(x, (dx, dx2))) + res = Enzyme.autodiff(ForwardWithPrimal, fwdlatestfoo, BatchDuplicated, BatchDuplicated(x, (dx, dx2))) @test 2.0 ≈ res[1][1] + @test 20.0 ≈ res[1][2] @test 2.0 ≈ res[2][1] - @test 20.0 ≈ res[2][2] - res = Enzyme.autodiff(Forward, fwdlatestfoo, BatchDuplicatedNoNeed, BatchDuplicated(x, (dx, dx2))) + res = Enzyme.autodiff(Forward, fwdlatestfoo, BatchDuplicated, BatchDuplicated(x, (dx, dx2))) @test 2.0 ≈ res[1][1] @test 20.0 ≈ res[1][2] @@ -2712,14 +2712,14 @@ end @testset "Batch Forward" begin square(x)=x*x - bres = autodiff(Forward, square, BatchDuplicatedNoNeed, BatchDuplicated(3.0, (1.0, 2.0, 3.0))) + bres = autodiff(Forward, square, BatchDuplicated, BatchDuplicated(3.0, (1.0, 2.0, 3.0))) @test length(bres) == 1 @test length(bres[1]) == 3 @test bres[1][1] ≈ 6.0 @test bres[1][2] ≈ 12.0 @test bres[1][3] ≈ 18.0 - bres = autodiff(Forward, square, BatchDuplicatedNoNeed, BatchDuplicated(3.0 + 7.0im, (1.0+0im, 2.0+0im, 3.0+0im))) + bres = autodiff(Forward, square, BatchDuplicated, BatchDuplicated(3.0 + 7.0im, (1.0+0im, 2.0+0im, 3.0+0im))) @test bres[1][1] ≈ 6.0 + 14.0im @test bres[1][2] ≈ 12.0 + 28.0im @test bres[1][3] ≈ 18.0 + 42.0im @@ -2729,10 +2729,10 @@ end # Shadow offset is not the same as primal so following doesn't work # d_inp = Float32[1.0, 2.0, 3.0] - # autodiff(Forward, squareidx, BatchDuplicatedNoNeed, BatchDuplicated(view(inp, 1:1), (view(d_inp, 1:1), view(d_inp, 2:2), view(d_inp, 3:3)))) + # autodiff(Forward, squareidx, BatchDuplicated, BatchDuplicated(view(inp, 1:1), (view(d_inp, 1:1), view(d_inp, 2:2), view(d_inp, 3:3)))) d_inp = (Float32[1.0], Float32[2.0], Float32[3.0]) - bres = autodiff(Forward, squareidx, BatchDuplicatedNoNeed, BatchDuplicated(inp, d_inp)) + bres = autodiff(Forward, squareidx, BatchDuplicated, BatchDuplicated(inp, d_inp)) @test bres[1][1] ≈ 6.0 @test bres[1][2] ≈ 12.0 @test bres[1][3] ≈ 18.0 @@ -3559,11 +3559,11 @@ end fn(0.0) end - res = autodiff(set_runtime_activity(Forward), Const(f2), Duplicated, Duplicated(0.2, 1.0)) - @test res[1] ≈ 0.2 + res = autodiff(set_runtime_activity(ForwardWithPrimal), Const(f2), Duplicated, Duplicated(0.2, 1.0)) + @test res[2] ≈ 0.2 # broken as the return of an apply generic is {primal, primal} # but since the return is abstractfloat doing the - @test res[2] ≈ 1.0 + @test res[1] ≈ 1.0 end @inline function uns_mymean(f, A, ::Type{T}, c) where T From bd5dcd10c703c43d5fbabb1e851b849089244dfd Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 16 Sep 2024 14:37:41 -0500 Subject: [PATCH 289/495] Auto upgrade to autodiff_deferred in nested AD (#1839) * WIP * Upgrade non deferred to deferred * cleanup * Update Project.toml * cleanup nested AD example --------- Co-authored-by: Michel Schanen --- examples/autodiff.jl | 4 ++-- src/Enzyme.jl | 45 ++----------------------------------- src/compiler/interpreter.jl | 32 +++++++++++++++++++++++++- test/runtests.jl | 8 +++++++ 4 files changed, 43 insertions(+), 46 deletions(-) diff --git a/examples/autodiff.jl b/examples/autodiff.jl index 669f3b6809..6bd0b74fb5 100644 --- a/examples/autodiff.jl +++ b/examples/autodiff.jl @@ -98,7 +98,7 @@ dby = [0.0] Enzyme.autodiff( Forward, - (x,y) -> Enzyme.autodiff_deferred(Reverse, f, x, y), + (x,y) -> Enzyme.autodiff(Reverse, f, x, y), Duplicated(Duplicated(x, bx), Duplicated(dx, dbx)), Duplicated(Duplicated(y, by), Duplicated(dy, dby)), ) @@ -121,7 +121,7 @@ dbx[2] == 1.0 # \end{aligned} # ``` function grad(x, dx, y, dy) - Enzyme.autodiff_deferred(Reverse, f, Duplicated(x, dx), DuplicatedNoNeed(y, dy)) + Enzyme.autodiff(Reverse, f, Duplicated(x, dx), DuplicatedNoNeed(y, dy)) nothing end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index bb86a33fc7..583035593d 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1084,31 +1084,6 @@ grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) end end -""" - gradient_deferred(::ReverseMode, f, x) - -Like [`gradient`](@ref), except it using deferred mode. -""" -@inline function gradient_deferred(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::X) where {F, X, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} - if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState - dx = Ref(make_zero(x)) - autodiff_deferred(rm, f, Active, MixedDuplicated(x, dx)) - if ReturnPrimal - return (only(dx), res[2]) - else - return only(dx) - end - else - dx = make_zero(x) - autodiff_deferred(rm, f, Active, Duplicated(x, dx)) - if ReturnPrimal - (dx, res[2]) - else - dx - end - end -end - """ gradient!(::ReverseMode, dx, f, x) @@ -1149,22 +1124,6 @@ gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0]) end end - -""" - gradient_deferred!(::ReverseMode, f, x) - -Like [`gradient!`](@ref), except it using deferred mode. -""" -@inline function gradient_deferred!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} - make_zero!(dx) - autodiff_deferred(rm, f, Active, Duplicated(x, dx)) - return if ReturnPrimal - (dx, res[2]) - else - dx - end -end - """ gradient(::ForwardMode, f, x; shadow=onehot(x)) @@ -1605,7 +1564,7 @@ res """ @inline function hvp!(res::X, f::F, x::X, v::X) where {F, X} grad = make_zero(x) - Enzyme.autodiff(Forward, gradient_deferred!, Const(Reverse), DuplicatedNoNeed(grad, res), Const(f), Duplicated(x, v)) + Enzyme.autodiff(Forward, gradient!, Const(Reverse), DuplicatedNoNeed(grad, res), Const(f), Duplicated(x, v)) return nothing end @@ -1640,7 +1599,7 @@ grad ``` """ @inline function hvp_and_gradient!(res::X, grad::X, f::F, x::X, v::X) where {F, X} - Enzyme.autodiff(Forward, gradient_deferred!, Const(Reverse), Duplicated(grad, res), Const(f), Duplicated(x, v)) + Enzyme.autodiff(Forward, gradient!, Const(Reverse), Duplicated(grad, res), Const(f), Duplicated(x, v)) return nothing end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 2ef66a1571..482690e20f 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -212,4 +212,34 @@ let # overload `inlining_policy` end end -end # module Interpreter +import Core.Compiler: abstract_call, abstract_call_known, ArgInfo, StmtInfo, AbsIntState, get_max_methods, + CallMeta, Effects, NoCallInfo, widenconst, mapany + +struct AutodiffCallInfo <: CallInfo + # ... + info::CallInfo +end + +function abstract_call_known(interp::EnzymeInterpreter, @nospecialize(f), + arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, + max_methods::Int = get_max_methods(interp, f, sv)) + + (; fargs, argtypes) = arginfo + + if f === Enzyme.autodiff && length(argtypes) >= 4 + if widenconst(argtypes[2]) <: Enzyme.Mode && widenconst(argtypes[3]) <: Enzyme.Annotation && widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation} + arginfo2 = ArgInfo( + fargs isa Nothing ? nothing : [:(Enzyme.autodiff_deferred), fargs[2:end]...], + [Core.Const(Enzyme.autodiff_deferred), argtypes[2:end]...] + ) + return abstract_call_known( + interp, Enzyme.autodiff_deferred, arginfo2, + si, sv, max_methods) + end + end + return Base.@invoke abstract_call_known( + interp::AbstractInterpreter, f, arginfo::ArgInfo, + si::StmtInfo, sv::AbsIntState, max_methods::Int) +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 18d765938d..b079c0f540 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -486,6 +486,14 @@ end end +@testset "Deferred upgrade" begin + function gradsin(x) + return gradient(Reverse, sin, x) + end + res = Enzyme.gradient(Reverse, gradsin, 3.1) + @test res ≈ -sin(3.1) +end + @testset "Simple Complex tests" begin mul2(z) = 2 * z square(z) = z * z From 786a998f0dc5343703c5420eae40cb790575e218 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 17 Sep 2024 15:30:36 -0500 Subject: [PATCH 290/495] Update sugar apis (#1844) * Update sugar apis * cleanup * cleanup * cleanup * fix * fix * fix stack * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * Update internal_rules.jl * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix --- ext/EnzymeStaticArraysExt.jl | 2 + src/Enzyme.jl | 734 ++++++++++++++++++++--------------- test/ext/logexpfunctions.jl | 4 +- test/internal_rules.jl | 38 +- test/runtests.jl | 326 ++++++++-------- 5 files changed, 609 insertions(+), 495 deletions(-) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index 6dbd390cb7..bcaa3ec6cb 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -3,6 +3,8 @@ module EnzymeStaticArraysExt using StaticArrays using Enzyme +@inline Enzyme.tupstack(rows::(NTuple{N, <:StaticArrays.SArray} where N), inshape, outshape) = reshape(cat(rows..., dims=length(inshape)), (inshape..., outshape...)) + @inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}) where {S, T, N, L} ntuple(Val(L)) do i Base.@_inline_meta diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 583035593d..66551f2958 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1024,12 +1024,16 @@ end end """ - gradient(::ReverseMode, f, x) + gradient(::ReverseMode, f, args...) Compute the gradient of a real-valued function `f` using reverse mode. -This will allocate and return new array `make_zero(x)` with the gradient result. +For each differentiable argument, this function will allocate and return new derivative object, returning +a tuple of derivatives for each argument. If an argument is not differentiable, the element of the returned +tuple with be nothing. -Besides arrays, for struct `x` it returns another instance of the same type, +In reverse mode (here), the derivatives will be the same type as the original argument. + +This is a structure gradient. For a struct `x` it returns another instance of the same type, whose fields contain the components of the gradient. In the result, `grad.a` contains `∂f/∂x.a` for any differential `x.a`, while `grad.c == x.c` for other types. @@ -1042,44 +1046,128 @@ f(x) = x[1]*x[2] grad = gradient(Reverse, f, [2.0, 3.0]) # output +([3.0, 2.0],) +``` -2-element Vector{Float64}: - 3.0 - 2.0 +```jldoctest gradient +grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) + +# output + +((a = 3.0, b = [2.0], c = "str"),) ``` ```jldoctest gradient +mul(x, y) = x[1]*y[1] -grad = gradient(ReverseWithPrimal, f, [2.0, 3.0]) +grad = gradient(Reverse, mul, [2.0], [3.0]) # output -([3.0, 2.0], 6.0) +([3.0], [2.0]) ``` ```jldoctest gradient -grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) + +grad = gradient(Reverse, mul, [2.0], Const([3.0])) + +# output +([3.0], nothing) +``` + +If passing a mode that returns the primal (e.g. ReverseWithPrimal), the return type will instead be +a tuple where the first element contains the derivatives, and the second element contains the result of the original computation. + +```jldoctest gradient + +grad = gradient(ReverseWithPrimal, f, [2.0, 3.0]) + +# output +(([3.0, 2.0],), 6.0) +``` +```jldoctest gradient + +grad = gradient(ReverseWithPrimal, mul, [2.0], [3.0]) # output +(([3.0], [2.0]), 6.0) +``` + +```jldoctest gradient +grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) -(a = 3.0, b = [2.0], c = "str") +# output +(([3.0], nothing), 6.0) ``` + """ -@inline function gradient(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::X) where {F, X, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} - if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState - dx = Ref(make_zero(x)) - res = autodiff(rm, f, Active, MixedDuplicated(x, dx)) - if ReturnPrimal - (only(dx), res[2]) - else - only(dx) +@generated function gradient(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::ty_0, args::Vararg{<:Any, N}) where {F, ty_0, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten, N} + toemit= Expr[quote + act_0 = !(x isa Enzyme.Const) && Compiler.active_reg_inner(Core.Typeof(x), #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState + end] + rargs = Union{Symbol,Expr}[:x] + acts = Symbol[Symbol("act_0")] + + for i in 1:N + argidx = quote args[$i] end + push!(rargs, argidx) + sym = Symbol("act_$i") + push!(acts, sym) + push!(toemit, quote + $sym = !($argidx isa Enzyme.Const) && Compiler.active_reg_inner(Core.Typeof($argidx), #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState + end) + end + + idx = 0 + shadows = Symbol[] + enz_args = Expr[] + resargs = Expr[] + for (arg, act) in zip(rargs, acts) + shad = Symbol("shad_$idx") + push!(shadows, shad) + push!(toemit, quote + $shad = if $arg isa Enzyme.Const + nothing + elseif $act + Ref(make_zero($arg)) + else + make_zero($arg) + end + end) + push!(enz_args, quote + if $arg isa Enzyme.Const + $arg + elseif $act + MixedDuplicated($arg, $shad) + else + Duplicated($arg, $shad) + end + end) + push!(resargs, quote + if $arg isa Enzyme.Const + nothing + elseif $act + $shad[] + else + $shad + end + end) + idx+=1 + end + push!(toemit, quote + res = autodiff(rm, f, Active, $(enz_args...)) + end) + + if ReturnPrimal + return quote + Base.@_inline_meta + $(toemit...) + (($(resargs...),), res[2]) end else - dx = make_zero(x) - res = autodiff(rm, f, Active, Duplicated(x, dx)) - if ReturnPrimal - (dx, res[2]) - else - dx + return quote + Base.@_inline_meta + $(toemit...) + ($(resargs...),) end end end @@ -1100,10 +1188,7 @@ dx = [0.0, 0.0] gradient!(Reverse, dx, f, [2.0, 3.0]) # output - -2-element Vector{Float64}: - 3.0 - 2.0 +([3.0, 2.0],) ``` ```jldoctest gradip @@ -1111,21 +1196,87 @@ dx = [0.0, 0.0] gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0]) # output -([3.0, 2.0], 6.0) +(([3.0, 2.0],), 6.0) ``` """ @inline function gradient!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} make_zero!(dx) res = autodiff(rm, f, Active, Duplicated(x, dx)) return if ReturnPrimal - (dx, res[2]) + ((dx,), res[2]) else - dx + (dx,) + end +end + +@inline function chunkedonehot(x, ::Val{chunk}) where chunk + sz = length(x) + num = ((sz + chunk - 1) ÷ chunk) + ntuple(Val(num)) do i + Base.@_inline_meta + onehot(x, (i-1)*chunk+1, i == num ? sz : (i*chunk) ) + end +end + +@inline function chunkedonehot(x::AbstractFloat, ::Val{chunk}) where chunk + return ((one(x),),) +end + +@inline tupleconcat(x) = x +@inline tupleconcat(x, y) = (x..., y...) +@inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...) + +function create_shadows(::Nothing, x) + return (onehot(x),) +end + +function create_shadows(::Val{1}, x) + return (onehot(x),) +end + +function create_shadows(::Val{chunk}, x) where chunk + return (chunkedonehot(x, Val(chunk)),) +end + +struct TupleArray{T, Shape, Length, N} <: AbstractArray{T,N} + data::NTuple{Length, T} +end +TupleArray(data::NTuple{Length, T}, Shape) where {Length, T} = TupleArray{T, Shape, Length, length(Shape)}(data) + +@inline Base.eltype(::TupleArray{T}) where T = T +@inline Base.eltype(::Type{<:TupleArray{T}}) where T = T +@inline Base.size(::TupleArray{<:Any, Shape}) where Shape = Shape +@inline Base.ndims(::TupleArray{<:Any, <:Any, <:Any, N}) where N = N + +function Base.convert(::Type{Array{T, N}}, X::TupleArray{T, Shape, Length, N}) where {T, Shape, Length, N} + vals = Array{T, N}(undef, Shape...) + for i in 1:Length + @inbounds val[i] = X.data[i] + end + return vals +end + +function Base.getindex(a::TupleArray, args::Vararg{Int,N}) where {N} + start = 0 + for i in 1:N + start *= size(a, N - i + 1) + start += (args[N - i + 1] - 1) + end + start += 1 + return a.data[start] +end + +@inline function tupstack(x, inshape, outshape) + st = Base.stack(x) + if length(outshape) == 1 + st + else + reshape(st, (inshape..., outshape...)) end end """ - gradient(::ForwardMode, f, x; shadow=onehot(x)) + gradient(::ForwardMode, f, x; shadows=onehot(x), chunk=nothing) Compute the gradient of an array-input function `f` using forward mode. The optional keyword argument `shadow` is a vector of one-hot vectors of type `x` @@ -1138,372 +1289,331 @@ Example: ```jldoctest gradfwd f(x) = x[1]*x[2] -grad = gradient(Forward, f, [2.0, 3.0]) +gradient(Forward, f, [2.0, 3.0]) # output -(3.0, 2.0) +([3.0, 2.0],) ``` ```jldoctest gradfwd gradient(ForwardWithPrimal, f, [2.0, 3.0]) # output -((3.0, 2.0), 6.0) +(([3.0, 2.0],), 6.0) ``` -""" -@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f, x; shadow=onehot(x)) where {ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} - if length(shadow) == 0 - if ReturnPrimal - ((), f(x.val)) - else - return () - end - end - resp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadow)) - res = values(resp[1]) - dres = if x isa AbstractFloat - res[1] - else - res - end - if ReturnPrimal - (dres, resp[2]) - else - dres - end -end - -@inline function chunkedonehot(x, ::Val{chunk}) where chunk - sz = length(x) - num = ((sz + chunk - 1) ÷ chunk) - ntuple(Val(num)) do i - Base.@_inline_meta - onehot(x, (i-1)*chunk+1, i == num ? sz : (i*chunk) ) - end -end +```jldoctest gradfwd +gradient(Forward, f, [2.0, 3.0]; chunk=Val(1)) -@inline function chunkedonehot(x::AbstractFloat, ::Val{chunk}) where chunk - return ((one(x),),) -end +# output -@inline tupleconcat(x) = x -@inline tupleconcat(x, y) = (x..., y...) -@inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...) +([3.0, 2.0],) +``` -""" - gradient(::ForwardMode, f, x::Union{Array,NTuple}, ::Val{chunk}; shadow=onehot(x)) +```jldoctest gradfwd +gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=Val(1)) -Compute the gradient of an array-input function `f` using vector forward mode. -Like [`gradient`](@ref), except it uses a chunk size of `chunk` to compute -`chunk` derivatives in a single call. +# output +(([3.0, 2.0],), 6.0) +``` -Example: +For functions which return an AbstractArray or scalar, this function will return an AbstracttArray +whose shape is `(size(output)..., size(input)...)`. No guarantees are presently made +about the type of the AbstractArray returned by this function (which may or may not be the same +as the input AbstractArray if provided). +For functions who return other types, this function will retun an AbstractArray +of shape `size(input)` of values of the output type. ```jldoctest -f(x) = x[1]*x[2] +f(x) = [ x[1] * x[2], x[2] + x[3] ] -grad = gradient(Forward, f, [2.0, 3.0], Val(2)) +grad = gradient(Forward, f, [2.0, 3.0, 4.0]) # output - -(3.0, 2.0) +([3.0 2.0 0.0; 0.0 1.0 1.0],) ``` """ -@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk, ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} - if chunk == 0 - throw(ErrorException("Cannot differentiate with a batch size of 0")) - end - if ReturnPrimal - rp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadow[1]))[1] - dres1 = if chunk == 1 - (rp[1],) - else - values(rp[1]) - end - gres = if x isa AbstractFloat - dres1 +@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f, x; chunk::CS=nothing, shadows=create_shadows(chunk, x)) where {ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity, CS} + if length(shadows[1]) == 0 + if ReturnPrimal + ((x,), f(x.val)) else - fm2 = ForwardMode{#=ReturnPrimal=#false, ABI, ErrIfFuncWritten,RuntimeActivity}() - tmp = ntuple(length(shadow)-1) do i - values(autodiff(fm2, f, BatchDuplicated, BatchDuplicated(x, shadow[i+1]))[1]) - end - tupleconcat(dres1, tmp...) - end - (gres, rp[2]) - else - tmp = ntuple(length(shadow)) do i - values(autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadow[i]))[1]) + return (x,) end - res = tupleconcat(tmp...) - if x isa AbstractFloat + end + if chunk == Val(0) + throw(ErrorException("Cannot differentiate with a batch size of 0")) + end + + gradtup = if chunk == nothing + resp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1])) + + res = values(resp[1]) + dres = if x isa AbstractFloat res[1] else res end - end -end - -@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X, ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} - if ReturnPrimal - rp = autodiff(fm, f, Duplicated, Duplicated(x, shadow[1])) - dres1 = rp[1] - fm2 = ForwardMode{#=ReturnPrimal=#false, ABI, ErrIfFuncWritten,RuntimeActivity}() - - res = ntuple(length(shadow)-1) do i - autodiff(fm2, f, Duplicated, Duplicated(x, shadow[i+1]))[1] + if ReturnPrimal + ((dres,), resp[2]) + else + (dres,) end - gres = if x isa AbstractFloat - dres1 + elseif chunk == Val(1) + if ReturnPrimal + rp = autodiff(fm, f, Duplicated, Duplicated(x, shadows[1][1])) + dres1 = rp[1] + fm2 = ForwardMode{#=ReturnPrimal=#false, ABI, ErrIfFuncWritten,RuntimeActivity}() + + res = ntuple(length(shadows[1])-1) do i + autodiff(fm2, f, Duplicated, Duplicated(x, shadows[1][i+1]))[1] + end + gres = if x isa AbstractFloat + dres1[1] + else + (dres1, res...) + end + ((gres,), rp[2]) else - (dres1, res...) + res = ntuple(length(shadows[1])) do i + autodiff(fm, f, Duplicated, Duplicated(x, shadows[1][i]))[1] + end + (if x isa AbstractFloat + res[1] + else + res + end,) end - (gres, rp[2]) else - res = ntuple(length(shadow)) do i - autodiff(fm, f, Duplicated, Duplicated(x, shadow[i]))[1] - end - if x isa AbstractFloat - res[1] + if ReturnPrimal + rp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1][1])) + dres1 = values(rp[1]) + gres = if x isa AbstractFloat + dres1[1] + else + fm2 = ForwardMode{#=ReturnPrimal=#false, ABI, ErrIfFuncWritten,RuntimeActivity}() + tmp = ntuple(length(shadows[1])-1) do i + values(autodiff(fm2, f, BatchDuplicated, BatchDuplicated(x, shadows[1][i+1]))[1]) + end + tupleconcat(dres1, tmp...) + end + ((gres,), rp[2]) else - res + tmp = ntuple(length(shadows[1])) do i + values(autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1][i]))[1]) + end + res = tupleconcat(tmp...) + (if x isa AbstractFloat + res[1] + else + res + end,) end end -end - -""" - jacobian(::ForwardMode, f, x; shadow=onehot(x)) - jacobian(::ForwardMode, f, x, ::Val{chunk}; shadow=onehot(x)) - -Compute the jacobian of an array or scalar-input function `f` using (potentially vector) -forward mode. All relevant arguments of the forward-mode [`gradient`](@ref) function -apply here. - -Example: - -```jldoctest -f(x) = [ x[1] * x[2], x[2] + x[3] ] - -grad = jacobian(Forward, f, [2.0, 3.0, 4.0]) - -# output -2×3 Matrix{Float64}: - 3.0 2.0 0.0 - 0.0 1.0 1.0 -``` - -For functions which return an AbstractArray, this function will return an array -whose shape is `(size(output)..., size(input)...)` - -For functions who return other types, this function will retun an array or tuple -of shape `size(input)` of values of the output type. -""" -@inline function jacobian(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, args...; kwargs...) where {ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} - gradtup = gradient(fm, args...; kwargs...) cols = if ReturnPrimal - gradtup[1] + gradtup[1][1] else - gradtup + gradtup[1] end - x = args[2] res = if x isa AbstractFloat cols - elseif length(cols) > 0 && cols[1] isa AbstractArray + elseif length(cols) > 0 && cols[1] isa AbstractArray && x isa AbstractArray inshape = size(x) outshape = size(cols[1]) # st : outshape x total inputs - st = Base.stack(cols) - - st3 = if length(inshape) <= 1 - st - else - reshape(st, (outshape..., inshape...)) - end - - st3 + tupstack(cols, outshape, inshape) elseif x isa AbstractArray - inshape = size(x) - reshape(collect(cols), inshape) + TupleArray(cols, size(x)) else cols end if ReturnPrimal - (res, gradtup[2]) + ((res,), gradtup[2]) else - res + (res,) end end """ - jacobian(::ReverseMode, f, x, ::Val{num_outs}, ::Val{chunk}=Val(1)) + jacobian(::ForwardMode, args...; kwargs...) + +Equivalent to gradient(::ForwardMode, args...; kwargs...) +""" +@inline function jacobian(fm::ForwardMode, args...; kwargs...) + gradient(fm, args...; kwargs...) +end + +""" + jacobian(::ReverseMode, f, x; n_outs=nothing, chunk=nothing) jacobian(::ReverseMode, f, x) -Compute the jacobian of an array-output function `f` using (potentially vector) -reverse mode. The `chunk` argument denotes the chunk size to use and `num_outs` -denotes the number of outputs `f` will return in an array. +Compute the jacobian of a array-output function `f` using (potentially vector) +reverse mode. The `chunk` argument denotes the chunk size to use and `n_outs` +denotes the shape of the array returned by `f`. Example: ```jldoctest f(x) = [ x[1] * x[2], x[2] + x[3] ] -grad = jacobian(Reverse, f, [2.0, 3.0, 4.0], Val(2)) +jacobian(Reverse, f, [2.0, 3.0, 4.0]) # output +([3.0 2.0 0.0; 0.0 1.0 1.0],) +``` -2×3 transpose(::Matrix{Float64}) with eltype Float64: - 3.0 2.0 0.0 - 0.0 1.0 1.0 +```jldoctest +f(x) = [ x[1] * x[2], x[2] + x[3] ] + +grad = jacobian(Reverse, f, [2.0, 3.0, 4.0], n_outs=Val((2,))) + +# output +([3.0 2.0 0.0; 0.0 1.0 1.0],) ``` -For functions which return an AbstractArray, this function will return an array -whose shape is `(size(output)..., size(input)...)` +This function will return an AbstractArray whose shape is `(size(output)..., size(input)...)`. +No guarantees are presently made about the type of the AbstractArray returned by this function +(which may or may not be the same as the input AbstractArray if provided). -For functions who return other types, this function will retun an array or tuple -of shape `size(output)` of values of the input type. +In the future, when this function is extended to handle non-array return types, +this function will retun an AbstractArray of shape `size(output)` of values of the input type. ``` """ -@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RuntimeActivity, RABI, #=Holomorphic=#false, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, RABI<:ABI, ErrIfFuncWritten, RuntimeActivity} - num = ((n_out_val + chunk - 1) ÷ chunk) - - if chunk == 0 - throw(ErrorException("Cannot differentiate with a batch size of 0")) - end - - XT = Core.Typeof(x) - MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState - tt′ = MD ? Tuple{BatchMixedDuplicated{XT, chunk}} : Tuple{BatchDuplicated{XT, chunk}} - tt = Tuple{XT} - rt = Core.Compiler.return_type(f, tt) - ModifiedBetween = Val((false, false)) - FA = Const{Core.Typeof(f)} - opt_mi = if RABI <: NonGenABI - Compiler.fspec(eltype(FA), tt′) - else - Val(codegen_world_age(Core.Typeof(f), tt)) - end - primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - - if num * chunk == n_out_val - last_size = chunk - primal2, adjoint2 = primal, adjoint - else - last_size = n_out_val - (num-1)*chunk - tt′ = Tuple{BatchDuplicated{Core.Typeof(x), last_size}} - primal2, adjoint2 = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - end +@inline function jacobian(::ReverseMode{ReturnPrimal,RuntimeActivity, RABI, Holomorphic, ErrIfFuncWritten}, f::F, x::X; n_outs::OutType=nothing, chunk::CT=nothing) where {ReturnPrimal, F, X, RABI<:ABI, ErrIfFuncWritten, RuntimeActivity, OutType, CT, Holomorphic} - tmp = ntuple(num) do i - Base.@_inline_meta - dx = ntuple(Val(i == num ? last_size : chunk)) do idx - Base.@_inline_meta - z = make_zero(x) - MD ? Ref(z) : z - end - res = (i == num ? primal2 : primal)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx)) - tape = res[1] - j = 0 - for shadow in res[3] - j += 1 - @inbounds shadow[(i-1)*chunk+j] += Compiler.default_adjoint(eltype(typeof(shadow))) + if n_outs == nothing + res = if f isa Const + f.val(x) + else + f(x) end - (i == num ? adjoint2 : adjoint)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), tape) - return MD ? (ntuple(Val(i == num ? last_size : chunk)) do idx - Base.@_inline_meta - dx[idx][] - end) : dx, (i == 1 ? size(res[3][1]) : nothing) - end - rows = tupleconcat(map(first, tmp)...) - outshape = tmp[1][2] - if x isa AbstractArray - inshape = size(x) - - st = Base.stack(rows) - - st2 = if length(outshape) == 1 - st + jac = if res isa AbstractArray + jacobian(ReverseMode{false,RuntimeActivity,RABI, Holomorphic, ErrIfFuncWritten}(), f, x; n_outs=Val(size(res)), chunk) + elseif res isa AbstractFloat + gradient(ReverseMode{false,RuntimeActivity,RABI, Holomorphic, ErrIfFuncWritten}(), f, x) else - reshape(st, (inshape..., outshape...)) + throw(AssertionError("Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))")) end - st3 = if length(outshape) == 1 && length(inshape) == 1 - transpose(st2) + return if ReturnPrimal + (jac, res) else - transp = ( ((length(inshape)+1):(length(inshape)+length(outshape)))... , (1:length(inshape))... ) - PermutedDimsArray(st2, transp) + jac end - - st3 else - reshape(collect(rows), outshape) - end -end - -@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RuntimeActivity,RABI, #=Holomorphic=#false, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,RuntimeActivity,RABI<:ABI, ErrIfFuncWritten} - XT = Core.Typeof(x) - MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState - tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} - tt = Tuple{XT} - rt = Core.Compiler.return_type(f, tt) - ModifiedBetween = Val((false, false)) - FA = Const{Core.Typeof(f)} - opt_mi = if RABI <: NonGenABI - Compiler.fspec(eltype(FA), tt′) - else - Val(codegen_world_age(Core.Typeof(f), tt)) - end - primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - tmp = ntuple(n_outs) do i - Base.@_inline_meta - z = make_zero(x) - dx = MD ? Ref(z) : z - res = primal(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx)) - tape = res[1] - @inbounds res[3][i] += Compiler.default_adjoint(eltype(typeof(res[3]))) - adjoint(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx), tape) - return MD ? dx[] : dx, (i == 1 ? size(res[3]) : nothing) - end - rows = map(first, tmp) - outshape = tmp[1][2] - if x isa AbstractArray - inshape = size(x) - st = Base.stack(rows) - - st2 = if length(outshape) == 1 - st + @assert !Holomorphic + n_out_val = if length(Compiler.element(n_outs)) == 0 + 0 else - reshape(st, (inshape..., outshape...)) + prod(Compiler.element(n_outs)) end - - st3 = if length(outshape) == 1 && length(inshape) == 1 - transpose(st2) + + if chunk == Val(0) + throw(ErrorException("Cannot differentiate with a batch size of 0")) + end + + XT = Core.Typeof(x) + MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState + tt = Tuple{XT} + rt = if f isa Const + Core.Compiler.return_type(f.val, tt) + else + Core.Compiler.return_type(f, tt) + end + + ModifiedBetween = Val((false, false)) + FRT = Core.Typeof(f) + FA = Const{FRT} + + opt_mi = if RABI <: NonGenABI + Compiler.fspec(FRT, tt′) else - transp = ( ((length(inshape)+1):(length(inshape)+length(outshape)))... , (1:length(inshape))... ) - PermutedDimsArray(st2, transp) + Val(codegen_world_age(FRT, tt)) end - st3 - else - reshape(collect(rows), outshape) - end -end + if chunk == Val(1) || chunk == nothing + tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} + primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + tmp = ntuple(Val(n_out_val)) do i + Base.@_inline_meta + z = make_zero(x) + dx = MD ? Ref(z) : z + res = primal(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx)) + tape = res[1] + @inbounds res[3][i] += Compiler.default_adjoint(eltype(typeof(res[3]))) + adjoint(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx), tape) + return MD ? dx[] : dx, (i == 1 ? size(res[3]) : nothing) + end + rows = map(first, tmp) + outshape = tmp[1][2] + rows, outshape + else + chunksize = Compiler.element(chunk) + tt′ = MD ? Tuple{BatchMixedDuplicated{XT, chunksize}} : Tuple{BatchDuplicated{XT, chunksize}} + primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#chunk, ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + + num = ((n_out_val + chunksize - 1) ÷ chunksize) + + if num * chunksize == n_out_val + last_size = chunksize + primal2, adjoint2 = primal, adjoint + else + last_size = n_out_val - (num-1)*chunksize + tt′ = Tuple{BatchDuplicated{Core.Typeof(x), last_size}} + primal2, adjoint2 = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + end -@inline function jacobian(::ReverseMode{ReturnPrimal,RuntimeActivity, RABI, Holomorphic, ErrIfFuncWritten}, f::F, x::X) where {ReturnPrimal, F, X, RABI<:ABI, ErrIfFuncWritten, RuntimeActivity, Holomorphic} - res = f(x) - jac = if res isa AbstractArray - jacobian(ReverseMode{false,RuntimeActivity,RABI, Holomorphic, ErrIfFuncWritten}(), f, x, Val(length(res))) - elseif res isa AbstractFloat - gradient(ReverseMode{false,RuntimeActivity,RABI, Holomorphic, ErrIfFuncWritten}(), f, x) - else - throw(AssertionError("Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))")) - end + tmp = ntuple(num) do i + Base.@_inline_meta + dx = ntuple(Val(i == num ? last_size : chunksize)) do idx + Base.@_inline_meta + z = make_zero(x) + MD ? Ref(z) : z + end + res = (i == num ? primal2 : primal)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx)) + tape = res[1] + j = 0 + for shadow in res[3] + j += 1 + @inbounds shadow[(i-1)*chunksize+j] += Compiler.default_adjoint(eltype(typeof(shadow))) + end + (i == num ? adjoint2 : adjoint)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), tape) + return MD ? (ntuple(Val(i == num ? last_size : chunksize)) do idx + Base.@_inline_meta + dx[idx][] + end) : dx, (i == 1 ? size(res[3][1]) : nothing) + end + rows = tupleconcat(map(first, tmp)...) + outshape = tmp[1][2] + rows, outshape + end + res = if x isa AbstractArray + inshape = size(x) + st2 = tupstack(rows, inshape, outshape) - if ReturnPrimal - (res, jac) - else - jac + st3 = if length(outshape) == 1 && length(inshape) == 1 + transpose(st2) + else + transp = ( ((length(inshape)+1):(length(inshape)+length(outshape)))... , (1:length(inshape))... ) + PermutedDimsArray(st2, transp) + end + + st3 + else + reshape(collect(rows), outshape) + end + if ReturnPrimal + # TODO optimize away redundant fwd pass + (res, if f isa Enzyme.Const + f.val(x) + else + f(x) + end) + else + (res,) + end end end diff --git a/test/ext/logexpfunctions.jl b/test/ext/logexpfunctions.jl index 69ee7f2e73..51dbe2ec76 100644 --- a/test/ext/logexpfunctions.jl +++ b/test/ext/logexpfunctions.jl @@ -9,6 +9,6 @@ xlogydiff(x) = xlogy(x[1], 23.0) grad_forward = Enzyme.gradient(Enzyme.Forward, xlogydiff, x) grad_reverse = Enzyme.gradient(Enzyme.Reverse, xlogydiff, x) - @test grad_forward[1] ≈ log(23.0) - @test grad_reverse[1] ≈ log(23.0) + @test grad_forward[1] ≈ [log(23.0)] + @test grad_reverse[1] ≈ [log(23.0)] end diff --git a/test/internal_rules.jl b/test/internal_rules.jl index b9a705941c..32a206c62e 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -197,14 +197,14 @@ end dL = zero(x) dL[2, 1] = 1.0 - @test Enzyme.gradient(Reverse, chol_lower0, x) ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + @test Enzyme.gradient(Reverse, chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] - @test reshape(collect(Enzyme.gradient(Forward, chol_lower0, x)), 4, 4) ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + @test Enzyme.gradient(Forward, chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] @test FiniteDifferences.grad(central_fdm(5, 1), chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] - @test reshape(collect(Enzyme.gradient(Forward, chol_upper0, x)), 4, 4) ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] - @test Enzyme.gradient(Reverse, chol_upper0, x) ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + @test Enzyme.gradient(Forward, chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + @test Enzyme.gradient(Reverse, chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] @test FiniteDifferences.grad(central_fdm(5, 1), chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] end @@ -225,14 +225,14 @@ end x = [1.0 0.13147601759884564 0.5282944836504488; 0.13147601759884564 1.0 0.18506733179093515; 0.5282944836504488 0.18506733179093515 1.0] for i in 1:size(x, 1) for j in 1:size(x, 2) - reverse_grad = Enzyme.gradient(Reverse, x -> tchol_lower(x, i, j), x) - forward_grad = reshape(collect(Enzyme.gradient(Forward, x -> tchol_lower(x, i, j), x)), size(x)) + reverse_grad = Enzyme.gradient(Reverse, x -> tchol_lower(x, i, j), x)[1] + forward_grad = Enzyme.gradient(Forward, x -> tchol_lower(x, i, j), x)[1] finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tchol_lower(x, i, j), x)[1] @test reverse_grad ≈ finite_diff @test forward_grad ≈ finite_diff - reverse_grad = Enzyme.gradient(Reverse, x -> tchol_upper(x, i, j), x) - forward_grad = reshape(collect(Enzyme.gradient(Forward, x -> tchol_upper(x, i, j), x)), size(x)) + reverse_grad = Enzyme.gradient(Reverse, x -> tchol_upper(x, i, j), x)[1] + forward_grad = Enzyme.gradient(Forward, x -> tchol_upper(x, i, j), x)[1] finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tchol_upper(x, i, j), x)[1] @test reverse_grad ≈ finite_diff @test forward_grad ≈ finite_diff @@ -257,26 +257,26 @@ end x = [1.0 0.13147601759884564 0.5282944836504488; 0.13147601759884564 1.0 0.18506733179093515; 0.5282944836504488 0.18506733179093515 1.0] for i in 1:15 B = [3.1 2.7 5.9 2.4 1.6; 7.9 8.2 1.3 9.4 5.5; 4.7 2.9 9.8 7.1 4.3] - reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_lower(x, B, i)), B) - # forward_grad = reshape(collect(Enzyme.gradient(Forward, B -> tcholsolv_lower(x, B, i), B)), size(B)) + reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_lower(x, B, i)), B)[1] + # forward_grad = Enzyme.gradient(Forward, B -> tcholsolv_lower(x, B, i), B)[1] finite_diff = FiniteDifferences.grad(central_fdm(5, 1), B -> tcholsolv_lower(x, B, i), B)[1] @test reverse_grad ≈ finite_diff # @test forward_grad ≈ finite_diff - reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_upper(x, B, i)), B) - # forward_grad = reshape(collect(Enzyme.gradient(Forward, B -> tcholsolv_upper(x, B, i), B)), size(B)) + reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_upper(x, B, i)), B)[1] + # forward_grad = Enzyme.gradient(Forward, B -> tcholsolv_upper(x, B, i), B))[1] finite_diff = FiniteDifferences.grad(central_fdm(5, 1), B -> tcholsolv_upper(x, B, i), B)[1] @test reverse_grad ≈ finite_diff # @test forward_grad ≈ finite_diff - reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_lower(x, B, i)), x) - #forward_grad = reshape(collect(Enzyme.gradient(Forward, x -> tcholsolv_lower(x, B, i), x)), size(x)) + reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_lower(x, B, i)), x)[1] + #forward_grad = Enzyme.gradient(Forward, x -> tcholsolv_lower(x, B, i), x)[1] finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tcholsolv_lower(x, B, i), x)[1] @test reverse_grad ≈ finite_diff #@test forward_grad ≈ finite_diff # - reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_upper(x, B, i)), x) - #forward_grad = reshape(collect(Enzyme.gradient(Forward, x -> tcholsolv_upper(x, B, i), x)), size(x)) + reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_upper(x, B, i)), x)[1] + #forward_grad = Enzyme.gradient(Forward, x -> tcholsolv_upper(x, B, i), x)[1] finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tcholsolv_upper(x, B, i), x)[1] @test reverse_grad ≈ finite_diff #@test forward_grad ≈ finite_diff @@ -554,7 +554,7 @@ end b = [1., 2.] dA = zero(A) Enzyme.autodiff(Reverse, h, Active, Duplicated(A, dA), Const(b)) - # dA_fwd = Enzyme.gradient(Forward, A->h(A, b), A) + # dA_fwd = Enzyme.gradient(Forward, A->h(A, b), A)[1] dA_fd = FiniteDifferences.grad(central_fdm(5, 1), A->h(A, b), A)[1] @test isapprox(dA, dA_fd) @@ -571,9 +571,9 @@ end @testset "Cholesky upper triangular v1" begin x = [1.0, -0.10541615131279458, 0.6219810761363638, 0.293343219811946, -0.10541615131279458, 1.0, -0.05258941747718969, 0.34629296878264443, 0.6219810761363638, -0.05258941747718969, 1.0, 0.4692436399208845, 0.293343219811946, 0.34629296878264443, 0.4692436399208845, 1.0] - @test collect(Enzyme.gradient(Forward, chol_upper, x)) ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + @test Enzyme.gradient(Forward, chol_upper, x)[1] ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - @test Enzyme.gradient(Reverse, chol_upper, x) ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + @test Enzyme.gradient(Reverse, chol_upper, x)[1] ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] end @testset "Linear solve for triangular matrices" begin diff --git a/test/runtests.jl b/test/runtests.jl index b079c0f540..65ad4e3fd4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -488,9 +488,9 @@ end @testset "Deferred upgrade" begin function gradsin(x) - return gradient(Reverse, sin, x) + return gradient(Reverse, sin, x)[1] end - res = Enzyme.gradient(Reverse, gradsin, 3.1) + res = Enzyme.gradient(Reverse, gradsin, 3.1)[1] @test res ≈ -sin(3.1) end @@ -2794,43 +2794,43 @@ end @testset "Gradient & NamedTuples" begin xy = (x = [1.0, 2.0], y = [3.0, 4.0]) - grad = Enzyme.gradient(Reverse, z -> sum(z.x .* z.y), xy) + grad = Enzyme.gradient(Reverse, z -> sum(z.x .* z.y), xy)[1] @test grad == (x = [3.0, 4.0], y = [1.0, 2.0]) xp = (x = [1.0, 2.0], p = 3) # 3::Int is non-diff - grad = Enzyme.gradient(Reverse, z -> sum(z.x .^ z.p), xp) + grad = Enzyme.gradient(Reverse, z -> sum(z.x .^ z.p), xp)[1] @test grad.x == [3.0, 12.0] xp2 = (x = [1.0, 2.0], p = 3.0) # mixed activity - grad = Enzyme.gradient(Reverse, z -> sum(z.x .^ z.p), xp2) + grad = Enzyme.gradient(Reverse, z -> sum(z.x .^ z.p), xp2)[1] @test grad.x == [3.0, 12.0] @test grad.p ≈ 5.545177444479562 xy = (x = [1.0, 2.0], y = [3, 4]) # y is non-diff - grad = Enzyme.gradient(Reverse, z -> sum(z.x .* z.y), xy) + grad = Enzyme.gradient(Reverse, z -> sum(z.x .* z.y), xy)[1] @test grad.x == [3.0, 4.0] @test grad.y === xy.y # make_zero did not copy this - grad = Enzyme.gradient(Reverse, z -> (z.x * z.y), (x=5.0, y=6.0)) + grad = Enzyme.gradient(Reverse, z -> (z.x * z.y), (x=5.0, y=6.0))[1] @test grad == (x = 6.0, y = 5.0) - grad = Enzyme.gradient(Reverse, abs2, 7.0) + grad = Enzyme.gradient(Reverse, abs2, 7.0)[1] @test grad == 14.0 end @testset "Gradient & SparseArrays / StaticArrays" begin x = sparse([5.0, 0.0, 6.0]) - dx = Enzyme.gradient(Reverse, sum, x) + dx = Enzyme.gradient(Reverse, sum, x)[1] @test dx isa SparseVector @test dx ≈ [1, 0, 1] x = sparse([5.0 0.0 6.0]) - dx = Enzyme.gradient(Reverse, sum, x) + dx = Enzyme.gradient(Reverse, sum, x)[1] @test dx isa SparseMatrixCSC @test dx ≈ [1 0 1] x = @SArray [5.0 0.0 6.0] - dx = Enzyme.gradient(Reverse, prod, x) + dx = Enzyme.gradient(Reverse, prod, x)[1] @test dx isa SArray @test dx ≈ [0 30 0] @@ -2851,7 +2851,7 @@ end @test y[2] == [0.0, 0.0, 1.0] x = @SArray [5.0 0.0 6.0] - dx = Enzyme.gradient(Forward, prod, x) + dx = Enzyme.gradient(Forward, prod, x)[1] @test dx[1] ≈ 0 @test dx[2] ≈ 30 @test dx[3] ≈ 0 @@ -2906,264 +2906,266 @@ mkarray(sz, args...) = reshape(vcat(args...), sz) scalar = 3.0 # ∂ scalar / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, x -> x^2, scalar) ≈ 6.0 - @test Enzyme.gradient(Enzyme.Reverse, x -> x^2, scalar) ≈ 6.0 - @test Enzyme.jacobian(Enzyme.Forward, x -> x^2, scalar) ≈ 6.0 - @test Enzyme.jacobian(Enzyme.Reverse, x -> x^2, scalar) ≈ 6.0 - @test Enzyme.gradient(Enzyme.Forward, x -> 2*x, scalar) ≈ 2.0 - @test Enzyme.gradient(Enzyme.Reverse, x -> 2*x, scalar) ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Forward, x -> 2*x, scalar) ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Reverse, x -> 2*x, scalar) ≈ 2.0 + @test Enzyme.gradient(Enzyme.Forward, x -> x^2, scalar)[1] ≈ 6.0 + @test Enzyme.gradient(Enzyme.Reverse, x -> x^2, scalar)[1] ≈ 6.0 + @test Enzyme.jacobian(Enzyme.Forward, x -> x^2, scalar)[1] ≈ 6.0 + @test Enzyme.jacobian(Enzyme.Reverse, x -> x^2, scalar)[1] ≈ 6.0 + @test Enzyme.gradient(Enzyme.Forward, x -> 2*x, scalar)[1] ≈ 2.0 + @test Enzyme.gradient(Enzyme.Reverse, x -> 2*x, scalar)[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x -> 2*x, scalar)[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Reverse, x -> 2*x, scalar)[1] ≈ 2.0 # ∂ vector / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, x -> [2*x, x^2], scalar) ≈ [2.0, 6.0] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [2*x, x^2], scalar) ≈ [2.0, 6.0] + @test Enzyme.gradient(Enzyme.Forward, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] - @test Enzyme.jacobian(Enzyme.Forward, x -> [2*x, x^2], scalar) ≈ [2.0, 6.0] - @test Enzyme.jacobian(Enzyme.Reverse, x -> [2*x, x^2], scalar) ≈ [2.0, 6.0] + @test Enzyme.jacobian(Enzyme.Forward, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] # ∂ tuple / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, x -> (2*x, x^2), scalar) ≃ (2.0, 6.0) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (2*x, x^2), scalar) ≈ [2.0, 6.0] + @test Enzyme.gradient(Enzyme.Forward, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (2*x, x^2), scalar)[1] ≈ [2.0, 6.0] - @test Enzyme.jacobian(Enzyme.Forward, x -> (2*x, x^2), scalar) ≃ (2.0, 6.0) - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (2*x, x^2), scalar) ≃ (2.0, 6.0) + @test Enzyme.jacobian(Enzyme.Forward, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) mkarray1 = x -> mkarray((2,2),2*x,sin(x),x^2,exp(x)) # ∂ matrix / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, mkarray1, scalar) ≈ [2.0 6.0; cos(scalar) exp(scalar)] - @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray1, scalar) ≈ [2.0 6.0; cos(scalar) exp(scalar)] + @test Enzyme.gradient(Enzyme.Forward, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] - @test Enzyme.jacobian(Enzyme.Forward, mkarray1, scalar) ≈ [2.0 6.0; cos(scalar) exp(scalar)] - @test Enzyme.jacobian(Enzyme.Reverse, mkarray1, scalar) ≈ [2.0 6.0; cos(scalar) exp(scalar)] + @test Enzyme.jacobian(Enzyme.Forward, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] + @test Enzyme.jacobian(Enzyme.Reverse, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] # ∂ struct / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar) == OutStruct(1.0,2*scalar,3*scalar^2) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar) == (OutStruct(1.0,2.0,3.0),) - @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar) == OutStruct(1.0,2*scalar,3*scalar^2) - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar) == (OutStruct(1.0,2.0,3.0),) + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar)[1] == OutStruct(1.0,2*scalar,3*scalar^2) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar)[1] == (OutStruct(1.0,2.0,3.0),) + @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar)[1] == OutStruct(1.0,2*scalar,3*scalar^2) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar)[1] == (OutStruct(1.0,2.0,3.0),) vector = [2.7, 3.1] # ∂ scalar / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], vector) ≃ (vector[2],vector[1]) - @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], vector) ≈ [vector[2], vector[1]] - @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], vector) ≈ [vector[2], vector[1]] - @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], vector) ≈ [vector[2], vector[1]] + @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], vector)[1] ≈ [vector[2],vector[1]] + @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] + @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] + @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] # ∂ vector / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector) ≃ - ([vector[2], -sin(vector[1])], [vector[1], 1.0]) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector) ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) - @test Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector) ≈ + @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ + [vector[2] vector[1]; -sin(vector[1]) 1.0] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ + [vector[2] vector[1]; -sin(vector[1]) 1.0] + @test Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ [vector[2] vector[1]; -sin(vector[1]) 1.0] - @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector) ≈ + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ [vector[2] vector[1]; -sin(vector[1]) 1.0] # ∂ tuple / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≃ - ((vector[2], -sin(vector[1])), (vector[1], 1.0)) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≈ + @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≃ + [(vector[2], -sin(vector[1])), (vector[1], 1.0)] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) - @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≃ + @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≃ [(vector[2], -sin(vector[1])), (vector[1], 1.0)] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] mkarray2 = x -> mkarray((2,2), x[1]*x[2], exp(x[2]), cos(x[1])+x[2], x[1]) # ∂ matrix / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, mkarray2, vector) ≃ - ([vector[2] -sin(vector[1]); 0.0 1.0], [vector[1] 1.0; exp(vector[2]) 0.0]) - @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, vector) - @test Enzyme.jacobian(Enzyme.Forward, mkarray2, vector) ≈ + @test Enzyme.gradient(Enzyme.Forward, mkarray2, vector)[1] ≈ + mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, vector)[1] + @test Enzyme.jacobian(Enzyme.Forward, mkarray2, vector)[1] ≈ mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) - @test Enzyme.jacobian(Enzyme.Reverse, mkarray2, vector) ≈ + @test Enzyme.jacobian(Enzyme.Reverse, mkarray2, vector)[1] ≈ mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) # ∂ struct / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector) ≃ - (OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector)[1] ≃ + [OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) - @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector) ≃ + @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector)[1] ≃ [OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) tuplev = (2.7, 3.1) # ∂ scalar / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], tuplev) ≃ (tuplev[2],tuplev[1]) - @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], tuplev) ≃ (tuplev[2],tuplev[1]) - @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], tuplev) ≃ (tuplev[2],tuplev[1]) - @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], tuplev) ≃ (tuplev[2],tuplev[1]) + @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) + @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) + @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) + @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) # ∂ vector / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev) ≃ + @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≃ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev) ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev) ≈ + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≈ [tuplev[2] tuplev[1]; -sin(tuplev[1]) 1.0] - @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev) ≃ + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≃ [(tuplev[2], tuplev[1]), (-sin(tuplev[1]), 1.0)] # ∂ tuple / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≃ + @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≃ ((vector[2], -sin(vector[1])), (vector[1], 1.0)) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≃ + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≃ ((tuplev[2], -sin(tuplev[1])), (tuplev[1], 1.0)) - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≈ + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ [tuplev[2] tuplev[1]; -sin(tuplev[1]) 1.0] # ∂ matrix / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, mkarray2, tuplev) ≃ + @test Enzyme.gradient(Enzyme.Forward, mkarray2, tuplev)[1] ≃ ([tuplev[2] -sin(tuplev[1]); 0.0 1.0], [tuplev[1] 1.0; exp(tuplev[2]) 0.0]) - @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, tuplev) - @test_broken Enzyme.jacobian(Enzyme.Forward, mkarray2, tuplev) ≈ + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, tuplev)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, mkarray2, tuplev)[1] ≈ [tuplev[2] -sin(tuplev[1]); 0.0 1.0;;; tuplev[1] 1.0; exp(tuplev[2]) 0.0] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> mkarray2, tuplev) ≈ + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> mkarray2, tuplev)[1] ≈ [tuplev[2] -sin(tuplev[1]); 0.0 1.0;;; tuplev[1] 1.0; exp(tuplev[2]) 0.0] # ∂ struct / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev) ≃ + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev)[1] ≃ (OutStruct(tuplev[2], -sin(tuplev[1]), 0.0), OutStruct(tuplev[1], 1.0, exp(tuplev[2]))) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev) ≃ + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev)[1] ≃ [OutStruct(tuplev[2], -sin(tuplev[1]), 0.0), OutStruct(tuplev[1], 1.0, exp(tuplev[2]))] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) matrix = [2.7 3.1; 4.7 5.6] # ∂ scalar / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix) ≃ - (matrix[1,2], matrix[2,2], matrix[1,1], matrix[2,1]) - @test Enzyme.gradient(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix) ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] - @test Enzyme.jacobian(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix) ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] - @test Enzyme.jacobian(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix) ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.gradient(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.gradient(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.jacobian(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.jacobian(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] # ∂ vector / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix) ≃ - ([matrix[1,2], 0.0], [0.0, matrix[2,2]], [matrix[1,1], 0.0], [0.0, matrix[2,1]]) - @test_broken Enzyme.gradient(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix) + @test Enzyme.gradient(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ + mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) + @test_broken Enzyme.gradient(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] # again we can't use array construction syntax because of 1.6 - @test Enzyme.jacobian(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix) ≈ + @test Enzyme.jacobian(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) - @test Enzyme.jacobian(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix) ≈ + @test Enzyme.jacobian(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) # ∂ tuple / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) ≃ ((matrix[1,2], 0.0), (0.0, matrix[2,2]), (matrix[1,1], 0.0), (0.0, matrix[2,1])) + @test Enzyme.gradient(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] ≃ + [(matrix[1,2],0.0) (matrix[1,1],0.0); (0.0,matrix[2,2]) (0.0,matrix[2,1])] @test_broken Enzyme.gradient(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) - @test Enzyme.jacobian(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) ≃ + @test Enzyme.jacobian(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] ≃ [(matrix[1,2],0.0) (matrix[1,1],0.0); (0.0,matrix[2,2]) (0.0,matrix[2,1])] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] mkarray3 = x -> mkarray((2,2), x[1,1]*x[1,2], exp(x[1,1])+x[2,2], x[2,1]*x[2,2], sin(x[1,2])+x[2,1]) # ∂ matrix / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, mkarray3, matrix) ≃ - ([matrix[1,2] 0.0; exp(matrix[1,1]) 0.0], [0.0 matrix[2,2]; 0.0 1.0], [matrix[1,1] 0.0; 0.0 cos(matrix[1,2])], [0.0 matrix[2,1]; 1.0 0.0]) - @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray3, matrix) + @test Enzyme.gradient(Enzyme.Forward, mkarray3, matrix)[1] ≈ + mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, + matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray3, matrix)[1] # array construction syntax broken on 1.6 - @test Enzyme.jacobian(Enzyme.Forward, mkarray3, matrix) ≈ + @test Enzyme.jacobian(Enzyme.Forward, mkarray3, matrix)[1] ≈ mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) - @test Enzyme.jacobian(Enzyme.Reverse, mkarray3, matrix) ≈ + @test Enzyme.jacobian(Enzyme.Reverse, mkarray3, matrix)[1] ≈ mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) # ∂ tuple / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix) ≃ - (OutStruct(matrix[1,2], 0.0, exp(matrix[1,1])), OutStruct(0.0, matrix[2,2], 0.0), OutStruct(matrix[1,1], 0.0, 0.0), OutStruct(0.0, matrix[2,1], 1.0)) - @test_broken Enzyme.gradient(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix) - @test Enzyme.jacobian(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix) ≃ + @test Enzyme.gradient(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] ≃ + [OutStruct(matrix[1,2],0.0, exp(matrix[1,1])) OutStruct(matrix[1,1],0.0,0.0); OutStruct(0.0,matrix[2,2],0.0) OutStruct(0.0,matrix[2,1], 1.0)] + @test_broken Enzyme.gradient(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] + @test Enzyme.jacobian(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] ≃ [OutStruct(matrix[1,2],0.0, exp(matrix[1,1])) OutStruct(matrix[1,1],0.0,0.0); OutStruct(0.0,matrix[2,2],0.0) OutStruct(0.0,matrix[2,1], 1.0)] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] istruct = InpStruct(2.7, 3.1, 4.7) # ∂ scalar / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct) - @test Enzyme.gradient(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct) ≃ InpStruct(istruct.i2, istruct.i1, 1.0) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct) - @test Enzyme.jacobian(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct) ≃ InpStruct(istruct.i2, istruct.i1, 1.0) + @test_broken Enzyme.gradient(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct)[1] + @test Enzyme.gradient(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct)[1] ≃ InpStruct(istruct.i2, istruct.i1, 1.0) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct)[1] + @test Enzyme.jacobian(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct)[1] ≃ InpStruct(istruct.i2, istruct.i1, 1.0) # ∂ vector / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct) - @test Enzyme.jacobian(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct) ≃ [InpStruct(istruct.i2, istruct.i1, 0.0), InpStruct(1.0, 0.0, -sin(istruct.i3))] + @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] ≃ [InpStruct(istruct.i2, istruct.i1, 0.0), InpStruct(1.0, 0.0, -sin(istruct.i3))] # ∂ tuple / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct) - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct) + @test_broken Enzyme.gradient(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] mkarray4 = x -> mkarray((2,2), x.i1*x.i2, exp(x.i2), cos(x.i3)+x.i1, x.i1) # ∂ matrix / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct) - @test Enzyme.jacobian(Enzyme.Reverse, mkarray4, istruct) ≃ + @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] + @test Enzyme.jacobian(Enzyme.Reverse, mkarray4, istruct)[1] ≃ [InpStruct(istruct.i2, istruct.i1, 0.0) InpStruct(1.0, 0.0, -sin(istruct.i3)); InpStruct(0.0, exp(istruct.i2), 0.0) InpStruct(1.0, 0.0, 0.0)] # ∂ struct / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct) - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct) + @test_broken Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] end @testset "Simple Jacobian" begin - @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0) ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0) ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0]) ≈ [4.0, 6.0] + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0)[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0)[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0])[1] ≈ [4.0, 6.0] - @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, Val(1)) ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, Val(1)) ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], Val(1)) ≈ [4.0, 6.0] + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, chunk=Val(1))[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, chunk=Val(1))[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], chunk=Val(1))[1] ≈ [4.0, 6.0] - @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, Val(2)) ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, Val(2)) ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], Val(2)) ≈ [4.0, 6.0] + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, chunk=Val(2))[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, chunk=Val(2))[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], chunk=Val(2))[1] ≈ [4.0, 6.0] - @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, Val(2)) ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, Val(2), Val(1)) ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, Val(2), Val(2)) ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)))[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)), chunk=Val(1))[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)), chunk=Val(2))[1] ≈ [1.0, 2.0] x = float.(reshape(1:6, 2, 3)) fillabs2(x) = [sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x), 1000*sum(abs2, x)] - jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x) + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x)[1] @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, Val(1)) + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, chunk=Val(1))[1] @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, Val(2)) + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, chunk=Val(2))[1] @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] @@ -3171,14 +3173,14 @@ end @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, Val(4), Val(1)) + jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, n_outs=Val((4,)), chunk=Val(1))[1] @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, Val(4), Val(2)) + jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, n_outs=Val((4,)), chunk=Val(2))[1] @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] @@ -3189,14 +3191,14 @@ end x2 = InpStruct(1.0, 2.0, 3.0) - jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, Val(4), Val(1)) + jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, n_outs=Val((4,)), chunk=Val(1))[1] @test jac[1] == InpStruct(2.0, 4.0, 6.0) @test jac[2] == InpStruct(20.0, 40.0, 60.0) @test jac[3] == InpStruct(200.0, 400.0, 600.0) @test jac[4] == InpStruct(2000.0, 4000.0, 6000.0) - jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, Val(4), Val(2)) + jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, n_outs=Val((4,)), chunk=Val(2))[1] @test jac[1] == InpStruct(2.0, 4.0, 6.0) @test jac[2] == InpStruct(20.0, 40.0, 60.0) @@ -3205,7 +3207,7 @@ end filloutabs2(x) = OutStruct(sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x)) - jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x) + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x)[1] @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) @@ -3216,7 +3218,7 @@ end @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) - jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, Val(1)) + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, chunk=Val(1))[1] @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) @@ -3227,7 +3229,7 @@ end @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) - jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, Val(2)) + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, chunk=Val(2))[1] @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) @@ -3245,27 +3247,27 @@ end [v[2], v[1]*v[1], v[1]*v[1]*v[1]] end - jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], #=n_outs=# Val(3), Val(1)) + jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], n_outs=Val((3,)), chunk=Val(1))[1] @test size(jac) == (3, 2) @test jac ≈ [ 0.0 1.0; 4.0 0.0; 12.0 0.0] - jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], Val(1)) + jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], chunk=Val(1))[1] @test size(jac) == (3, 2) @test jac ≈ [ 0.0 1.0; 4.0 0.0; 12.0 0.0] - @test jac == Enzyme.jacobian(Forward, inout, [2.0, 3.0]) + @test jac == Enzyme.jacobian(Forward, inout, [2.0, 3.0])[1] - jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], #=n_outs=# Val(3), Val(2)) + jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], n_outs=Val((3,)), chunk=Val(2))[1] @test size(jac) == (3, 2) @test jac ≈ [ 0.0 1.0; 4.0 0.0; 12.0 0.0] - jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], Val(2)) + jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], chunk=Val(2))[1] @test size(jac) == (3, 2) @test jac ≈ [ 0.0 1.0; 4.0 0.0; @@ -3286,13 +3288,13 @@ end utmp .= A*x[2:end] .+ x[1] end - J_r_1(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_1(A, θ), x, Val(5)) - J_r_2(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_2(A, θ), x, Val(5)) - J_r_3(u, A, x) = Enzyme.jacobian(Reverse, θ -> f_test_3!(u, A, θ), x, Val(5)) + J_r_1(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_1(A, θ), x, n_outs=Val((5,)))[1] + J_r_2(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_2(A, θ), x, n_outs=Val((5,)))[1] + J_r_3(u, A, x) = Enzyme.jacobian(Reverse, θ -> f_test_3!(u, A, θ), x, n_outs=Val((5,)))[1] - J_f_1(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_1(A, θ)), x) - J_f_2(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_2(A, θ)), x) - J_f_3(u, A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_3!(u, A, θ)), x) + J_f_1(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_1(A, θ)), x)[1] + J_f_2(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_2(A, θ)), x)[1] + J_f_3(u, A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_3!(u, A, θ)), x)[1] x = ones(6) A = Matrix{Float64}(LinearAlgebra.I, 5, 5) @@ -3351,7 +3353,7 @@ end dry = zeros(2) function foo(y, dy, x, dx) - autodiff_deferred(Reverse, speelpenning, Const, Duplicated(y, dy), Duplicated(x, dx)) + autodiff(Reverse, speelpenning, Const, Duplicated(y, dy), Duplicated(x, dx)) return nothing end @@ -3776,8 +3778,8 @@ end @testset "Constant Complex return" begin vec = [0.5] - @test Enzyme.gradient(Enzyme.Reverse, fexpandempty, vec)[1] ≈ 1.0 - @test Enzyme.gradient(Enzyme.Forward, fexpandempty, vec)[1] ≈ 1.0 + @test Enzyme.gradient(Enzyme.Reverse, fexpandempty, vec)[1] ≈ [1.0] + @test Enzyme.gradient(Enzyme.Forward, fexpandempty, vec)[1] ≈ [1.0] end const CUmemoryPool2 = Ptr{Float64} @@ -3924,10 +3926,10 @@ const objective3 = params -> mixture_loglikelihood3(params, data) -13.935687326484112, -38.00044665702692, 12.87712891527131] - @test expected ≈ Enzyme.gradient(Reverse, objective1, params0) + @test expected ≈ Enzyme.gradient(Reverse, objective1, params0)[1] # objective2 fails from runtime activity requirements - # @test expected ≈ Enzyme.gradient(Reverse, objective2, params0) - @test expected ≈ Enzyme.gradient(Reverse, objective3, params0) + # @test expected ≈ Enzyme.gradient(Reverse, objective2, params0)[1] + @test expected ≈ Enzyme.gradient(Reverse, objective3, params0)[1] end struct HarmonicAngle From bbaa1f8d8c83daf4b28018d6387148be9121bdb1 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 17 Sep 2024 21:06:17 -0500 Subject: [PATCH 291/495] Use namedtuple for grad/jacobian (#1850) * Use namedtuple for grad/jacobian * Update index.md * Update Enzyme.jl * Update Enzyme.jl * Update Enzyme.jl * Update Enzyme.jl * Update index.md --- docs/src/faq.md | 2 +- docs/src/index.md | 88 ++++++++++++++++++++++++------------------- src/Enzyme.jl | 28 +++++++------- test/ext/bfloat16s.jl | 4 +- 4 files changed, 67 insertions(+), 55 deletions(-) diff --git a/docs/src/faq.md b/docs/src/faq.md index 88c0cce3b9..6b3bbce6b4 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -193,7 +193,7 @@ That is why Enzyme provides a helper function `Enzyme.make_zero` that does this ```jldoctest sparse Enzyme.make_zero(a) -Enzyme.gradient(Reverse, sum, a) # This calls make_zero(a) +Enzyme.gradient(Reverse, sum, a)[1] # This calls make_zero(a) # output diff --git a/docs/src/index.md b/docs/src/index.md index 1f7f092a99..2643b87a1e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -76,24 +76,32 @@ Both the inplace and "normal" variant return the gradient. The difference is tha ## Forward mode -The return value of forward mode with a `Duplicated` return is a tuple containing as the first value -the primal return value and as the second value the derivative. +The return value when using `ForwardWithPrimal` is a tuple containing as the first value +the derivative return value and as the second value the original value. + +The return value when using `Forward` is a single-element tuple containing the derivative. In forward mode `Duplicated(x, 0.0)` is equivalent to `Const(x)`, except that we can perform more optimizations for `Const`. ```jldoctest rosenbrock -julia> autodiff(Forward, rosenbrock, Duplicated, Const(1.0), Duplicated(3.0, 1.0)) +julia> autodiff(ForwardWithPrimal, rosenbrock, Const(1.0), Duplicated(3.0, 1.0)) (400.0, 400.0) -julia> autodiff(Forward, rosenbrock, Duplicated, Duplicated(1.0, 1.0), Const(3.0)) -(400.0, -800.0) +julia> autodiff(Forward, rosenbrock, Const(1.0), Duplicated(3.0, 1.0)) +(400.0,) + +julia> autodiff(ForwardWithPrimal, rosenbrock, Duplicated(1.0, 1.0), Const(3.0)) +(-800.0, 400.0) + +julia> autodiff(Forward, rosenbrock, Duplicated(1.0, 1.0), Const(3.0)) +(-800.0,) ``` Of note, when we seed both arguments at once the tangent return is the sum of both. ```jldoctest rosenbrock -julia> autodiff(Forward, rosenbrock, Duplicated, Duplicated(1.0, 1.0), Duplicated(3.0, 1.0)) +julia> autodiff(ForwardWithPrimal, rosenbrock, Duplicated(1.0, 1.0), Duplicated(3.0, 1.0)) (400.0, -400.0) ``` @@ -121,7 +129,7 @@ Note the seeding through `dx`. We can also use vector mode to calculate both derivatives at once. ```jldoctest rosenbrock -julia> autodiff(Forward, rosenbrock, BatchDuplicated, BatchDuplicated(1.0, (1.0, 0.0)), BatchDuplicated(3.0, (0.0, 1.0))) +julia> autodiff(ForwardWithPrimal, rosenbrock, BatchDuplicated(1.0, (1.0, 0.0)), BatchDuplicated(3.0, (0.0, 1.0))) (400.0, (var"1" = -800.0, var"2" = 400.0)) julia> x = [1.0, 3.0] @@ -131,7 +139,7 @@ julia> x = [1.0, 3.0] julia> dx_1 = [1.0, 0.0]; dx_2 = [0.0, 1.0]; -julia> autodiff(Forward, rosenbrock_inp, BatchDuplicated, BatchDuplicated(x, (dx_1, dx_2))) +julia> autodiff(ForwardWithPrimal, rosenbrock_inp, BatchDuplicated(x, (dx_1, dx_2))) (400.0, (var"1" = -800.0, var"2" = 400.0)) ``` @@ -145,18 +153,20 @@ Like [`autodiff`](@ref), the mode (forward or reverse) is determined by the firs The functions [`gradient`](@ref) and [`gradient!`](@ref) compute the gradient of function with vector input and scalar return. +Gradient functions take a mode as the first argument. If the mode is `Reverse` or `Forward`, the return type is a tuple of gradients of each argument. +If the mode is `ReverseWithPrimal` or `ForwardWithPrimal`, the return type is a named tuple containing both the derivatives and the original return result. + ```jldoctest rosenbrock julia> gradient(Reverse, rosenbrock_inp, [1.0, 2.0]) -2-element Vector{Float64}: - -400.0 - 200.0 +([-400.0, 200.0],) + +julia> gradient(ReverseWithPrimal, rosenbrock_inp, [1.0, 2.0]) +(derivs=[-400.0, 200.0], val=100.0) julia> # inplace variant dx = [0.0, 0.0]; gradient!(Reverse, dx, rosenbrock_inp, [1.0, 2.0]) -2-element Vector{Float64}: - -400.0 - 200.0 +([-400.0, 200.0],) julia> dx 2-element Vector{Float64}: @@ -164,14 +174,16 @@ julia> dx 200.0 julia> gradient(Forward, rosenbrock_inp, [1.0, 2.0]) -(-400.0, 200.0) +([-400.0, 200.0],) + +julia> gradient(ForwardWithPrimal, rosenbrock_inp, [1.0, 2.0]) +(derivs = [-400.0, 200.0], val = 100.0) julia> # in forward mode, we can also optionally pass a chunk size # to specify the number of derivatives computed simulateneously # using vector forward mode - chunk_size = Val(2) - gradient(Forward, rosenbrock_inp, [1.0, 2.0], chunk_size) -(-400.0, 200.0) + gradient(Forward, rosenbrock_inp, [1.0, 2.0]; chunk=Val(1)) +([-400.0, 200.0],) ``` ## Jacobian Convenience functions @@ -179,31 +191,31 @@ julia> # in forward mode, we can also optionally pass a chunk size The function [`jacobian`](@ref) computes the Jacobian of a function vector input and vector return. Like [`autodiff`](@ref) and [`gradient`](@ref), the mode (forward or reverse) is determined by the first argument. +Again like [`gradient`](@ref), if the mode is `Reverse` or `Forward`, the return type is a tuple of jacobians of each argument. +If the mode is `ReverseWithPrimal` or `ForwardWithPrimal`, the return type is a named tuple containing both the derivatives and the original return result. + +Both forward and reverse modes take an optional chunk size to compute several derivatives simultaneously using vector mode, and reverse mode optionally takes `n_outs` which describes the shape of the output value. + ```jldoctest rosenbrock julia> foo(x) = [rosenbrock_inp(x), prod(x)]; -julia> output_size = Val(2) # here we have to provide the output size of `foo` since it cannot be statically inferred - jacobian(Reverse, foo, [1.0, 2.0], output_size) -2×2 transpose(::Matrix{Float64}) with eltype Float64: - -400.0 200.0 - 2.0 1.0 +julia> jacobian(Reverse, foo, [1.0, 2.0]) +([-400.0 200.0; 2.0 1.0],) -julia> chunk_size = Val(2) # By specifying the optional chunk size argument, we can use vector inverse mode to propogate derivatives of multiple outputs at once. - jacobian(Reverse, foo, [1.0, 2.0], output_size, chunk_size) -2×2 transpose(::Matrix{Float64}) with eltype Float64: - -400.0 200.0 - 2.0 1.0 +julia> jacobian(ReverseWithPrimal, foo, [1.0, 2.0]) +(derivs = ([-400.0 200.0; 2.0 1.0],), val = [100.0, 2.0]) + +julia> jacobian(Reverse, foo, [1.0, 2.0]; chunk=Val(2)) +([-400.0 200.0; 2.0 1.0],) + +julia> jacobian(Reverse, foo, [1.0, 2.0]; chunk=Val(2), n_outs=Val((2,))) +([-400.0 200.0; 2.0 1.0],) julia> jacobian(Forward, foo, [1.0, 2.0]) -2×2 Matrix{Float64}: - -400.0 200.0 - 2.0 1.0 - -julia> # Again, the optinal chunk size argument allows us to use vector forward mode - jacobian(Forward, foo, [1.0, 2.0], chunk_size) -2×2 Matrix{Float64}: - -400.0 200.0 - 2.0 1.0 +([-400.0 200.0; 2.0 1.0],) + +julia> jacobian(Forward, foo, [1.0, 2.0], chunk=Val(2)) +([-400.0 200.0; 2.0 1.0],) ``` ## Hessian Vector Product Convenience functions @@ -257,4 +269,4 @@ julia> grad 2-element Vector{Float64}: 2.880510859951098 1.920340573300732 -``` \ No newline at end of file +``` diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 66551f2958..c4994fb363 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1082,21 +1082,21 @@ a tuple where the first element contains the derivatives, and the second element grad = gradient(ReverseWithPrimal, f, [2.0, 3.0]) # output -(([3.0, 2.0],), 6.0) +(derivs = ([3.0, 2.0],), val = 6.0) ``` ```jldoctest gradient grad = gradient(ReverseWithPrimal, mul, [2.0], [3.0]) # output -(([3.0], [2.0]), 6.0) +(derivs = ([3.0], [2.0]), val = 6.0) ``` ```jldoctest gradient grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) # output -(([3.0], nothing), 6.0) +(derivs = ([3.0], nothing), val = 6.0) ``` """ @@ -1161,7 +1161,7 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) return quote Base.@_inline_meta $(toemit...) - (($(resargs...),), res[2]) + (; derivs=($(resargs...),), val=res[2]) end else return quote @@ -1196,14 +1196,14 @@ dx = [0.0, 0.0] gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0]) # output -(([3.0, 2.0],), 6.0) +(derivs = ([3.0, 2.0],), val = 6.0) ``` """ @inline function gradient!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} make_zero!(dx) res = autodiff(rm, f, Active, Duplicated(x, dx)) return if ReturnPrimal - ((dx,), res[2]) + (; derivs=(dx,), val=res[2]) else (dx,) end @@ -1300,7 +1300,7 @@ gradient(Forward, f, [2.0, 3.0]) gradient(ForwardWithPrimal, f, [2.0, 3.0]) # output -(([3.0, 2.0],), 6.0) +(derivs = ([3.0, 2.0],), val = 6.0) ``` ```jldoctest gradfwd @@ -1315,7 +1315,7 @@ gradient(Forward, f, [2.0, 3.0]; chunk=Val(1)) gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=Val(1)) # output -(([3.0, 2.0],), 6.0) +(derivs = ([3.0, 2.0],), val = 6.0) ``` For functions which return an AbstractArray or scalar, this function will return an AbstracttArray @@ -1336,10 +1336,10 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) """ @inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f, x; chunk::CS=nothing, shadows=create_shadows(chunk, x)) where {ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity, CS} if length(shadows[1]) == 0 - if ReturnPrimal - ((x,), f(x.val)) + return if ReturnPrimal + (; derivs=(x,), val=f(x.val)) else - return (x,) + (x,) end end if chunk == Val(0) @@ -1430,7 +1430,7 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) cols end if ReturnPrimal - ((res,), gradtup[2]) + (; derivs=(res,), val=gradtup[2]) else (res,) end @@ -1498,7 +1498,7 @@ this function will retun an AbstractArray of shape `size(output)` of values of t end return if ReturnPrimal - (jac, res) + (; derivs=jac, val=res) else jac end @@ -1606,7 +1606,7 @@ this function will retun an AbstractArray of shape `size(output)` of values of t end if ReturnPrimal # TODO optimize away redundant fwd pass - (res, if f isa Enzyme.Const + (; derivs=res, val=if f isa Enzyme.Const f.val(x) else f(x) diff --git a/test/ext/bfloat16s.jl b/test/ext/bfloat16s.jl index 0a47f48f03..daaf6ef74c 100644 --- a/test/ext/bfloat16s.jl +++ b/test/ext/bfloat16s.jl @@ -2,6 +2,6 @@ using Enzyme using Test using BFloat16s -@test_broken Enzyme.gradient(Reverse, sum, ones(BFloat16, 10)) ≈ ones(BFloat16, 10) +@test_broken Enzyme.gradient(Reverse, sum, ones(BFloat16, 10))[1] ≈ ones(BFloat16, 10) -@test_broken Enzyme.gradient(Forward, sum, ones(BFloat16, 10)) ≈ ones(BFloat16, 10) +@test_broken Enzyme.gradient(Forward, sum, ones(BFloat16, 10))[1] ≈ ones(BFloat16, 10) From 6a19be2cfb982b1d12cacc7c5aa182ed7321801f Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Sep 2024 00:26:29 -0500 Subject: [PATCH 292/495] Simplify deferred functions (#1849) * Simplify deferred functions * fix * Update runtests.jl * fix * fix * fix * fix * fix --- docs/Project.toml | 3 ++- docs/src/faq.md | 2 +- docs/src/index.md | 29 ++++++++++++++++------------- examples/custom_rule.jl | 19 +++++++++++-------- lib/EnzymeCore/src/rules.jl | 12 ++++++++++++ src/Enzyme.jl | 33 --------------------------------- test/abi.jl | 30 +++++++++++++++--------------- test/amdgpu.jl | 6 +++--- test/cuda.jl | 10 +++++----- test/metal.jl | 4 ++-- test/runtests.jl | 12 ++++++------ 11 files changed, 73 insertions(+), 87 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 56dd852972..14301cf64d 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,8 +1,9 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -Literate = "2" Documenter = "1" +Literate = "2" diff --git a/docs/src/faq.md b/docs/src/faq.md index 6b3bbce6b4..c5a80a976d 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -627,7 +627,7 @@ Presently Enzyme only considers floats as base types. As a result, Enzyme does n ```jldoctest types f_int(x) = x * x -Enzyme.autodiff(Forward, f_int, DuplicatedNoNeed, Duplicated(3, 1)) +Enzyme.autodiff(Forward, f_int, Duplicated, Duplicated(3, 1)) # output diff --git a/docs/src/index.md b/docs/src/index.md index 2643b87a1e..3c1c31b4af 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -102,7 +102,10 @@ Of note, when we seed both arguments at once the tangent return is the sum of bo ```jldoctest rosenbrock julia> autodiff(ForwardWithPrimal, rosenbrock, Duplicated(1.0, 1.0), Duplicated(3.0, 1.0)) -(400.0, -400.0) +(-400.0, 400.0) + +julia> autodiff(Forward, rosenbrock, Duplicated(1.0, 1.0), Duplicated(3.0, 1.0)) +(-400.0,) ``` We can also use forward mode with our inplace method. @@ -118,8 +121,8 @@ julia> dx = [1.0, 1.0] 1.0 1.0 -julia> autodiff(Forward, rosenbrock_inp, Duplicated, Duplicated(x, dx)) -(400.0, -400.0) +julia> autodiff(ForwardWithPrimal, rosenbrock_inp, Duplicated, Duplicated(x, dx)) +(-400.0, 400.0) ``` Note the seeding through `dx`. @@ -130,7 +133,7 @@ We can also use vector mode to calculate both derivatives at once. ```jldoctest rosenbrock julia> autodiff(ForwardWithPrimal, rosenbrock, BatchDuplicated(1.0, (1.0, 0.0)), BatchDuplicated(3.0, (0.0, 1.0))) -(400.0, (var"1" = -800.0, var"2" = 400.0)) +((var"1" = -800.0, var"2" = 400.0), 400.0) julia> x = [1.0, 3.0] 2-element Vector{Float64}: @@ -140,7 +143,7 @@ julia> x = [1.0, 3.0] julia> dx_1 = [1.0, 0.0]; dx_2 = [0.0, 1.0]; julia> autodiff(ForwardWithPrimal, rosenbrock_inp, BatchDuplicated(x, (dx_1, dx_2))) -(400.0, (var"1" = -800.0, var"2" = 400.0)) +((var"1" = -800.0, var"2" = 400.0), 400.0) ``` ## Gradient Convenience functions @@ -161,7 +164,7 @@ julia> gradient(Reverse, rosenbrock_inp, [1.0, 2.0]) ([-400.0, 200.0],) julia> gradient(ReverseWithPrimal, rosenbrock_inp, [1.0, 2.0]) -(derivs=[-400.0, 200.0], val=100.0) +(derivs = ([-400.0, 200.0],), val = 100.0) julia> # inplace variant dx = [0.0, 0.0]; @@ -177,7 +180,7 @@ julia> gradient(Forward, rosenbrock_inp, [1.0, 2.0]) ([-400.0, 200.0],) julia> gradient(ForwardWithPrimal, rosenbrock_inp, [1.0, 2.0]) -(derivs = [-400.0, 200.0], val = 100.0) +(derivs = ([-400.0, 200.0],), val = 100.0) julia> # in forward mode, we can also optionally pass a chunk size # to specify the number of derivatives computed simulateneously @@ -200,22 +203,22 @@ Both forward and reverse modes take an optional chunk size to compute several de julia> foo(x) = [rosenbrock_inp(x), prod(x)]; julia> jacobian(Reverse, foo, [1.0, 2.0]) -([-400.0 200.0; 2.0 1.0],) +([-400.0 200.0; 2.0 1.0],) julia> jacobian(ReverseWithPrimal, foo, [1.0, 2.0]) -(derivs = ([-400.0 200.0; 2.0 1.0],), val = [100.0, 2.0]) +(derivs = ([-400.0 200.0; 2.0 1.0],), val = [100.0, 2.0]) julia> jacobian(Reverse, foo, [1.0, 2.0]; chunk=Val(2)) -([-400.0 200.0; 2.0 1.0],) +([-400.0 200.0; 2.0 1.0],) julia> jacobian(Reverse, foo, [1.0, 2.0]; chunk=Val(2), n_outs=Val((2,))) -([-400.0 200.0; 2.0 1.0],) +([-400.0 200.0; 2.0 1.0],) julia> jacobian(Forward, foo, [1.0, 2.0]) -([-400.0 200.0; 2.0 1.0],) +([-400.0 200.0; 2.0 1.0],) julia> jacobian(Forward, foo, [1.0, 2.0], chunk=Val(2)) -([-400.0 200.0; 2.0 1.0],) +([-400.0 200.0; 2.0 1.0],) ``` ## Hessian Vector Product Convenience functions diff --git a/examples/custom_rule.jl b/examples/custom_rule.jl index 2b3f226fb0..f778939032 100644 --- a/examples/custom_rule.jl +++ b/examples/custom_rule.jl @@ -65,7 +65,7 @@ function forward(config::FwdConfig, func::Const{typeof(f)}, ::Type{<:Duplicated} end # In the signature of our rule, we have made use of `Enzyme`'s activity annotations. Let's break down each one: -# - the [`FwdConfig`](@ref) configuration passes certain compile-time information about differentiation procedure (the width, and if we're using runtime activity), +# - the [`EnzymeRules.FwdConfig`](@ref) configuration passes certain compile-time information about differentiation procedure (the width, and if we're using runtime activity), # - the [`Const`](@ref) annotation on `f` indicates that we accept a function `f` that does not have a derivative component, # which makes sense since `f` is not a closure with data that could be differentiated. # - the [`Duplicated`](@ref) annotation given in the second argument annotates the return value of `f`. This means that @@ -123,8 +123,9 @@ dy = [0.0, 0.0] # If a custom rule is specified for the correct function/argument types, but not the correct activity annotation, # a runtime error will be thrown alerting the user to the missing activity rule rather than silently ignoring the rule." -# Finally, it may be that either `x`, `y`, or the return value are marked as [`Const`](@ref). We can in fact handle this case, -# along with the previous two cases, all together in a single rule: +# Finally, it may be that either `x`, `y`, or the return value are marked as [`Const`](@ref), in which case we can simply return the original result. However, Enzyme also may determine the return is not differentiable and also not needed for other computations, in which case we should simply return nothing. +# +# We can in fact handle this case, along with the previous two cases, all together in a single rule by leveraging utility functions [`EnzymeRules.needs_primal`](@ref) and [`EnzymeRules.needs_shadow`](@ref), which return true if the original return or the derivative is needed to be returned, respectively: Base.delete_method.(methods(forward, (Const{typeof(f)}, Vararg{Any}))) # delete our old rules @@ -138,12 +139,14 @@ function forward(config, func::Const{typeof(f)}, RT::Type{<:Union{Const, Duplica make_zero!(y.dval) end dret = !(y isa Const) ? sum(y.dval) : zero(eltype(y.val)) - if RT <: Const + if needs_primal(config) && needs_shadow(config) + return Duplicated(sum(y.val), dret) + elseif needs_primal(config) return sum(y.val) - elseif RT <: DuplicatedNoNeed + elseif needs_shadow(config) return dret else - return Duplicated(sum(y.val), dret) + return nothing end end @@ -189,7 +192,7 @@ function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f)}, ::T end # Let's unpack our signature for `augmented_primal` : -# * We accepted a [`EnzymeRules.Config`](@ref) object with a specified width of 1, which means that our rule does not support batched reverse mode. +# * We accepted a [`EnzymeRules.RevConfig`](@ref) object with a specified width of 1, which means that our rule does not support batched reverse mode. # * We annotated `f` with [`Const`](@ref) as usual. # * We dispatched on an [`Active`](@ref) annotation for the return value. This is a special annotation for scalar values, such as our return value, # that indicates that that we care about the value's derivative but we need not explicitly allocate a mutable shadow since it is a scalar value. @@ -197,7 +200,7 @@ end # Now, let's unpack the body of our `augmented_primal` rule: # * We checked if the `config` requires the primal. If not, we need not compute the return value, but we make sure to mutate `y` in all cases. -# * We checked if `x` could possibly be overwritten using the `Overwritten` attribute of [`EnzymeRules.Config`](@ref). +# * We checked if `x` could possibly be overwritten using the `Overwritten` attribute of [`EnzymeRules.RevConfig`](@ref). # If so, we save the elements of `x` on the `tape` of the returned [`EnzymeRules.AugmentedReturn`](@ref) object. # * We return a shadow of `nothing` since the return value is [`Active`](@ref) and hence does not need a shadow. diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 8d01d321da..d4469e9793 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -35,7 +35,19 @@ Getters for the type parameters are provided by `needs_primal`, `needs_shadow`, struct FwdConfig{NeedsPrimal, NeedsShadow, Width, RuntimeActivity} end const FwdConfigWidth{Width} = FwdConfig{<:Any,<:Any,Width} +""" + needs_primal(::FwdConfig) + needs_primal(::RevConfig) + +Whether a custom rule should return the original result of the function. +""" @inline needs_primal(::FwdConfig{NeedsPrimal}) where NeedsPrimal = NeedsPrimal +""" + needs_shadow(::FwdConfig) + needs_shadow(::RevConfig) + +Whether a custom rule should return the shadow (derivative) of the function result. +""" @inline needs_shadow(::FwdConfig{<:Any, NeedsShadow}) where NeedsShadow = NeedsShadow @inline width(::FwdConfig{<:Any, <:Any, Width}) where Width = Width diff --git a/src/Enzyme.jl b/src/Enzyme.jl index c4994fb363..2b6f1f3627 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -531,39 +531,6 @@ code, as well as high-order differentiation. thunk(f, args...) end -""" - autodiff_deferred(mode::Mode, f, ::Type{A}, args) - -Like [`autodiff_deferred`](@ref) but will try to extend f to an annotation, if needed. -""" -@inline function autodiff_deferred(mode::CMode, f::F, args::Vararg{Annotation, Nargs}) where {F, CMode<:Mode, Nargs} - autodiff_deferred(EnzymeCore.set_err_if_func_written(mode), Const(f), args...) -end -@inline function autodiff_deferred(mode::CMode, f::F, ::Type{RT}, args::Vararg{Annotation, Nargs}) where {F, RT<:Annotation, CMode<:Mode, Nargs} - autodiff_deferred(EnzymeCore.set_err_if_func_written(mode), Const(f), RT, args...) -end - -""" - autodiff_deferred(mode, f, args...) - -Like [`autodiff_deferred`](@ref) but will try to guess the activity of the return value. -""" - -@inline function autodiff_deferred(mode::M, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, M<:Mode, Nargs} - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - rt = if mode isa ReverseMode - Compiler.primal_return_type(mode, Val(codegen_world_age(eltype(FA), tt)), eltype(FA), tt) - else - Core.Compiler.return_type(f.val, tt) - end - - if rt === Union{} - error("return type is Union{}, giving up.") - end - rt = guess_activity(rt, mode) - autodiff_deferred(mode, f, rt, args...) -end - """ autodiff_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation, Nargs}) diff --git a/test/abi.jl b/test/abi.jl index 63fe48dc61..342722c44d 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -20,13 +20,13 @@ using Test @test () === autodiff(Forward, f, Const(nothing)) - res = autodiff_deferred(Reverse, f, Const(nothing)) + res = autodiff_deferred(Reverse, Const(f), Const, Const(nothing)) @test res === ((nothing,),) - res = autodiff_deferred(Enzyme.set_abi(Reverse, NonGenABI), f, Const, Const(nothing)) + res = autodiff_deferred(Enzyme.set_abi(Reverse, NonGenABI), Const(f), Const, Const(nothing)) @test res === ((nothing,),) - @test () === autodiff_deferred(Forward, f, Const(nothing)) - @test () === autodiff_deferred(Enzyme.set_abi(Forward, NonGenABI), f, Const, Const(nothing)) + @test () === autodiff_deferred(Forward, Const(f), Const, Const(nothing)) + @test () === autodiff_deferred(Enzyme.set_abi(Forward, NonGenABI), Const(f), Const, Const(nothing)) # ConstType -> Type{Int} res = autodiff(Reverse, f, Const, Const(Int)) @@ -37,9 +37,9 @@ using Test @test res === ((nothing,),) @test () === autodiff(Forward, f, Const(Int)) - res = autodiff_deferred(Reverse, f, Const(Int)) + res = autodiff_deferred(Reverse, Const(f), Const, Const(Int)) @test res === ((nothing,),) - @test () === autodiff_deferred(Forward, f, Const(Int)) + @test () === autodiff_deferred(Forward, Const(f), Const, Const(Int)) # Complex numbers @test_throws ErrorException autodiff(Reverse, f, Active, Active(1.5 + 0.7im)) @@ -54,10 +54,10 @@ using Test cres, = autodiff(Forward, f, Duplicated(1.5 + 0.7im, 1.0+0im)) @test cres ≈ 1.0 + 0.0im - @test_throws ErrorException autodiff_deferred(Reverse, f, Active(1.5 + 0.7im)) - @test_throws ErrorException autodiff_deferred(ReverseHolomorphic, f, Active(1.5 + 0.7im)) + @test_throws ErrorException autodiff_deferred(Reverse, Const(f), Active, Active(1.5 + 0.7im)) + @test_throws ErrorException autodiff_deferred(ReverseHolomorphic, Const(f), Active, Active(1.5 + 0.7im)) - cres, = autodiff_deferred(Forward, f, Duplicated(1.5 + 0.7im, 1.0+0im)) + cres, = autodiff_deferred(Forward, Const(f), Duplicated, Duplicated(1.5 + 0.7im, 1.0+0im)) @test cres ≈ 1.0 + 0.0im # Unused singleton argument @@ -97,7 +97,7 @@ using Test x = [0.0] dx = [1.2] - autodiff_deferred(Reverse, squareRetArray, Const, Duplicated(x, dx)) + autodiff_deferred(Reverse, Const(squareRetArray), Const, Duplicated(x, dx)) dx = [1.2] @test () === autodiff(Forward, squareRetArray, Const, Duplicated(x, dx)) @@ -113,7 +113,7 @@ using Test @test pair[1] ≈ 3.0 @test pair[2] ≈ 2.0 - pair = autodiff_deferred(Reverse, mul, Active(2.0), Active(3.0))[1] + pair = autodiff_deferred(Reverse, Const(mul), Active, Active(2.0), Active(3.0))[1] @test pair[1] ≈ 3.0 @test pair[2] ≈ 2.0 @@ -122,7 +122,7 @@ using Test @test pair[2] ≈ 2.0 @test orig ≈ 6.0 - pair, orig = autodiff_deferred(ReverseWithPrimal, mul, Active(2.0), Active(3.0)) + pair, orig = autodiff_deferred(ReverseWithPrimal, Const(mul), Active, Active(2.0), Active(3.0)) @test pair[1] ≈ 3.0 @test pair[2] ≈ 2.0 @test orig ≈ 6.0 @@ -142,7 +142,7 @@ using Test res = Ref(3.0) dres = Ref(1.0) - pair, orig = autodiff_deferred(ReverseWithPrimal, inplace, Const, Duplicated(res, dres)) + pair, orig = autodiff_deferred(ReverseWithPrimal, Const(inplace), Const, Duplicated(res, dres)) @test pair == (nothing,) @test res[] ≈ 6.0 @test dres[] ≈ 2.0 @@ -163,7 +163,7 @@ using Test res = Ref(3.0) dres = Ref(1.0) - pair, orig = autodiff_deferred(ReverseWithPrimal, inplace2, Const, Duplicated(res, dres)) + pair, orig = autodiff_deferred(ReverseWithPrimal, Const(inplace2), Const, Duplicated(res, dres)) @test pair == (nothing,) @test res[] ≈ 6.0 @test dres[] ≈ 2.0 @@ -450,7 +450,7 @@ end @test r[2] ≈ 100.0 @test r[1][1] ≈ -400.0 @test r[1][2] ≈ 200.0 - r = autodiff_deferred(ForwardWithPrimal, rosenbrock_inp, Duplicated, BatchDuplicated(x, (dx_1, dx_2))) + r = autodiff_deferred(ForwardWithPrimal, Const(rosenbrock_inp), Duplicated, BatchDuplicated(x, (dx_1, dx_2))) @test r[2] ≈ 100.0 @test r[1][1] ≈ -400.0 @test r[1][2] ≈ 200.0 diff --git a/test/amdgpu.jl b/test/amdgpu.jl index 9c9b097422..75318ac97d 100644 --- a/test/amdgpu.jl +++ b/test/amdgpu.jl @@ -11,7 +11,7 @@ function mul_kernel(A) end function grad_mul_kernel(A, dA) - autodiff_deferred(Reverse, mul_kernel, Const, Duplicated(A, dA)) + autodiff_deferred(Reverse, Const(mul_kernel), Const, Duplicated(A, dA)) return nothing end @@ -34,7 +34,7 @@ function exp_kernel(A) end function grad_exp_kernel(A, dA) - autodiff_deferred(Reverse, exp_kernel, Const, Duplicated(A, dA)) + autodiff_deferred(Reverse, Const(exp_kernel), Const, Duplicated(A, dA)) return nothing end @@ -57,7 +57,7 @@ function cos_kernel(A) end function grad_cos_kernel(A, dA) - autodiff_deferred(Reverse, cos_kernel, Const, Duplicated(A, dA)) + autodiff_deferred(Reverse, Const(cos_kernel), Const, Duplicated(A, dA)) return nothing end diff --git a/test/cuda.jl b/test/cuda.jl index 29a55dcfc8..736f667a87 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -11,7 +11,7 @@ function mul_kernel(A) end function grad_mul_kernel(A, dA) - autodiff_deferred(Reverse, mul_kernel, Const, Duplicated(A, dA)) + autodiff_deferred(Reverse, Const(mul_kernel), Const, Duplicated(A, dA)) return nothing end @@ -34,7 +34,7 @@ function exp_kernel(A) end function grad_exp_kernel(A, dA) - autodiff_deferred(Reverse, exp_kernel, Const, Duplicated(A, dA)) + autodiff_deferred(Reverse, Const(exp_kernel), Const, Duplicated(A, dA)) return nothing end @@ -57,7 +57,7 @@ function cos_kernel(A) end function grad_cos_kernel(A, dA) - autodiff_deferred(Reverse, cos_kernel, Const, Duplicated(A, dA)) + autodiff_deferred(Reverse, Const(cos_kernel), Const, Duplicated(A, dA)) return nothing end @@ -76,7 +76,7 @@ function val_kernel!(_, ::Val{N}) where N end function dval_kernel!(du, ::Val{N}) where {N} - autodiff_deferred(Reverse, val_kernel!, Const, du, Const(Val(N))) + autodiff_deferred(Reverse, Const(val_kernel!), Const, du, Const(Val(N))) return nothing end @@ -123,7 +123,7 @@ function ddense!( autodiff_deferred( Reverse, - dense!, + Const(dense!), Const, dfeats_out, dfeats_in, dW, db, Const(Val(nfeat_out)), Const(Val(nfeat_in)), Const(Val(ndof)) diff --git a/test/metal.jl b/test/metal.jl index 661bcfbedc..588357c92e 100644 --- a/test/metal.jl +++ b/test/metal.jl @@ -16,12 +16,12 @@ function fun_gpu!(A, B, a) end function ∇_fun_cpu!(A, Ā, B, B̄, a) - Enzyme.autodiff_deferred(Reverse, fun_cpu!, Const, DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a)) + Enzyme.autodiff_deferred(Reverse, Const(fun_cpu!), Const, DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a)) nothing end function ∇_fun_gpu!(A_d, Ā_d, B_d, B̄_d, a) - Enzyme.autodiff_deferred(Reverse, fun_gpu!, Const, Duplicated(A_d, Ā_d), Duplicated(B_d, B̄_d), Const(a)) + Enzyme.autodiff_deferred(Reverse, Const(fun_gpu!), Const, Duplicated(A_d, Ā_d), Duplicated(B_d, B̄_d), Const(a)) nothing end diff --git a/test/runtests.jl b/test/runtests.jl index 65ad4e3fd4..d99a28832b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -436,7 +436,7 @@ end def_A, thunk_A = copy(A), copy(A) primal = Enzyme.autodiff(ReverseWithPrimal, dot, Active, Duplicated(A, dA))[2] @test primal == 34.0 - primal = Enzyme.autodiff_deferred(ReverseWithPrimal, dot, Active, Duplicated(def_A, def_dA))[2] + primal = Enzyme.autodiff_deferred(ReverseWithPrimal, Const(dot), Active, Duplicated(def_A, def_dA))[2] @test primal == 34.0 dup = Duplicated(thunk_A, thunk_dA) @@ -752,7 +752,7 @@ end @testset "Nested AD" begin tonest(x,y) = (x + y)^2 - @test autodiff(Forward, (x,y) -> autodiff_deferred(Forward, tonest, Duplicated(x, 1.0), Const(y))[1], Const(1.0), Duplicated(2.0, 1.0))[1] ≈ 2.0 + @test autodiff(Forward, (x,y) -> autodiff(Forward, Const(tonest), Duplicated(x, 1.0), Const(y))[1], Const(1.0), Duplicated(2.0, 1.0))[1] ≈ 2.0 end @testset "Hessian" begin @@ -762,7 +762,7 @@ end end function grad(x, dx, y, dy) - Enzyme.autodiff_deferred(Reverse, origf, Duplicated(x, dx), DuplicatedNoNeed(y, dy)) + Enzyme.autodiff(Reverse, Const(origf), Duplicated(x, dx), DuplicatedNoNeed(y, dy)) nothing end @@ -797,7 +797,7 @@ end function f_gradient_deferred!(dx, x, tmp) dtmp = make_zero(tmp) - autodiff_deferred(Reverse, f_ip, Active, Duplicated(x, dx), Duplicated(tmp, dtmp)) + autodiff_deferred(Reverse, Const(f_ip), Active, Duplicated(x, dx), Duplicated(tmp, dtmp)) return nothing end @@ -828,7 +828,7 @@ end function nested_df!(dx, x) make_zero!(dx) - autodiff_deferred(Reverse, nested_f, Active, Duplicated(x, dx)) + autodiff_deferred(Reverse, Const(nested_f), Active, Duplicated(x, dx)) return nothing end @@ -1869,7 +1869,7 @@ end @testset "Mismatched return" begin @test_throws ErrorException autodiff(Reverse, _->missing, Active, Active(2.1)) - @test_throws ErrorException autodiff_deferred(Reverse, _->missing, Active, Active(2.1)) + @test_throws ErrorException autodiff_deferred(Reverse, Const(_->missing), Active, Active(2.1)) end @testset "GCPreserve" begin From efbe9a110ebc89c7b4534c03c93bba46eaa5a634 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Sep 2024 13:32:33 -0500 Subject: [PATCH 293/495] Add within autodiff cmd (#1851) * Add within autodiff cmd * fix * fixup * fix * fix * fix --- lib/EnzymeCore/src/EnzymeCore.jl | 8 ++++++++ src/Enzyme.jl | 11 +++++++++-- src/compiler/interpreter.jl | 17 ++++++++++++++++- test/abi.jl | 6 ++++++ 4 files changed, 39 insertions(+), 3 deletions(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index cc71f0f9c6..f51c742f5d 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -6,6 +6,7 @@ export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplic export MixedDuplicated, BatchMixedDuplicated export DefaultABI, FFIABI, InlineABI, NonGenABI export BatchDuplicatedFunc +export within_autodiff function batch_size end @@ -338,4 +339,11 @@ if !isdefined(Base, :get_extension) include("../ext/AdaptExt.jl") end +""" + within_autodiff() + +Returns true if within autodiff, otherwise false. +""" +function within_autodiff end + end # module EnzymeCore diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 2b6f1f3627..985e8deeea 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -5,8 +5,8 @@ import EnzymeCore import EnzymeCore: Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal -import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity -export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity +import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity, within_autodiff +export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity, within_autodiff import EnzymeCore: BatchDuplicatedFunc export BatchDuplicatedFunc @@ -1744,4 +1744,11 @@ macro import_rrule(args...) return _import_rrule(args...) end +""" + within_autodiff() + +Returns true if within autodiff, otherwise false. +""" +@inline EnzymeCore.within_autodiff() = false + end # module diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 482690e20f..c167581c3a 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -213,7 +213,7 @@ let # overload `inlining_policy` end import Core.Compiler: abstract_call, abstract_call_known, ArgInfo, StmtInfo, AbsIntState, get_max_methods, - CallMeta, Effects, NoCallInfo, widenconst, mapany + CallMeta, Effects, NoCallInfo, widenconst, mapany, MethodResultPure struct AutodiffCallInfo <: CallInfo # ... @@ -225,6 +225,21 @@ function abstract_call_known(interp::EnzymeInterpreter, @nospecialize(f), max_methods::Int = get_max_methods(interp, f, sv)) (; fargs, argtypes) = arginfo + + if f === Enzyme.within_autodiff + if length(argtypes) != 1 + @static if VERSION < v"1.11.0-" + return CallMeta(Union{}, Effects(), NoCallInfo()) + else + return CallMeta(Union{}, Union{}, Effects(), NoCallInfo()) + end + end + @static if VERSION < v"1.11.0-" + return CallMeta(Core.Const(true), Core.Compiler.EFFECTS_TOTAL, MethodResultPure()) + else + return CallMeta(Core.Const(true), Union{}, Core.Compiler.EFFECTS_TOTAL, MethodResultPure()) + end + end if f === Enzyme.autodiff && length(argtypes) >= 4 if widenconst(argtypes[2]) <: Enzyme.Mode && widenconst(argtypes[3]) <: Enzyme.Annotation && widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation} diff --git a/test/abi.jl b/test/abi.jl index 342722c44d..cbd467c155 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -460,6 +460,12 @@ abssum(x) = sum(abs2, x); mulsin(x) = sin(x[1] * x[2]) +@testset "within_autodiff" begin + @test !Enzyme.within_autodiff() + @test_broken Enzyme.autodiff(ForwardWithPrimal, Enzyme.within_autodiff)[1] + @test Enzyme.autodiff(ForwardWithPrimal, () -> Enzyme.within_autodiff())[1] +end + @testset "Type inference" begin x = ones(10) @inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x)) From f49d1fc68c3a66aeed2883fb308bec92ebd30103 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Sep 2024 16:42:17 -0500 Subject: [PATCH 294/495] Fix enzymetestutils tests (#1858) --- lib/EnzymeTestUtils/test/test_forward.jl | 1 + lib/EnzymeTestUtils/test/test_reverse.jl | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/EnzymeTestUtils/test/test_forward.jl b/lib/EnzymeTestUtils/test/test_forward.jl index 57385a1dd9..24de5b2f44 100644 --- a/lib/EnzymeTestUtils/test/test_forward.jl +++ b/lib/EnzymeTestUtils/test/test_forward.jl @@ -20,6 +20,7 @@ function f_kwargs_fwd!(x; kwargs...) end function EnzymeRules.forward( + config, func::Const{typeof(f_kwargs_fwd)}, RT::Type{ <:Union{Const,Duplicated,DuplicatedNoNeed,BatchDuplicated,BatchDuplicatedNoNeed} diff --git a/lib/EnzymeTestUtils/test/test_reverse.jl b/lib/EnzymeTestUtils/test/test_reverse.jl index b394fa171d..901c259af8 100644 --- a/lib/EnzymeTestUtils/test/test_reverse.jl +++ b/lib/EnzymeTestUtils/test/test_reverse.jl @@ -17,7 +17,7 @@ function f_kwargs_rev!(x; kwargs...) end function EnzymeRules.augmented_primal( - config::EnzymeRules.ConfigWidth{1}, + config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(f_kwargs_rev)}, RT::Type{<:Union{Const,Duplicated,DuplicatedNoNeed}}, x::Union{Const,Duplicated}; @@ -39,7 +39,7 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.ConfigWidth{1}, + config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(f_kwargs_rev)}, dret::Type{<:Union{Const,Duplicated,DuplicatedNoNeed}}, tape, From 6e867ba81bab2abafaed85f56a0f6e7cc38b01a2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Sep 2024 16:46:45 -0500 Subject: [PATCH 295/495] Try llvm.jl 9.1 (#1857) * Try llvm.jl 9.1 * fixups * more fix * bump version * fix * fix * fix --- Project.toml | 4 ++-- src/compiler.jl | 46 +++++++++++++++++++++++++------------- src/compiler/optimize.jl | 20 ++++++++--------- src/compiler/orcv2.jl | 2 +- src/compiler/utils.jl | 10 ++++++++- src/compiler/validation.jl | 4 ++-- 6 files changed, 55 insertions(+), 31 deletions(-) diff --git a/Project.toml b/Project.toml index 19315d01dc..9cf8028a76 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.0" +version = "0.13.1" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -37,7 +37,7 @@ ChainRulesCore = "1" EnzymeCore = "0.8" Enzyme_jll = "0.0.150" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" -LLVM = "6.1, 7, 8, =9.0" +LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" ObjectFile = "0.4" Preferences = "1.4" diff --git a/src/compiler.jl b/src/compiler.jl index 1d21fb99a1..be4679d263 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4045,6 +4045,22 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr return adjointf, augmented_primalf, TapeType end +function get_subprogram(f::LLVM.Function) + @static if isdefined(LLVM, :subprogram) + LLVM.subprogram(f) + else + LLVM.get_subprogram(f) + end +end + +function set_subprogram!(f::LLVM.Function, sp) + @static if isdefined(LLVM, :subprogram) + LLVM.subprogram!(f, sp) + else + LLVM.set_subprogram!(f, sp) + end +end + function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, Mode::API.CDerivativeMode, augmented, width, returnPrimal, shadow_init, world, interp) is_adjoint = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModeCombined is_split = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModePrimal @@ -4422,8 +4438,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, push!(args, psret) end res = LLVM.call!(builder, LLVM.function_type(llvmf), llvmf, args) - if LLVM.get_subprogram(llvmf) !== nothing - metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) ) + if get_subprogram(llvmf) !== nothing + metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) ) end if psret !== nothing res = load!(builder, convert(LLVMType, Func_RT), psret) @@ -4449,8 +4465,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end val = call!(builder, LLVM.function_type(enzymefn), enzymefn, realparms) - if LLVM.get_subprogram(llvm_f) !== nothing - metadata(val)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) ) + if get_subprogram(llvm_f) !== nothing + metadata(val)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) ) end @inline function fixup_abi(index, value) @@ -4514,8 +4530,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, push!(function_attributes(cf), EnumAttribute("alwaysinline", 0)) for shadowv in shadows c = call!(builder, LLVM.function_type(cf), cf, [shadowv]) - if LLVM.get_subprogram(llvm_f) !== nothing - metadata(c)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) ) + if get_subprogram(llvm_f) !== nothing + metadata(c)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) ) end end end @@ -5027,9 +5043,9 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function wrapper_ft = LLVM.FunctionType(RT, wrapper_types) wrapper_f = LLVM.Function(mod, LLVM.name(entry_f), wrapper_ft) callconv!(wrapper_f, callconv(entry_f)) - sfn = LLVM.get_subprogram(entry_f) + sfn = get_subprogram(entry_f) if sfn !== nothing - LLVM.set_subprogram!(wrapper_f, sfn) + set_subprogram!(wrapper_f, sfn) end hasReturnsTwice = any(map(k->kind(k)==kind(EnumAttribute("returns_twice")), collect(function_attributes(entry_f)))) @@ -5107,8 +5123,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function entry = BasicBlock(wrapper_f, "entry") position!(builder, entry) - if LLVM.get_subprogram(entry_f) !== nothing - debuglocation!(builder, DILocation(0, 0, LLVM.get_subprogram(entry_f))) + if get_subprogram(entry_f) !== nothing + debuglocation!(builder, DILocation(0, 0, get_subprogram(entry_f))) end wrapper_args = Vector{LLVM.Value}() @@ -5178,8 +5194,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function end res = call!(builder, LLVM.function_type(entry_f), entry_f, wrapper_args) - if LLVM.get_subprogram(entry_f) !== nothing - metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(entry_f) ) + if get_subprogram(entry_f) !== nothing + metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(entry_f) ) end callconv!(res, LLVM.callconv(entry_f)) @@ -5411,10 +5427,10 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function LLVM.run!(pm, mod) end if haskey(globals(mod), "llvm.used") - unsafe_delete!(mod, globals(mod)["llvm.used"]) + eraseInst(mod, globals(mod)["llvm.used"]) for u in user.(collect(uses(entry_f))) if isa(u, LLVM.GlobalVariable) && endswith(LLVM.name(u), "_slot") && startswith(LLVM.name(u), "julia") - unsafe_delete!(mod, u) + eraseInst(mod, u) end end end @@ -6469,7 +6485,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; st = LLVM.user(u) LLVM.API.LLVMInstructionEraseFromParent(st) end - LLVM.unsafe_delete!(mod, f) + eraseInst(mod, f) end linkage!(adjointf, LLVM.API.LLVMExternalLinkage) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 8c6385edb8..2e3e8194c9 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -533,7 +533,7 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module) end end for inst in todel - unsafe_delete!(LLVM.parent(inst), inst) + eraseInst(LLVM.parent(inst), inst) end end end @@ -1145,7 +1145,7 @@ function prop_global!(g) end end replace_uses!(var, res) - unsafe_delete!(LLVM.parent(var), var) + eraseInst(LLVM.parent(var), var) continue end if isa(var, LLVM.AddrSpaceCastInst) @@ -1441,7 +1441,7 @@ function propagate_returned!(mod::LLVM.Module) end if !illegalUse for c in reverse(torem) - unsafe_delete!(LLVM.parent(c), c) + eraseInst(LLVM.parent(c), c) end B = IRBuilder() position!(B, first(instructions(first(blocks(fn))))) @@ -1617,7 +1617,7 @@ function propagate_returned!(mod::LLVM.Module) end API.EnzymeSetCalledFunction(un, nfn, toremove) end - unsafe_delete!(mod, fn) + eraseInst(mod, fn) changed = true catch break @@ -2030,26 +2030,26 @@ function removeDeadArgs!(mod::LLVM.Module, tm) for u in LLVM.uses(rfunc) u = LLVM.user(u) - unsafe_delete!(LLVM.parent(u), u) + eraseInst(LLVM.parent(u), u) end - unsafe_delete!(mod, rfunc) + eraseInst(mod, rfunc) for u in LLVM.uses(sfunc) u = LLVM.user(u) - unsafe_delete!(LLVM.parent(u), u) + eraseInst(LLVM.parent(u), u) end - unsafe_delete!(mod, sfunc) + eraseInst(mod, sfunc) for fn in functions(mod) for b in blocks(fn) inst = first(LLVM.instructions(b)) if isa(inst, LLVM.CallInst) fn = LLVM.called_operand(inst) if fn == func - unsafe_delete!(b, inst) + eraseInst(b, inst) end end end end - unsafe_delete!(mod, func) + eraseInst(mod, func) end function optimize!(mod::LLVM.Module, tm) diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index 40d13eea80..78ff089e7d 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -224,7 +224,7 @@ function get_trampoline(job) # but it would be nicer if _thunk just codegen'd the half # we need. other_func = functions(mod)[other_name] - LLVM.unsafe_delete!(mod, other_func) + Compiler.eraseInst(mod, other_func) end tsm = move_to_threadsafe(mod) diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index 6615b6bd40..cde5d2cade 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -313,6 +313,14 @@ function reinsert_gcmarker!(func, PB=nothing) end end +function eraseInst(bb, inst) + @static if isdefined(LLVM, Symbol("erase!")) + LLVM.erase!(inst) + else + unsafe_delete!(bb, inst) + end +end + function unique_gcmarker!(func) entry_bb = first(blocks(func)) pgcstack_func = declare_pgcstack!(LLVM.parent(func)) @@ -327,7 +335,7 @@ function unique_gcmarker!(func) for i in 2:length(found) LLVM.replace_uses!(found[i], found[1]) ops = LLVM.collect(operands(found[i])) - Base.unsafe_delete!(entry_bb, found[i]) + eraseInst(entry_bb, found[i]) end end return nothing diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 51aeacf675..3df37be117 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -112,7 +112,7 @@ function restore_lookups(mod::LLVM.Module) if haskey(functions(mod), k) f = functions(mod)[k] replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstIntToPtr(ConstantInt(T_size_t, convert(UInt, v)), value_type(f)))) - unsafe_delete!(mod, f) + eraseInst(mod, f) end end end @@ -272,7 +272,7 @@ function check_ir!(job, errors, mod::LLVM.Module) mfn = LLVM.API.LLVMAddFunction(mod, "malloc", LLVM.FunctionType(ptr8, parameters(prev_ft))) replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstPointerCast(mfn, value_type(f)))) - unsafe_delete!(mod, f) + eraseInst(mod, f) end rewrite_ccalls!(mod) for f in collect(functions(mod)) From 5a5beea54eabc9117a371397847dfaa62cd6c161 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 19 Sep 2024 14:47:17 -0500 Subject: [PATCH 296/495] runtime activity lookup on val not type (#1862) * runtime activity lookup on val not type * fix --------- Co-authored-by: William Moses --- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/rules.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 2d39f92f45..0b688f27a8 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.0" +version = "0.8.1" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index d4469e9793..3da3e318a7 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -293,6 +293,6 @@ Mark a particular type `Ty` as always being inactive. """ inactive_type(::Type) = false -@inline EnzymeCore.set_runtime_activity(::M, ::Config) where {M<:Mode, Config <: Union{FwdConfig, RevConfig}} = EnzymeCore.set_runtime_activity(M, runtime_activity(Config)) +@inline EnzymeCore.set_runtime_activity(mode::M, config::Config) where {M<:Mode, Config <: Union{FwdConfig, RevConfig}} = EnzymeCore.set_runtime_activity(mode, runtime_activity(config)) end # EnzymeRules From f14bd4a6bb3de73c205ca181bf8f69d71365822c Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 19 Sep 2024 20:49:22 -0500 Subject: [PATCH 297/495] Fix jac nout (#1864) --- Project.toml | 2 +- src/Enzyme.jl | 24 +++++++++++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 9cf8028a76..3c93057f90 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.1" +version = "0.13.2" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 985e8deeea..a439cf430c 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1417,8 +1417,8 @@ end jacobian(::ReverseMode, f, x) Compute the jacobian of a array-output function `f` using (potentially vector) -reverse mode. The `chunk` argument denotes the chunk size to use and `n_outs` -denotes the shape of the array returned by `f`. +reverse mode. The `chunk` argument optionally denotes the chunk size to use and +`n_outs` optionally denotes the shape of the array returned by `f` (e.g `size(f(x))`). Example: @@ -1434,12 +1434,30 @@ jacobian(Reverse, f, [2.0, 3.0, 4.0]) ```jldoctest f(x) = [ x[1] * x[2], x[2] + x[3] ] +grad = jacobian(ReverseWithPrimal, f, [2.0, 3.0, 4.0]) + +# output +(derivs = ([3.0 2.0 0.0; 0.0 1.0 1.0],), val = [6.0, 7.0]) +``` + +```jldoctest +f(x) = [ x[1] * x[2], x[2] + x[3] ] + grad = jacobian(Reverse, f, [2.0, 3.0, 4.0], n_outs=Val((2,))) # output ([3.0 2.0 0.0; 0.0 1.0 1.0],) ``` +```jldoctest +f(x) = [ x[1] * x[2], x[2] + x[3] ] + +grad = jacobian(ReverseWithPrimal, f, [2.0, 3.0, 4.0], n_outs=Val((2,))) + +# output +(derivs = ([3.0 2.0 0.0; 0.0 1.0 1.0],), val = [6.0, 7.0]) +``` + This function will return an AbstractArray whose shape is `(size(output)..., size(input)...)`. No guarantees are presently made about the type of the AbstractArray returned by this function (which may or may not be the same as the input AbstractArray if provided). @@ -1573,7 +1591,7 @@ this function will retun an AbstractArray of shape `size(output)` of values of t end if ReturnPrimal # TODO optimize away redundant fwd pass - (; derivs=res, val=if f isa Enzyme.Const + (; derivs=(res,), val=if f isa Enzyme.Const f.val(x) else f(x) From 00037e7ff8fb32f36691bbdba5ce8dc251fe2dec Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 21 Sep 2024 02:16:23 -0500 Subject: [PATCH 298/495] Cleanup (#1872) * Return type config * cleanup * Update runtests.jl --------- Co-authored-by: William Moses --- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/rules.jl | 15 ++---- lib/EnzymeTestUtils/src/generate_tangent.jl | 10 +--- lib/EnzymeTestUtils/src/test_reverse.jl | 7 +-- lib/EnzymeTestUtils/test/test_forward.jl | 53 +++++++-------------- src/rules/customrules.jl | 8 ++-- test/runtests.jl | 15 +++--- test/threads.jl | 5 -- 8 files changed, 35 insertions(+), 80 deletions(-) diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 0b688f27a8..37ddaf6457 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.1" +version = "0.8.2" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 3da3e318a7..a7563a2ef7 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -77,17 +77,21 @@ const RevConfigWidth{Width} = RevConfig{<:Any,<:Any, Width} @inline runtime_activity(::RevConfig{<:Any, <:Any, <:Any, <:Any, RuntimeActivity}) where RuntimeActivity = RuntimeActivity """ + primal_type(::FwdConfig, ::Type{<:Annotation{RT}}) primal_type(::RevConfig, ::Type{<:Annotation{RT}}) Compute the exepcted primal return type given a reverse mode config and return activity """ +@inline primal_type(config::FwdConfig, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing @inline primal_type(config::RevConfig, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing """ + shadow_type(::FwdConfig, ::Type{<:Annotation{RT}}) shadow_type(::RevConfig, ::Type{<:Annotation{RT}}) Compute the exepcted shadow return type given a reverse mode config and return activity """ +@inline shadow_type(config::FwdConfig, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing @inline shadow_type(config::RevConfig, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing """ @@ -191,9 +195,6 @@ function isapplicable(@nospecialize(f), @nospecialize(TT); caller::Union{Nothing,Core.MethodInstance}=nothing) tt = Base.to_tuple_type(TT) sig = Base.signature_type(f, tt) - @static if VERSION < v"1.7.0" - return !isempty(Base._methods_by_ftype(sig, -1, world)) - end mt = ccall(:jl_method_table_for, Any, (Any,), sig) mt isa Core.MethodTable || return false if method_table === nothing @@ -234,14 +235,6 @@ function add_mt_backedge!(caller::Core.MethodInstance, mt::Core.MethodTable, @no return nothing end -function issupported() - @static if VERSION < v"1.7.0" - return false - else - return true - end -end - """ inactive(func::typeof(f), args...) diff --git a/lib/EnzymeTestUtils/src/generate_tangent.jl b/lib/EnzymeTestUtils/src/generate_tangent.jl index d774036e7e..91822a509f 100644 --- a/lib/EnzymeTestUtils/src/generate_tangent.jl +++ b/lib/EnzymeTestUtils/src/generate_tangent.jl @@ -60,14 +60,8 @@ end # get around the constructors and make the type directly # Note this is moderately evil accessing julia's internals -if VERSION >= v"1.3" - @generated function _force_construct(T, args...) - return Expr(:splatnew, :T, :args) - end -else - @generated function _force_construct(T, args...) - return Expr(:new, :T, Any[:(args[$i]) for i in 1:length(args)]...) - end +@generated function _force_construct(T, args...) + return Expr(:splatnew, :T, :args) end function _construct(T, args...) diff --git a/lib/EnzymeTestUtils/src/test_reverse.jl b/lib/EnzymeTestUtils/src/test_reverse.jl index 6c20aebb7a..543f5de699 100644 --- a/lib/EnzymeTestUtils/src/test_reverse.jl +++ b/lib/EnzymeTestUtils/src/test_reverse.jl @@ -7,12 +7,7 @@ for N in 1:30 eval(quote function call_with_kwargs(fkwargs::NT, f::FT, $(argexprs...)) where {NT, FT} Base.@_inline_meta - @static if VERSION ≤ v"1.8" - # callsite inline syntax unsupported in <= 1.8 - f($(argexprs...); fkwargs...) - else - @inline f($(argexprs...); fkwargs...) - end + @inline f($(argexprs...); fkwargs...) end end) end diff --git a/lib/EnzymeTestUtils/test/test_forward.jl b/lib/EnzymeTestUtils/test/test_forward.jl index 24de5b2f44..5f8e5e7c6c 100644 --- a/lib/EnzymeTestUtils/test/test_forward.jl +++ b/lib/EnzymeTestUtils/test/test_forward.jl @@ -87,9 +87,6 @@ end elseif TT <: NamedTuple x = (a=randn(T), b=randn(T)) else # TT <: TestStruct - if VERSION <= v"1.8" && Tx == BatchDuplicated - continue - end x = TestStruct(randn(T, 5), randn(T)) end atol = rtol = sqrt(eps(real(T))) @@ -117,38 +114,26 @@ end a = randn(T) atol = rtol = sqrt(eps(real(T))) - if VERSION < v"1.8" && ( - Tret <: BatchDuplicated || - Tx <: BatchDuplicated || - Ta <: BatchDuplicated - ) - @test !fails() do - test_forward(f_multiarg, Tret, (x, Tx), (a, Ta); atol, rtol) - end skip = true - else - @test !fails() do - test_forward(f_multiarg, Tret, (x, Tx), (a, Ta); atol, rtol) - end broken = ( - VERSION < v"1.8" && Tx <: Const && !(Ta <: Const) && T <: Complex - ) - end + @test !fails() do + test_forward(f_multiarg, Tret, (x, Tx), (a, Ta); atol, rtol) + end end end - VERSION >= v"1.8" && @testset "structured array inputs/outputs" begin - @testset for Tret in (Const, Duplicated, BatchDuplicated), - Tx in (Const, Duplicated, BatchDuplicated), - T in (Float32, Float64, ComplexF32, ComplexF64) + @testset "structured array inputs/outputs" begin + @testset for Tret in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated), + T in (Float32, Float64, ComplexF32, ComplexF64) - # if some are batch, none must be duplicated - are_activities_compatible(Tret, Tx) || continue + # if some are batch, none must be duplicated + are_activities_compatible(Tret, Tx) || continue - x = Hermitian(randn(T, 5, 5)) + x = Hermitian(randn(T, 5, 5)) - atol = rtol = sqrt(eps(real(T))) - test_forward(f_structured_array, Tret, (x, Tx); atol, rtol) - end - end + atol = rtol = sqrt(eps(real(T))) + test_forward(f_structured_array, Tret, (x, Tx); atol, rtol) + end + end @testset "equivalent arrays in output" begin function f(x) @@ -197,7 +182,7 @@ end atol = rtol = sqrt(eps(real(T))) @test !fails() do test_forward(f_mut_fwd!, Tret, (y, Ty), (x, Tx), (a, Ta); atol, rtol, runtime_activity=true) - end skip = (VERSION < v"1.8" && T <: Complex) + end end end @@ -230,13 +215,7 @@ end atol = rtol = sqrt(eps(real(T))) @test !fails() do test_forward((c, Tc), Tret, (y, Ty); atol, rtol) - end skip = ( - VERSION < v"1.8" && ( - Tret <: BatchDuplicated || - Tc <: BatchDuplicated || - Ty <: BatchDuplicated - ) - ) + end end end end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index e0eae36e4d..08cd15facb 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -497,7 +497,7 @@ end if RT <: Const if needsPrimal if RealRt != fwd_RT - emit_error(B, orig, "Enzyme: incorrect return type of const primal-only forward custom rule - "*(string(RT))*" "*string(activity)*" want just return type "*string(RealRt)*" found "*string(fwd_RT)) + emit_error(B, orig, "Enzyme: incorrect return type of const primal-only forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just return type "*string(RealRt)*" found "*string(fwd_RT)) return false end if get_return_info(RealRt)[2] !== nothing @@ -508,7 +508,7 @@ end end else if Nothing != fwd_RT - emit_error(B, orig, "Enzyme: incorrect return type of const no-primal forward custom rule - "*(string(RT))*" "*string(activity)*" want just return type Nothing found "*string(fwd_RT)) + emit_error(B, orig, "Enzyme: incorrect return type of const no-primal forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just return type Nothing found "*string(fwd_RT)) return false end end @@ -519,7 +519,7 @@ end ST = NTuple{Int(width), ST} end if ST != fwd_RT - emit_error(B, orig, "Enzyme: incorrect return type of shadow-only forward custom rule - "*(string(RT))*" "*string(activity)*" want just shadow type "*string(ST)*" found "*string(fwd_RT)) + emit_error(B, orig, "Enzyme: incorrect return type of shadow-only forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just shadow type "*string(ST)*" found "*string(fwd_RT)) return false end if get_return_info(RealRt)[2] !== nothing @@ -539,7 +539,7 @@ end BatchDuplicated{RealRt, Int(width)} end if ST != fwd_RT - emit_error(B, orig, "Enzyme: incorrect return type of prima/shadow forward custom rule - "*(string(RT))*" "*string(activity)*" want just shadow type "*string(ST)*" found "*string(fwd_RT)) + emit_error(B, orig, "Enzyme: incorrect return type of prima/shadow forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just shadow type "*string(ST)*" found "*string(fwd_RT)) return false end if get_return_info(RealRt)[2] !== nothing diff --git a/test/runtests.jl b/test/runtests.jl index d99a28832b..573140f2c2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -85,14 +85,13 @@ end include("abi.jl") include("typetree.jl") -@static if Enzyme.EnzymeRules.issupported() - include("rules.jl") - include("rrules.jl") - include("kwrules.jl") - include("kwrrules.jl") - include("internal_rules.jl") - include("ruleinvalidation.jl") -end +include("rules.jl") +include("rrules.jl") +include("kwrules.jl") +include("kwrrules.jl") +include("internal_rules.jl") +include("ruleinvalidation.jl") + @static if !Sys.iswindows() include("blas.jl") end diff --git a/test/threads.jl b/test/threads.jl index 6899d8d2d6..9a06869c88 100644 --- a/test/threads.jl +++ b/test/threads.jl @@ -73,14 +73,9 @@ end out = [1.0, 2.0] dout = [1.0, 1.0] -@static if VERSION < v"1.8" - # GPUCompiler causes a stack overflow due to https://github.com/JuliaGPU/GPUCompiler.jl/issues/587 - # @test_throws AssertionError autodiff(Reverse, f_multi, Const, Duplicated(out, dout), Active(2.0)) -else res = autodiff(Reverse, f_multi, Const, Duplicated(out, dout), Active(2.0)) @test res[1][2] ≈ 2.0 end -end @testset "Closure-less threads $(Threads.nthreads())" begin function bf(i, x) From 0d6fe67ff400218d24d8c4aee9591852d8f90710 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 21 Sep 2024 14:48:31 -0500 Subject: [PATCH 299/495] fix (#1877) --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index be4679d263..07cbdeb65a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5746,7 +5746,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end end - if !(haskey(functions(mod), k_name) || has_custom_rule) + if !haskey(functions(mod), k_name) continue end From 29ed385d498a70b8d41da9d88a366707c7263388 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 21 Sep 2024 15:38:09 -0500 Subject: [PATCH 300/495] Try fixing buildkite (#1843) * Try fixing buildkite * Update pipeline.yml * Update pipeline.yml * Update pipeline.yml --- .buildkite/pipeline.yml | 67 ++++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 6de936a558..1a9f70d04c 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -15,10 +15,10 @@ steps: commands: | echo "--- Setup Julia packages" julia --color=yes -e ' - import Pkg - Pkg.develop(; path = pwd()) - Pkg.develop(; path = joinpath(pwd(), "lib", "EnzymeCore")) - Pkg.develop(; name = "CUDA")' || exit 3 + using Pkg + pkgs = [PackageSpec(; path) for path in (".", "lib/EnzymeCore", "lib/EnzymeTestUtils")] + push!(pkgs, PackageSpec(; name="CUDA")) + Pkg.develop(pkgs)' || exit 3 echo "+++ Run tests" julia --color=yes test/cuda.jl @@ -41,40 +41,39 @@ steps: commands: | echo "--- Setup Julia packages" julia --color=yes -e ' - import Pkg - Pkg.develop(; path = pwd()) - Pkg.develop(; path = joinpath(pwd(), "lib", "EnzymeCore")) - Pkg.develop(; name = "AMDGPU")' || exit 3 + using Pkg + pkgs = [PackageSpec(; path) for path in (".", "lib/EnzymeCore", "lib/EnzymeTestUtils")] + push!(pkgs, PackageSpec(; name="AMDGPU")) + Pkg.develop(pkgs)' || exit 3 echo "+++ Run tests" julia --color=yes test/amdgpu.jl env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager - # - label: "Metal Julia v{{matrix.version}}" - # matrix: - # setup: - # version: - # - "1.8" - # - "1.9" - # plugins: - # - JuliaCI/julia#v1: - # version: "{{matrix.version}}" - # agents: - # queue: "juliaecosystem" - # os: "macos" - # arch: "aarch64" - # if: build.message !~ /\[skip tests\]/ - # timeout_in_minutes: 60 - # commands: | - # echo "--- Setup Julia packages" - # julia --color=yes -e ' - # import Pkg - # Pkg.develop(; path = pwd()) - # Pkg.develop(; path = joinpath(pwd(), "lib", "EnzymeCore")) - # Pkg.develop(; name = "Metal")' || exit 3 + - label: "Metal Julia v{{matrix.version}}" + matrix: + setup: + version: + - "1.10" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.version}}" + agents: + queue: "juliaecosystem" + os: "macos" + arch: "aarch64" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + commands: | + echo "--- Setup Julia packages" + julia --color=yes -e ' + using Pkg + pkgs = [PackageSpec(; path) for path in (".", "lib/EnzymeCore", "lib/EnzymeTestUtils")] + push!(pkgs, PackageSpec(; name="Metal")) + Pkg.develop(pkgs)' || exit 3 - # echo "+++ Run tests" - # julia --color=yes test/metal.jl - # env: - # JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager + echo "+++ Run tests" + julia --color=yes test/metal.jl + env: + JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager From a08b81033b7f1a335807d56ef3cd2251781b9ad5 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Sat, 21 Sep 2024 16:40:58 -0400 Subject: [PATCH 301/495] Remove deprecated UnionAll Vararg (#1859) * Remove deprecated UnionAll Vararg * Replace remaining uses of Vararg in docstrings `...` is better understood by more users and easier on the eyes --- src/Enzyme.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index a439cf430c..7a864aa51a 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -170,7 +170,7 @@ end end """ - autodiff(::ReverseMode, f, Activity, args::Vararg{<:Annotation, Nargs}) + autodiff(::ReverseMode, f, Activity, args::Annotation...) Auto-differentiate function `f` at arguments `args` using reverse mode. @@ -317,7 +317,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) end """ - autodiff(mode::Mode, f, ::Type{A}, args::Vararg{Annotation, Nargs}) + autodiff(mode::Mode, f, ::Type{A}, args::Annotation...) Like [`autodiff`](@ref) but will try to extend f to an annotation, if needed. """ @@ -345,7 +345,7 @@ Like [`autodiff`](@ref) but will try to guess the activity of the return value. end """ - autodiff(::ForwardMode, f, Activity, args::Vararg{<:Annotation, Nargs}) + autodiff(::ForwardMode, f, Activity, args::Annotation...) Auto-differentiate function `f` at arguments `args` using forward mode. @@ -431,7 +431,7 @@ f(x) = x*x end """ - autodiff_deferred(::ReverseMode, f, Activity, args::Vararg{<:Annotation, Nargs}) + autodiff_deferred(::ReverseMode, f, Activity, args::Annotation...) Same as [`autodiff`](@ref) but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. @@ -472,9 +472,9 @@ code, as well as high-order differentiation. end """ - autodiff_deferred(::ForwardMode, f, Activity, args::Vararg{<:Annotation, Nargs}) + autodiff_deferred(::ForwardMode, f, Activity, args::Annotation...) -Same as `autodiff(::ForwardMode, f, Activity, args)` but uses deferred compilation to support usage in GPU +Same as `autodiff(::ForwardMode, f, Activity, args...)` but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ @inline function autodiff_deferred(::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {ReturnPrimal, FA<:Annotation, A<:Annotation, Nargs, ABI, ErrIfFuncWritten, RuntimeActivity} @@ -532,7 +532,7 @@ code, as well as high-order differentiation. end """ - autodiff_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation, Nargs}) + autodiff_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Type{<:Annotation}...) Provide the split forward and reverse pass functions for annotated function type ftype when called with args of type `argtypes` when using reverse mode. @@ -628,7 +628,7 @@ end end """ - autodiff_thunk(::ForwardMode, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs}) + autodiff_thunk(::ForwardMode, ftype, Activity, argtypes::Type{<:Annotation}...) Provide the thunk forward mode function for annotated function type ftype when called with args of type `argtypes`. @@ -798,7 +798,7 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType end """ - autodiff_deferred_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs}) + autodiff_deferred_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Type{<:Annotation}...) Provide the split forward and reverse pass functions for annotated function type ftype when called with args of type `argtypes` when using reverse mode. @@ -1067,7 +1067,7 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) ``` """ -@generated function gradient(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::ty_0, args::Vararg{<:Any, N}) where {F, ty_0, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten, N} +@generated function gradient(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::ty_0, args::Vararg{Any, N}) where {F, ty_0, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten, N} toemit= Expr[quote act_0 = !(x isa Enzyme.Const) && Compiler.active_reg_inner(Core.Typeof(x), #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState end] From 0c36c5af6e7ed02a25e5bb7485d6477f97ca7eed Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 21 Sep 2024 21:16:38 -0500 Subject: [PATCH 302/495] Use correct triple (#1878) * Use correct triple * fix * fix * fix --- src/compiler.jl | 4 +-- src/compiler/orcv2.jl | 63 +++++-------------------------------------- 2 files changed, 9 insertions(+), 58 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 07cbdeb65a..ce51a6e7f5 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3293,9 +3293,9 @@ end # Define EnzymeTarget Base.@kwdef struct EnzymeTarget <: AbstractCompilerTarget end -GPUCompiler.llvm_triple(::EnzymeTarget) = Sys.MACHINE -# GPUCompiler.llvm_datalayout(::EnzymeTarget) = nothing +GPUCompiler.llvm_triple(::EnzymeTarget) = LLVM.triple(JIT.get_jit()) +GPUCompiler.llvm_datalayout(::EnzymeTarget) = LLVM.datalayout(JIT.get_jit()) function GPUCompiler.llvm_machine(::EnzymeTarget) return JIT.get_tm() diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index 78ff089e7d..482a961b52 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -9,24 +9,12 @@ import GPUCompiler import ..Compiler import ..Compiler: API, cpu_name, cpu_features -@inline function use_ojit() - return !Sys.iswindows() -end - export get_trampoline -@static if use_ojit() - struct CompilerInstance - jit::LLVM.JuliaOJIT - lctm::Union{LLVM.LazyCallThroughManager, Nothing} - ism::Union{LLVM.IndirectStubsManager, Nothing} - end -else - struct CompilerInstance - jit::LLVM.LLJIT - lctm::Union{LLVM.LazyCallThroughManager, Nothing} - ism::Union{LLVM.IndirectStubsManager, Nothing} - end +struct CompilerInstance + jit::LLVM.JuliaOJIT + lctm::Union{LLVM.LazyCallThroughManager, Nothing} + ism::Union{LLVM.IndirectStubsManager, Nothing} end function LLVM.dispose(ci::CompilerInstance) @@ -44,6 +32,7 @@ const jit = Ref{CompilerInstance}() const tm = Ref{TargetMachine}() # for opt pipeline get_tm() = tm[] +get_jit() = jit[].jit function absolute_symbol_materialization(name, ptr) address = LLVM.API.LLVMOrcJITTargetAddress(reinterpret(UInt, ptr)) @@ -80,37 +69,7 @@ function __init__() LLVM.asm_verbosity!(tempTM, true) tm[] = tempTM - lljit = @static if !use_ojit() - tempTM = LLVM.JITTargetMachine(LLVM.triple(), cpu_name(), cpu_features(); optlevel) - LLVM.asm_verbosity!(tempTM, true) - - gdb = haskey(ENV, "ENABLE_GDBLISTENER") - perf = haskey(ENV, "ENABLE_JITPROFILING") - if gdb || perf - ollc = LLVM.ObjectLinkingLayerCreator() do es, triple - oll = ObjectLinkingLayer(es) - if gdb - register!(oll, GDBRegistrationListener()) - end - if perf - register!(oll, IntelJITEventListener()) - register!(oll, PerfJITEventListener()) - end - return oll - end - GC.@preserve ollc begin - builder = LLJITBuilder() - LLVM.linkinglayercreator!(builder, ollc) - tmb = TargetMachineBuilder(tempTM) - LLVM.targetmachinebuilder!(builder, tmb) - LLJIT(builder) - end - else - LLJIT(;tm=tempTM) - end - else - JuliaOJIT() - end + lljit = JuliaOJIT() jd_main = JITDylib(lljit) @@ -145,10 +104,6 @@ function __init__() end atexit() do - @static if !use_ojit() - ci = jit[] - dispose(ci) - end dispose(tm[]) end end @@ -229,11 +184,7 @@ function get_trampoline(job) tsm = move_to_threadsafe(mod) - il = @static if use_ojit() - LLVM.IRCompileLayer(lljit) - else - LLVM.IRTransformLayer(lljit) - end + il = LLVM.IRCompileLayer(lljit) LLVM.emit(il, mr, tsm) end return nothing From 6bfe8e0bf09e4bba62da490e073a95105ceed20a Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 23 Sep 2024 22:36:50 -0500 Subject: [PATCH 303/495] Cleanup absint (#1880) * Cleanup absint * cleanup * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * Update Project.toml * fix * Update runtests.jl * Update runtests.jl --- Project.toml | 2 +- src/Enzyme.jl | 1126 +++++-- src/absint.jl | 365 ++- src/api.jl | 1171 +++++-- src/compiler.jl | 5296 +++++++++++++++++++++++--------- src/compiler/interpreter.jl | 210 +- src/compiler/optimize.jl | 1304 +++++--- src/compiler/orcv2.jl | 89 +- src/compiler/passes.jl | 2 +- src/compiler/reflection.jl | 107 +- src/compiler/utils.jl | 169 +- src/compiler/validation.jl | 864 ++++-- src/gradientutils.jl | 66 +- src/internal_rules.jl | 590 ++-- src/pmap.jl | 66 +- src/rules/activityrules.jl | 70 +- src/rules/allocrules.jl | 105 +- src/rules/customrules.jl | 815 +++-- src/rules/jitrules.jl | 1476 +++++++-- src/rules/llvmrules.jl | 1082 +++++-- src/rules/parallelrules.jl | 359 ++- src/rules/typerules.jl | 18 +- src/rules/typeunstablerules.jl | 981 ++++-- src/typeanalysis.jl | 7 +- src/typetree.jl | 104 +- src/utils.jl | 64 +- test/runtests.jl | 11 + test/typetree.jl | 13 + 28 files changed, 12040 insertions(+), 4492 deletions(-) diff --git a/Project.toml b/Project.toml index 3c93057f90..5a0e192de5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.2" +version = "0.13.3" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 7a864aa51a..c99114e038 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -2,11 +2,68 @@ module Enzyme import EnzymeCore -import EnzymeCore: Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal -export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal - -import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity, within_autodiff -export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity, within_autodiff +import EnzymeCore: + Forward, + ForwardWithPrimal, + Reverse, + ReverseWithPrimal, + ReverseSplitNoPrimal, + ReverseSplitWithPrimal, + ReverseSplitModified, + ReverseSplitWidth, + ReverseMode, + ForwardMode, + ReverseHolomorphic, + ReverseHolomorphicWithPrimal +export Forward, + ForwardWithPrimal, + Reverse, + ReverseWithPrimal, + ReverseSplitNoPrimal, + ReverseSplitWithPrimal, + ReverseSplitModified, + ReverseSplitWidth, + ReverseMode, + ForwardMode, + ReverseHolomorphic, + ReverseHolomorphicWithPrimal + +import EnzymeCore: + Annotation, + Const, + Active, + Duplicated, + DuplicatedNoNeed, + BatchDuplicated, + BatchDuplicatedNoNeed, + ABI, + DefaultABI, + FFIABI, + InlineABI, + NonGenABI, + set_err_if_func_written, + clear_err_if_func_written, + set_abi, + set_runtime_activity, + clear_runtime_activity, + within_autodiff +export Annotation, + Const, + Active, + Duplicated, + DuplicatedNoNeed, + BatchDuplicated, + BatchDuplicatedNoNeed, + DefaultABI, + FFIABI, + InlineABI, + NonGenABI, + set_err_if_func_written, + clear_err_if_func_written, + set_abi, + set_runtime_activity, + clear_runtime_activity, + within_autodiff import EnzymeCore: BatchDuplicatedFunc export BatchDuplicatedFunc @@ -14,11 +71,24 @@ export BatchDuplicatedFunc import EnzymeCore: MixedDuplicated, BatchMixedDuplicated export MixedDuplicated, BatchMixedDuplicated -import EnzymeCore: batch_size, get_func +import EnzymeCore: batch_size, get_func export batch_size, get_func -import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero, make_zero! -export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero, make_zero! +import EnzymeCore: + autodiff, + autodiff_deferred, + autodiff_thunk, + autodiff_deferred_thunk, + tape_type, + make_zero, + make_zero! +export autodiff, + autodiff_deferred, + autodiff_thunk, + autodiff_deferred_thunk, + tape_type, + make_zero, + make_zero! export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient! export markType, batch_size, onehot, chunkedonehot @@ -58,7 +128,7 @@ import .Compiler: CompilationException end end -@inline function any_active(args::Vararg{Annotation, N}) where N +@inline function any_active(args::Vararg{Annotation,N}) where {N} any(ntuple(Val(N)) do i Base.@_inline_meta arg = @inbounds args[i] @@ -74,18 +144,22 @@ end end) end -@inline function vaTypeof(args::Vararg{Any, N}) where N - return Tuple{(ntuple(Val(N)) do i - Base.@_inline_meta - Core.Typeof(args[i]) - end)...} +@inline function vaTypeof(args::Vararg{Any,N}) where {N} + return Tuple{( + ntuple(Val(N)) do i + Base.@_inline_meta + Core.Typeof(args[i]) + end + )...} end -@inline function vaEltypes(args::Type{Ty}) where {Ty <: Tuple} - return Tuple{(ntuple(Val(length(Ty.parameters))) do i - Base.@_inline_meta - eltype(Ty.parameters[i]) - end)...} +@inline function vaEltypes(args::Type{Ty}) where {Ty<:Tuple} + return Tuple{( + ntuple(Val(length(Ty.parameters))) do i + Base.@_inline_meta + eltype(Ty.parameters[i]) + end + )...} end @inline function same_or_one_helper(current, next) @@ -99,22 +173,28 @@ end end @inline same_or_one_rec(current) = current -@inline same_or_one_rec(current, arg::BatchMixedDuplicated{T, N}, args...) where {T,N} = - same_or_one_rec(same_or_one_helper(current, N), args...) -@inline same_or_one_rec(current, arg::Type{BatchMixedDuplicated{T, N}}, args...) where {T,N} = - same_or_one_rec(same_or_one_helper(current, N), args...) -@inline same_or_one_rec(current, arg::BatchDuplicatedFunc{T, N}, args...) where {T,N} = - same_or_one_rec(same_or_one_helper(current, N), args...) -@inline same_or_one_rec(current, arg::Type{BatchDuplicatedFunc{T, N}}, args...) where {T,N} = - same_or_one_rec(same_or_one_helper(current, N), args...) -@inline same_or_one_rec(current, arg::BatchDuplicated{T, N}, args...) where {T,N} = - same_or_one_rec(same_or_one_helper(current, N), args...) -@inline same_or_one_rec(current, arg::Type{BatchDuplicated{T, N}}, args...) where {T,N} = - same_or_one_rec(same_or_one_helper(current, N), args...) -@inline same_or_one_rec(current, arg::BatchDuplicatedNoNeed{T, N}, args...) where {T,N} = - same_or_one_rec(same_or_one_helper(current, N), args...) -@inline same_or_one_rec(current, arg::Type{BatchDuplicatedNoNeed{T, N}}, args...) where {T,N} = - same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec(current, arg::BatchMixedDuplicated{T,N}, args...) where {T,N} = + same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec( + current, + arg::Type{BatchMixedDuplicated{T,N}}, + args..., +) where {T,N} = same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec(current, arg::BatchDuplicatedFunc{T,N}, args...) where {T,N} = + same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec(current, arg::Type{BatchDuplicatedFunc{T,N}}, args...) where {T,N} = + same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec(current, arg::BatchDuplicated{T,N}, args...) where {T,N} = + same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec(current, arg::Type{BatchDuplicated{T,N}}, args...) where {T,N} = + same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec(current, arg::BatchDuplicatedNoNeed{T,N}, args...) where {T,N} = + same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec( + current, + arg::Type{BatchDuplicatedNoNeed{T,N}}, + args..., +) where {T,N} = same_or_one_rec(same_or_one_helper(current, N), args...) @inline same_or_one_rec(current, arg, args...) = same_or_one_rec(current, args...) @inline function same_or_one(defaultVal, args...) @@ -127,7 +207,7 @@ end end -@inline function refn_seed(x::T) where T +@inline function refn_seed(x::T) where {T} if T <: Complex return conj(x) / 2 else @@ -135,7 +215,7 @@ end end end -@inline function imfn_seed(x::T) where T +@inline function imfn_seed(x::T) where {T} if T <: Complex return im * conj(x) / 2 else @@ -143,7 +223,11 @@ end end end -@inline function seed_complex_args(seen, seen2, args::Vararg{Annotation, Nargs}) where {Nargs} +@inline function seed_complex_args( + seen, + seen2, + args::Vararg{Annotation,Nargs}, +) where {Nargs} return ntuple(Val(Nargs)) do i Base.@_inline_meta arg = args[i] @@ -151,18 +235,29 @@ end arg elseif arg isa Duplicated || arg isa DuplicatedNoNeed RT = eltype(Core.Typeof(arg)) - BatchDuplicated(arg.val, (arg.dval, make_zero(RT, seen, arg.dval), make_zero(RT, seen2, arg.dval))) + BatchDuplicated( + arg.val, + (arg.dval, make_zero(RT, seen, arg.dval), make_zero(RT, seen2, arg.dval)), + ) else - throw(ErrorException("Active Complex return does not yet support batching in combined reverse mode")) + throw( + ErrorException( + "Active Complex return does not yet support batching in combined reverse mode", + ), + ) end end end -@inline function fuse_complex_results(results, args::Vararg{Annotation, Nargs}) where {Nargs} +@inline function fuse_complex_results(results, args::Vararg{Annotation,Nargs}) where {Nargs} ntuple(Val(Nargs)) do i Base.@_inline_meta if args[i] isa Active - Compiler.recursive_add(Compiler.recursive_add(results[1][i][1], results[1][i][2], refn_seed), results[1][i][3], imfn_seed) + Compiler.recursive_add( + Compiler.recursive_add(results[1][i][1], results[1][i][2], refn_seed), + results[1][i][3], + imfn_seed, + ) else results[1][i] end @@ -229,16 +324,30 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) [`Active`](@ref) will automatically convert plain integers to floating point values, but cannot do so for integer values in tuples and structs. """ -@inline function autodiff(rmode::ReverseMode{ReturnPrimal, RuntimeActivity,RABI,Holomorphic, ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RuntimeActivity, RABI<:ABI,Holomorphic, Nargs, ErrIfFuncWritten} - tt′ = vaTypeof(args...) +@inline function autodiff( + rmode::ReverseMode{ReturnPrimal,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}, + f::FA, + ::Type{A}, + args::Vararg{Annotation,Nargs}, +) where { + FA<:Annotation, + A<:Annotation, + ReturnPrimal, + RuntimeActivity, + RABI<:ABI, + Holomorphic, + Nargs, + ErrIfFuncWritten, +} + tt′ = vaTypeof(args...) width = same_or_one(1, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end - ModifiedBetween = Val(falses_from_args(Nargs+1)) + ModifiedBetween = Val(falses_from_args(Nargs + 1)) - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} + tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} FTy = Core.Typeof(f.val) @@ -251,12 +360,25 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) rt = if A isa UnionAll Compiler.primal_return_type(rmode, Val(codegen_world_age(FTy, tt)), FTy, tt) else - eltype(A) + eltype(A) end if A <: Active if (!allocatedinline(rt) || rt isa Union) && rt != Union{} - forward, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + forward, adjoint = Enzyme.Compiler.thunk( + opt_mi, + FA, + Duplicated{rt}, + tt′, + Val(API.DEM_ReverseModeGradient), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(true), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# res = forward(f, args...) tape = res[1] if ReturnPrimal @@ -265,7 +387,11 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) return adjoint(f, args..., tape) end end - elseif A <: Duplicated || A<: DuplicatedNoNeed || A <: BatchDuplicated || A<: BatchDuplicatedNoNeed || A <: BatchDuplicatedFunc + elseif A <: Duplicated || + A <: DuplicatedNoNeed || + A <: BatchDuplicated || + A <: BatchDuplicatedNoNeed || + A <: BatchDuplicatedFunc throw(ErrorException("Duplicated Returns not yet handled")) end @@ -277,16 +403,40 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) f = if f isa Const || f isa Active f elseif f isa Duplicated || f isa DuplicatedNoNeed - BatchDuplicated(f.val, (f.dval, make_zero(typeof(f), seen, f.dval), make_zero(typeof(f), seen2, f.dval))) + BatchDuplicated( + f.val, + ( + f.dval, + make_zero(typeof(f), seen, f.dval), + make_zero(typeof(f), seen2, f.dval), + ), + ) else - throw(ErrorException("Active Complex return does not yet support batching in combined reverse mode")) + throw( + ErrorException( + "Active Complex return does not yet support batching in combined reverse mode", + ), + ) end width = same_or_one(3, args...) args = seed_complex_args(seen, seen2, args...) - tt′ = vaTypeof(args...) - - thunk = Enzyme.Compiler.thunk(opt_mi, typeof(f), A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + tt′ = vaTypeof(args...) + + thunk = Enzyme.Compiler.thunk( + opt_mi, + typeof(f), + A, + tt′, + Val(API.DEM_ReverseModeCombined), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# results = thunk(f, args..., (rt(0), rt(1), rt(im))) @@ -305,10 +455,27 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) return (fused, results[2:end]...) end - throw(ErrorException("Reverse-mode Active Complex return is ambiguous and requires more information to specify the desired result. See https://enzyme.mit.edu/julia/stable/faq/#Complex-numbers for more details.")) - end - - thunk = Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + throw( + ErrorException( + "Reverse-mode Active Complex return is ambiguous and requires more information to specify the desired result. See https://enzyme.mit.edu/julia/stable/faq/#Complex-numbers for more details.", + ), + ) + end + + thunk = Enzyme.Compiler.thunk( + opt_mi, + FA, + A, + tt′, + Val(API.DEM_ReverseModeCombined), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# if A <: Active args = (args..., Compiler.default_adjoint(rt)) @@ -321,10 +488,19 @@ end Like [`autodiff`](@ref) but will try to extend f to an annotation, if needed. """ -@inline function autodiff(mode::CMode, f::F, args::Vararg{Annotation, Nargs}) where {F, CMode<:Mode, Nargs} +@inline function autodiff( + mode::CMode, + f::F, + args::Vararg{Annotation,Nargs}, +) where {F,CMode<:Mode,Nargs} autodiff(EnzymeCore.set_err_if_func_written(mode), Const(f), args...) end -@inline function autodiff(mode::CMode, f::F, ::Type{RT}, args::Vararg{Annotation, Nargs}) where {F, RT<:Annotation, CMode<:Mode, Nargs} +@inline function autodiff( + mode::CMode, + f::F, + ::Type{RT}, + args::Vararg{Annotation,Nargs}, +) where {F,RT<:Annotation,CMode<:Mode,Nargs} autodiff(EnzymeCore.set_err_if_func_written(mode), Const(f), RT, args...) end @@ -333,14 +509,23 @@ end Like [`autodiff`](@ref) but will try to guess the activity of the return value. """ -@inline function autodiff(mode::CMode, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, CMode<:Mode, Nargs} - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - rt = if mode isa ReverseMode - Compiler.primal_return_type(mode, Val(codegen_world_age(eltype(FA), tt)), eltype(FA), tt) +@inline function autodiff( + mode::CMode, + f::FA, + args::Vararg{Annotation,Nargs}, +) where {FA<:Annotation,CMode<:Mode,Nargs} + tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} + rt = if mode isa ReverseMode + Compiler.primal_return_type( + mode, + Val(codegen_world_age(eltype(FA), tt)), + eltype(FA), + tt, + ) else Core.Compiler.return_type(f.val, tt) end - A = guess_activity(rt, mode) + A = guess_activity(rt, mode) autodiff(mode, f, A, args...) end @@ -384,11 +569,19 @@ f(x) = x*x (6.28,) ``` """ -@inline function autodiff(::ForwardMode{ReturnPrimal, RABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {ReturnPrimal, RABI <: ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} +@inline function autodiff( + ::ForwardMode{ReturnPrimal,RABI,ErrIfFuncWritten,RuntimeActivity}, + f::FA, + ::Type{A}, + args::Vararg{Annotation,Nargs}, +) where { + FA<:Annotation, + A<:Annotation, +} where {ReturnPrimal,RABI<:ABI,Nargs,ErrIfFuncWritten,RuntimeActivity} if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end - tt′ = vaTypeof(args...) + tt′ = vaTypeof(args...) width = same_or_one(1, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) @@ -397,27 +590,31 @@ f(x) = x*x throw(ErrorException("Active Returns not allowed in forward mode")) end if A <: DuplicatedNoNeed || A <: BatchDuplicatedNoNeed - throw(ErrorException("Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)")) + throw( + ErrorException( + "Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)", + ), + ) end RT = if A <: Duplicated && width != 1 if A isa UnionAll - BatchDuplicated{T, width} where T + BatchDuplicated{T,width} where {T} else - BatchDuplicated{eltype(A), width} + BatchDuplicated{eltype(A),width} end elseif A <: DuplicatedNoNeed && width != 1 if A isa UnionAll - BatchDuplicatedNoNeed{T, width} where T + BatchDuplicatedNoNeed{T,width} where {T} else - BatchDuplicatedNoNeed{eltype(A), width} + BatchDuplicatedNoNeed{eltype(A),width} end else A end - - ModifiedBetween = Val(falses_from_args(Nargs+1)) - - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} + + ModifiedBetween = Val(falses_from_args(Nargs + 1)) + + tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} opt_mi = if RABI <: NonGenABI Compiler.fspec(eltype(FA), tt′) @@ -425,8 +622,20 @@ f(x) = x*x Val(codegen_world_age(Core.Typeof(f.val), tt)) end - thunk = Enzyme.Compiler.thunk(opt_mi, FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), - ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + thunk = Enzyme.Compiler.thunk( + opt_mi, + FA, + RT, + tt′, + Val(API.DEM_ForwardMode), + Val(width), #=Mode=# + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# thunk(f, args...) end @@ -436,16 +645,30 @@ end Same as [`autodiff`](@ref) but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ReverseMode{ReturnPrimal, RuntimeActivity, ABI,Holomorphic,ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, Nargs, ABI,Holomorphic,ErrIfFuncWritten, RuntimeActivity} - tt′ = vaTypeof(args...) +@inline function autodiff_deferred( + ::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, + f::FA, + ::Type{A}, + args::Vararg{Annotation,Nargs}, +) where { + FA<:Annotation, + A<:Annotation, + ReturnPrimal, + Nargs, + ABI, + Holomorphic, + ErrIfFuncWritten, + RuntimeActivity, +} + tt′ = vaTypeof(args...) width = same_or_one(1, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - + tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} + world = codegen_world_age(Core.Typeof(f.val), tt) - + if A isa UnionAll rt = Core.Compiler.return_type(f.val, tt) rt = A{rt} @@ -458,14 +681,31 @@ code, as well as high-order differentiation. error("Return type inferred to be Union{}. Giving up.") end - ModifiedBetween = Val(falses_from_args(Nargs+1)) - - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), UnknownTapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - - thunk = Compiler.CombinedAdjointThunk{Ptr{Cvoid}, FA, rt, tt′, width, ReturnPrimal}(adjoint_ptr) + ModifiedBetween = Val(falses_from_args(Nargs + 1)) + + adjoint_ptr = Compiler.deferred_codegen( + Val(world), + FA, + Val(tt′), + Val(rt), + Val(API.DEM_ReverseModeCombined), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + UnknownTapeType, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# + + thunk = + Compiler.CombinedAdjointThunk{Ptr{Cvoid},FA,rt,tt′,width,ReturnPrimal}(adjoint_ptr) if rt <: Active args = (args..., Compiler.default_adjoint(eltype(rt))) - elseif A <: Duplicated || A<: DuplicatedNoNeed || A <: BatchDuplicated || A<: BatchDuplicatedNoNeed + elseif A <: Duplicated || + A <: DuplicatedNoNeed || + A <: BatchDuplicated || + A <: BatchDuplicatedNoNeed throw(ErrorException("Duplicated Returns not yet handled")) end thunk(f, args...) @@ -477,37 +717,54 @@ end Same as `autodiff(::ForwardMode, f, Activity, args...)` but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {ReturnPrimal, FA<:Annotation, A<:Annotation, Nargs, ABI, ErrIfFuncWritten, RuntimeActivity} +@inline function autodiff_deferred( + ::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, + f::FA, + ::Type{A}, + args::Vararg{Annotation,Nargs}, +) where { + ReturnPrimal, + FA<:Annotation, + A<:Annotation, + Nargs, + ABI, + ErrIfFuncWritten, + RuntimeActivity, +} if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end - tt′ = vaTypeof(args...) + tt′ = vaTypeof(args...) width = same_or_one(1, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end if A <: DuplicatedNoNeed || A <: BatchDuplicatedNoNeed - throw(ErrorException("Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)")) + throw( + ErrorException( + "Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)", + ), + ) end RT = if A <: Duplicated && width != 1 if A isa UnionAll - BatchDuplicated{T, width} where T + BatchDuplicated{T,width} where {T} else - BatchDuplicated{eltype(A), width} + BatchDuplicated{eltype(A),width} end elseif A <: DuplicatedNoNeed && width != 1 if A isa UnionAll - BatchDuplicatedNoNeed{T, width} where T + BatchDuplicatedNoNeed{T,width} where {T} else - BatchDuplicatedNoNeed{eltype(A), width} + BatchDuplicatedNoNeed{eltype(A),width} end else A end - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - + tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} + world = codegen_world_age(Core.Typeof(f.val), tt) - + if RT isa UnionAll rt = Core.Compiler.return_type(f.val, tt) rt = RT{rt} @@ -524,10 +781,23 @@ code, as well as high-order differentiation. throw(ErrorException("Active Returns not allowed in forward mode")) end - ModifiedBetween = Val(falses_from_args(Nargs+1)) - - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), UnknownTapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - thunk = Compiler.ForwardModeThunk{Ptr{Cvoid}, FA, rt, tt′, width, ReturnPrimal}(adjoint_ptr) + ModifiedBetween = Val(falses_from_args(Nargs + 1)) + + adjoint_ptr = Compiler.deferred_codegen( + Val(world), + FA, + Val(tt′), + Val(rt), + Val(API.DEM_ForwardMode), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + UnknownTapeType, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# + thunk = Compiler.ForwardModeThunk{Ptr{Cvoid},FA,rt,tt′,width,ReturnPrimal}(adjoint_ptr) thunk(f, args...) end @@ -574,7 +844,31 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_thunk(rs::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT,RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} +@inline function autodiff_thunk( + rs::ReverseModeSplit{ + ReturnPrimal, + ReturnShadow, + RuntimeActivity, + Width, + ModifiedBetweenT, + RABI, + ErrIfFuncWritten, + }, + ::Type{FA}, + ::Type{A}, + args::Vararg{Type{<:Annotation},Nargs}, +) where { + FA<:Annotation, + A<:Annotation, + ReturnPrimal, + ReturnShadow, + Width, + ModifiedBetweenT, + RABI<:ABI, + Nargs, + ErrIfFuncWritten, + RuntimeActivity, +} width = if Width == 0 w = same_or_one(1, args...) if w == 0 @@ -586,13 +880,13 @@ result, ∂v, ∂A end if ModifiedBetweenT === true - ModifiedBetween = Val(falses_from_args(Nargs+1)) + ModifiedBetween = Val(falses_from_args(Nargs + 1)) else ModifiedBetween = Val(ModifiedBetweenT) end - tt = Tuple{map(eltype, args)...} - + tt = Tuple{map(eltype, args)...} + if !(A <: Const) @assert ReturnShadow end @@ -602,7 +896,20 @@ result, ∂v, ∂A else Val(codegen_world_age(eltype(FA), tt)) end - Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + Enzyme.Compiler.thunk( + opt_mi, + FA, + A, + tt′, + Val(API.DEM_ReverseModeGradient), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# end """ @@ -620,11 +927,20 @@ end ((6.2,),) ``` """ -@inline function autodiff(f::Function, m::MMode, ::Type{A}, args::Vararg{Annotation, Nargs}) where {A<:Annotation, Nargs, MMode<:Mode} - autodiff(m, f, A, args...) +@inline function autodiff( + f::Function, + m::MMode, + ::Type{A}, + args::Vararg{Annotation,Nargs}, +) where {A<:Annotation,Nargs,MMode<:Mode} + autodiff(m, f, A, args...) end -@inline function autodiff(f::Function, m::MMode, args::Vararg{Annotation, Nargs}) where {Nargs, MMode<:Mode} - autodiff(m, f, args...) +@inline function autodiff( + f::Function, + m::MMode, + args::Vararg{Annotation,Nargs}, +) where {Nargs,MMode<:Mode} + autodiff(m, f, args...) end """ @@ -671,7 +987,20 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, Duplicated, Duplicated{Float (6.28,) ``` """ -@inline function autodiff_thunk(::ForwardMode{ReturnPrimal, RABI, ErrIfFuncWritten, RuntimeActivity}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {ReturnPrimal, FA<:Annotation, A<:Annotation, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} +@inline function autodiff_thunk( + ::ForwardMode{ReturnPrimal,RABI,ErrIfFuncWritten,RuntimeActivity}, + ::Type{FA}, + ::Type{A}, + args::Vararg{Type{<:Annotation},Nargs}, +) where { + ReturnPrimal, + FA<:Annotation, + A<:Annotation, + RABI<:ABI, + Nargs, + ErrIfFuncWritten, + RuntimeActivity, +} width = same_or_one(1, A, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) @@ -680,23 +1009,64 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, Duplicated, Duplicated{Float throw(ErrorException("Active Returns not allowed in forward mode")) end if A <: DuplicatedNoNeed || A <: BatchDuplicatedNoNeed - throw(ErrorException("Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)")) + throw( + ErrorException( + "Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)", + ), + ) end - ModifiedBetween = Val(falses_from_args(Nargs+1)) + ModifiedBetween = Val(falses_from_args(Nargs + 1)) + + tt = Tuple{map(eltype, args)...} - tt = Tuple{map(eltype, args)...} - tt′ = Tuple{args...} opt_mi = if RABI <: NonGenABI Compiler.fspec(eltype(FA), tt′) else Val(codegen_world_age(eltype(FA), tt)) end - results = Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + results = Enzyme.Compiler.thunk( + opt_mi, + FA, + A, + tt′, + Val(API.DEM_ForwardMode), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# end -@inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT, RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} +@inline function tape_type( + ::ReverseModeSplit{ + ReturnPrimal, + ReturnShadow, + RuntimeActivity, + Width, + ModifiedBetweenT, + RABI, + ErrIfFuncWritten, + }, + ::Type{FA}, + ::Type{A}, + args::Vararg{Type{<:Annotation},Nargs}, +) where { + FA<:Annotation, + A<:Annotation, + ReturnPrimal, + ReturnShadow, + Width, + ModifiedBetweenT, + RABI<:ABI, + Nargs, + ErrIfFuncWritten, + RuntimeActivity, +} width = if Width == 0 w = same_or_one(1, args...) if w == 0 @@ -708,21 +1078,34 @@ end end if ModifiedBetweenT === true - ModifiedBetween = Val(falses_from_args(Nargs+1)) + ModifiedBetween = Val(falses_from_args(Nargs + 1)) else ModifiedBetween = Val(ModifiedBetweenT) end @assert ReturnShadow TT = Tuple{args...} - + primal_tt = Tuple{map(eltype, args)...} opt_mi = if RABI <: NonGenABI Compiler.fspec(eltype(FA), TT) else Val(codegen_world_age(eltype(FA), primal_tt)) end - nondef = Enzyme.Compiler.thunk(opt_mi, FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + nondef = Enzyme.Compiler.thunk( + opt_mi, + FA, + A, + TT, + Val(API.DEM_ReverseModeGradient), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# if nondef[1] isa Enzyme.Compiler.PrimalErrorThunk return Nothing else @@ -731,16 +1114,36 @@ end end end -const tape_cache = Dict{UInt, Type}() +const tape_cache = Dict{UInt,Type}() const tape_cache_lock = ReentrantLock() import .Compiler: fspec, remove_innerty, UnknownTapeType @inline function tape_type( - parent_job::Union{GPUCompiler.CompilerJob,Nothing}, ::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT, RABI}, - ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs} -) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, RuntimeActivity} + parent_job::Union{GPUCompiler.CompilerJob,Nothing}, + ::ReverseModeSplit{ + ReturnPrimal, + ReturnShadow, + RuntimeActivity, + Width, + ModifiedBetweenT, + RABI, + }, + ::Type{FA}, + ::Type{A}, + args::Vararg{Type{<:Annotation},Nargs}, +) where { + FA<:Annotation, + A<:Annotation, + ReturnPrimal, + ReturnShadow, + Width, + ModifiedBetweenT, + RABI<:ABI, + Nargs, + RuntimeActivity, +} width = if Width == 0 w = same_or_one(1, args...) if w == 0 @@ -768,12 +1171,21 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType target = Compiler.EnzymeTarget() params = Compiler.EnzymeCompilerParams( - Tuple{FA, TT.parameters...}, API.DEM_ReverseModeGradient, width, - Compiler.remove_innerty(A), true, #=abiwrap=#false, ModifiedBetweenT, - ReturnPrimal, #=ShadowInit=#false, Compiler.UnknownTapeType, RABI, #=errifwritte=#false, - RuntimeActivity + Tuple{FA,TT.parameters...}, + API.DEM_ReverseModeGradient, + width, + Compiler.remove_innerty(A), + true, + false, + ModifiedBetweenT, #=abiwrap=# + ReturnPrimal, + false, + Compiler.UnknownTapeType, + RABI, + false, #=errifwritte=# + RuntimeActivity, ) - job = Compiler.CompilerJob(mi, Compiler.CompilerConfig(target, params; kernel=false)) + job = Compiler.CompilerJob(mi, Compiler.CompilerConfig(target, params; kernel = false)) key = hash(parent_job, hash(job)) @@ -786,7 +1198,7 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType if obj === nothing Compiler.JuliaContext() do ctx - _, meta = Compiler.codegen(:llvm, job; optimize=false, parent_job) + _, meta = Compiler.codegen(:llvm, job; optimize = false, parent_job) obj = meta.TapeType tape_cache[key] = obj end @@ -841,7 +1253,33 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_deferred_thunk(mode::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT, RABI, ErrIfFuncWritten}, tt::Type{TapeType}, fa::Type{FA}, a2::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A2<:Annotation, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} +@inline function autodiff_deferred_thunk( + mode::ReverseModeSplit{ + ReturnPrimal, + ReturnShadow, + RuntimeActivity, + Width, + ModifiedBetweenT, + RABI, + ErrIfFuncWritten, + }, + tt::Type{TapeType}, + fa::Type{FA}, + a2::Type{A2}, + args::Vararg{Type{<:Annotation},Nargs}, +) where { + FA<:Annotation, + A2<:Annotation, + TapeType, + ReturnPrimal, + ReturnShadow, + Width, + ModifiedBetweenT, + RABI<:ABI, + Nargs, + ErrIfFuncWritten, + RuntimeActivity, +} @assert RABI == FFIABI width = if Width == 0 w = same_or_one(1, args...) @@ -854,7 +1292,7 @@ result, ∂v, ∂A end if ModifiedBetweenT === true - ModifiedBetween = Val(falses_from_args(Nargs+1)) + ModifiedBetween = Val(falses_from_args(Nargs + 1)) else ModifiedBetween = Val(ModifiedBetweenT) end @@ -865,40 +1303,69 @@ result, ∂v, ∂A primal_tt = Tuple{map(eltype, args)...} world = codegen_world_age(eltype(FA), primal_tt) - primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + primal_ptr = Compiler.deferred_codegen( + Val(world), + FA, + Val(TT), + Val(Compiler.remove_innerty(A2)), + Val(API.DEM_ReverseModePrimal), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + TapeType, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# + adjoint_ptr = Compiler.deferred_codegen( + Val(world), + FA, + Val(TT), + Val(Compiler.remove_innerty(A2)), + Val(API.DEM_ReverseModeGradient), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + TapeType, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# RT = if A2 <: Duplicated && width != 1 if A2 isa UnionAll - BatchDuplicated{T, width} where T + BatchDuplicated{T,width} where {T} else - BatchDuplicated{eltype(A2), width} + BatchDuplicated{eltype(A2),width} end elseif A2 <: DuplicatedNoNeed && width != 1 if A2 isa UnionAll - BatchDuplicatedNoNeed{T, width} where T + BatchDuplicatedNoNeed{T,width} where {T} else - BatchDuplicatedNoNeed{eltype(A2), width} + BatchDuplicatedNoNeed{eltype(A2),width} end elseif A2 <: MixedDuplicated && width != 1 if A2 isa UnionAll - BatchMixedDuplicated{T, width} where T + BatchMixedDuplicated{T,width} where {T} else - BatchMixedDuplicated{eltype(A2), width} + BatchMixedDuplicated{eltype(A2),width} end else A2 end - + rt = if RT isa UnionAll - RT{Core.Compiler.return_type(Tuple{eltype(FA), map(eltype, args)...})} + RT{Core.Compiler.return_type(Tuple{eltype(FA),map(eltype, args)...})} else @assert RT isa DataType RT end - aug_thunk = Compiler.AugmentedForwardThunk{Ptr{Cvoid}, FA, rt, TT, width, ReturnPrimal, TapeType}(primal_ptr) - adj_thunk = Compiler.AdjointThunk{Ptr{Cvoid}, FA, rt, TT, width, TapeType}(adjoint_ptr) + aug_thunk = + Compiler.AugmentedForwardThunk{Ptr{Cvoid},FA,rt,TT,width,ReturnPrimal,TapeType}( + primal_ptr, + ) + adj_thunk = Compiler.AdjointThunk{Ptr{Cvoid},FA,rt,TT,width,TapeType}(adjoint_ptr) aug_thunk, adj_thunk end @@ -911,11 +1378,11 @@ Base.@ccallable function __enzyme_double(x::Ptr{Cvoid})::Cvoid return nothing end -@inline function markType(::Type{T}, ptr::Ptr{Cvoid}) where T +@inline function markType(::Type{T}, ptr::Ptr{Cvoid}) where {T} markType(Base.unsafe_convert(Ptr{T}, ptr)) end -@inline function markType(data::Array{T}) where T +@inline function markType(data::Array{T}) where {T} GC.@preserve data markType(pointer(data)) end @@ -925,20 +1392,52 @@ end end @inline function markType(data::Ptr{Float32}) -@static if sizeof(Int) == sizeof(Int64) - Base.llvmcall(("declare void @__enzyme_float(i8* nocapture) nounwind define void @c(i64 %q) nounwind alwaysinline { %p = inttoptr i64 %q to i8* call void @__enzyme_float(i8* %p) ret void }", "c"), Cvoid, Tuple{Ptr{Float32}}, data) -else - Base.llvmcall(("declare void @__enzyme_float(i8* nocapture) nounwind define void @c(i32 %q) nounwind alwaysinline { %p = inttoptr i32 %q to i8* call void @__enzyme_float(i8* %p) ret void }", "c"), Cvoid, Tuple{Ptr{Float32}}, data) -end + @static if sizeof(Int) == sizeof(Int64) + Base.llvmcall( + ( + "declare void @__enzyme_float(i8* nocapture) nounwind define void @c(i64 %q) nounwind alwaysinline { %p = inttoptr i64 %q to i8* call void @__enzyme_float(i8* %p) ret void }", + "c", + ), + Cvoid, + Tuple{Ptr{Float32}}, + data, + ) + else + Base.llvmcall( + ( + "declare void @__enzyme_float(i8* nocapture) nounwind define void @c(i32 %q) nounwind alwaysinline { %p = inttoptr i32 %q to i8* call void @__enzyme_float(i8* %p) ret void }", + "c", + ), + Cvoid, + Tuple{Ptr{Float32}}, + data, + ) + end nothing end @inline function markType(data::Ptr{Float64}) -@static if sizeof(Int) == sizeof(Int64) - Base.llvmcall(("declare void @__enzyme_double(i8* nocapture) nounwind define void @c(i64 %q) nounwind alwaysinline { %p = inttoptr i64 %q to i8* call void @__enzyme_double(i8* %p) ret void }", "c"), Cvoid, Tuple{Ptr{Float64}}, data) -else - Base.llvmcall(("declare void @__enzyme_double(i8* nocapture) nounwind define void @c(i32 %q) nounwind alwaysinline { %p = inttoptr i32 %q to i8* call void @__enzyme_double(i8* %p) ret void }", "c"), Cvoid, Tuple{Ptr{Float64}}, data) -end + @static if sizeof(Int) == sizeof(Int64) + Base.llvmcall( + ( + "declare void @__enzyme_double(i8* nocapture) nounwind define void @c(i64 %q) nounwind alwaysinline { %p = inttoptr i64 %q to i8* call void @__enzyme_double(i8* %p) ret void }", + "c", + ), + Cvoid, + Tuple{Ptr{Float64}}, + data, + ) + else + Base.llvmcall( + ( + "declare void @__enzyme_double(i8* nocapture) nounwind define void @c(i32 %q) nounwind alwaysinline { %p = inttoptr i32 %q to i8* call void @__enzyme_double(i8* %p) ret void }", + "c", + ), + Cvoid, + Tuple{Ptr{Float64}}, + data, + ) + end nothing end @@ -947,24 +1446,24 @@ end ntuple(Val(N)) do i Base.@_inline_meta res = similar(x) - for idx in 1:N + for idx = 1:N @inbounds res[idx] = (i == idx) ? 1.0 : 0.0 end return res end end @inline function onehot(x, start, endl) - ntuple(Val(endl-start+1)) do i + ntuple(Val(endl - start + 1)) do i Base.@_inline_meta res = similar(x) - for idx in 1:length(x) - @inbounds res[idx] = (i + start - 1== idx) ? 1.0 : 0.0 + for idx = 1:length(x) + @inbounds res[idx] = (i + start - 1 == idx) ? 1.0 : 0.0 end return res end end -@inline function onehot(::Type{NTuple{N, T}}) where {T, N} +@inline function onehot(::Type{NTuple{N,T}}) where {T,N} ntuple(Val(N)) do i Base.@_inline_meta ntuple(Val(N)) do idx @@ -973,11 +1472,11 @@ end end end end -@inline function onehot(x::NTuple{N, T}) where {T, N} - onehot(NTuple{N, T}) +@inline function onehot(x::NTuple{N,T}) where {T,N} + onehot(NTuple{N,T}) end -@inline function onehot(x::NTuple{N, T}, start, endl) where {T, N} - ntuple(Val(endl-start+1)) do i +@inline function onehot(x::NTuple{N,T}, start, endl) where {T,N} + ntuple(Val(endl - start + 1)) do i Base.@_inline_meta ntuple(Val(N)) do idx Base.@_inline_meta @@ -1067,21 +1566,41 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) ``` """ -@generated function gradient(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::ty_0, args::Vararg{Any, N}) where {F, ty_0, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten, N} - toemit= Expr[quote - act_0 = !(x isa Enzyme.Const) && Compiler.active_reg_inner(Core.Typeof(x), #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState +@generated function gradient( + rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, + f::F, + x::ty_0, + args::Vararg{Any,N}, +) where {F,ty_0,ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten,N} + toemit = Expr[quote + act_0 = + !(x isa Enzyme.Const) && + Compiler.active_reg_inner(Core.Typeof(x), (), nothing, Val(true)) == + Compiler.ActiveState #=justActive=# end] rargs = Union{Symbol,Expr}[:x] acts = Symbol[Symbol("act_0")] - for i in 1:N - argidx = quote args[$i] end + for i = 1:N + argidx = quote + args[$i] + end push!(rargs, argidx) sym = Symbol("act_$i") push!(acts, sym) - push!(toemit, quote - $sym = !($argidx isa Enzyme.Const) && Compiler.active_reg_inner(Core.Typeof($argidx), #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState - end) + push!( + toemit, + quote + $sym = + !($argidx isa Enzyme.Const) && + Compiler.active_reg_inner( + Core.Typeof($argidx), + (), + nothing, + Val(true), + ) == Compiler.ActiveState #=justActive=# + end, + ) end idx = 0 @@ -1118,7 +1637,7 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) $shad end end) - idx+=1 + idx += 1 end push!(toemit, quote res = autodiff(rm, f, Active, $(enz_args...)) @@ -1128,7 +1647,7 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) return quote Base.@_inline_meta $(toemit...) - (; derivs=($(resargs...),), val=res[2]) + (; derivs = ($(resargs...),), val = res[2]) end else return quote @@ -1166,26 +1685,31 @@ gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0]) (derivs = ([3.0, 2.0],), val = 6.0) ``` """ -@inline function gradient!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} +@inline function gradient!( + rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, + dx::X, + f::F, + x::X, +) where {X<:Array,F,ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} make_zero!(dx) res = autodiff(rm, f, Active, Duplicated(x, dx)) return if ReturnPrimal - (; derivs=(dx,), val=res[2]) + (; derivs = (dx,), val = res[2]) else (dx,) end end -@inline function chunkedonehot(x, ::Val{chunk}) where chunk +@inline function chunkedonehot(x, ::Val{chunk}) where {chunk} sz = length(x) num = ((sz + chunk - 1) ÷ chunk) ntuple(Val(num)) do i Base.@_inline_meta - onehot(x, (i-1)*chunk+1, i == num ? sz : (i*chunk) ) + onehot(x, (i - 1) * chunk + 1, i == num ? sz : (i * chunk)) end end -@inline function chunkedonehot(x::AbstractFloat, ::Val{chunk}) where chunk +@inline function chunkedonehot(x::AbstractFloat, ::Val{chunk}) where {chunk} return ((one(x),),) end @@ -1201,23 +1725,27 @@ function create_shadows(::Val{1}, x) return (onehot(x),) end -function create_shadows(::Val{chunk}, x) where chunk +function create_shadows(::Val{chunk}, x) where {chunk} return (chunkedonehot(x, Val(chunk)),) end -struct TupleArray{T, Shape, Length, N} <: AbstractArray{T,N} - data::NTuple{Length, T} +struct TupleArray{T,Shape,Length,N} <: AbstractArray{T,N} + data::NTuple{Length,T} end -TupleArray(data::NTuple{Length, T}, Shape) where {Length, T} = TupleArray{T, Shape, Length, length(Shape)}(data) - -@inline Base.eltype(::TupleArray{T}) where T = T -@inline Base.eltype(::Type{<:TupleArray{T}}) where T = T -@inline Base.size(::TupleArray{<:Any, Shape}) where Shape = Shape -@inline Base.ndims(::TupleArray{<:Any, <:Any, <:Any, N}) where N = N - -function Base.convert(::Type{Array{T, N}}, X::TupleArray{T, Shape, Length, N}) where {T, Shape, Length, N} - vals = Array{T, N}(undef, Shape...) - for i in 1:Length +TupleArray(data::NTuple{Length,T}, Shape) where {Length,T} = + TupleArray{T,Shape,Length,length(Shape)}(data) + +@inline Base.eltype(::TupleArray{T}) where {T} = T +@inline Base.eltype(::Type{<:TupleArray{T}}) where {T} = T +@inline Base.size(::TupleArray{<:Any,Shape}) where {Shape} = Shape +@inline Base.ndims(::TupleArray{<:Any,<:Any,<:Any,N}) where {N} = N + +function Base.convert( + ::Type{Array{T,N}}, + X::TupleArray{T,Shape,Length,N}, +) where {T,Shape,Length,N} + vals = Array{T,N}(undef, Shape...) + for i = 1:Length @inbounds val[i] = X.data[i] end return vals @@ -1225,9 +1753,9 @@ end function Base.getindex(a::TupleArray, args::Vararg{Int,N}) where {N} start = 0 - for i in 1:N + for i = 1:N start *= size(a, N - i + 1) - start += (args[N - i + 1] - 1) + start += (args[N-i+1] - 1) end start += 1 return a.data[start] @@ -1301,10 +1829,16 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) ([3.0 2.0 0.0; 0.0 1.0 1.0],) ``` """ -@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f, x; chunk::CS=nothing, shadows=create_shadows(chunk, x)) where {ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity, CS} +@inline function gradient( + fm::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, + f, + x; + chunk::CS = nothing, + shadows = create_shadows(chunk, x), +) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,CS} if length(shadows[1]) == 0 return if ReturnPrimal - (; derivs=(x,), val=f(x.val)) + (; derivs = (x,), val = f(x.val)) else (x,) end @@ -1331,9 +1865,9 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) if ReturnPrimal rp = autodiff(fm, f, Duplicated, Duplicated(x, shadows[1][1])) dres1 = rp[1] - fm2 = ForwardMode{#=ReturnPrimal=#false, ABI, ErrIfFuncWritten,RuntimeActivity}() + fm2 = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() #=ReturnPrimal=# - res = ntuple(length(shadows[1])-1) do i + res = ntuple(length(shadows[1]) - 1) do i autodiff(fm2, f, Duplicated, Duplicated(x, shadows[1][i+1]))[1] end gres = if x isa AbstractFloat @@ -1359,9 +1893,16 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) gres = if x isa AbstractFloat dres1[1] else - fm2 = ForwardMode{#=ReturnPrimal=#false, ABI, ErrIfFuncWritten,RuntimeActivity}() - tmp = ntuple(length(shadows[1])-1) do i - values(autodiff(fm2, f, BatchDuplicated, BatchDuplicated(x, shadows[1][i+1]))[1]) + fm2 = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() #=ReturnPrimal=# + tmp = ntuple(length(shadows[1]) - 1) do i + values( + autodiff( + fm2, + f, + BatchDuplicated, + BatchDuplicated(x, shadows[1][i+1]), + )[1], + ) end tupleconcat(dres1, tmp...) end @@ -1397,7 +1938,7 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) cols end if ReturnPrimal - (; derivs=(res,), val=gradtup[2]) + (; derivs = (res,), val = gradtup[2]) else (res,) end @@ -1466,7 +2007,13 @@ In the future, when this function is extended to handle non-array return types, this function will retun an AbstractArray of shape `size(output)` of values of the input type. ``` """ -@inline function jacobian(::ReverseMode{ReturnPrimal,RuntimeActivity, RABI, Holomorphic, ErrIfFuncWritten}, f::F, x::X; n_outs::OutType=nothing, chunk::CT=nothing) where {ReturnPrimal, F, X, RABI<:ABI, ErrIfFuncWritten, RuntimeActivity, OutType, CT, Holomorphic} +@inline function jacobian( + ::ReverseMode{ReturnPrimal,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}, + f::F, + x::X; + n_outs::OutType = nothing, + chunk::CT = nothing, +) where {ReturnPrimal,F,X,RABI<:ABI,ErrIfFuncWritten,RuntimeActivity,OutType,CT,Holomorphic} if n_outs == nothing res = if f isa Const @@ -1475,43 +2022,57 @@ this function will retun an AbstractArray of shape `size(output)` of values of t f(x) end jac = if res isa AbstractArray - jacobian(ReverseMode{false,RuntimeActivity,RABI, Holomorphic, ErrIfFuncWritten}(), f, x; n_outs=Val(size(res)), chunk) + jacobian( + ReverseMode{false,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}(), + f, + x; + n_outs = Val(size(res)), + chunk, + ) elseif res isa AbstractFloat - gradient(ReverseMode{false,RuntimeActivity,RABI, Holomorphic, ErrIfFuncWritten}(), f, x) + gradient( + ReverseMode{false,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}(), + f, + x, + ) else - throw(AssertionError("Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))")) + throw( + AssertionError( + "Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))", + ), + ) end return if ReturnPrimal - (; derivs=jac, val=res) + (; derivs = jac, val = res) else jac end else - @assert !Holomorphic + @assert !Holomorphic n_out_val = if length(Compiler.element(n_outs)) == 0 0 else prod(Compiler.element(n_outs)) end - + if chunk == Val(0) throw(ErrorException("Cannot differentiate with a batch size of 0")) end - - XT = Core.Typeof(x) - MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState - tt = Tuple{XT} + + XT = Core.Typeof(x) + MD = Compiler.active_reg_inner(XT, (), nothing, Val(true)) == Compiler.ActiveState #=justActive=# + tt = Tuple{XT} rt = if f isa Const Core.Compiler.return_type(f.val, tt) else Core.Compiler.return_type(f, tt) end - + ModifiedBetween = Val((false, false)) FRT = Core.Typeof(f) FA = Const{FRT} - + opt_mi = if RABI <: NonGenABI Compiler.fspec(FRT, tt′) else @@ -1519,8 +2080,21 @@ this function will retun an AbstractArray of shape `size(output)` of values of t end if chunk == Val(1) || chunk == nothing - tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} - primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} + primal, adjoint = Enzyme.Compiler.thunk( + opt_mi, + FA, + DuplicatedNoNeed{rt}, + tt′, + Val(API.DEM_ReverseModeGradient), + Val(1), + ModifiedBetween, + Val(false), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# tmp = ntuple(Val(n_out_val)) do i Base.@_inline_meta z = make_zero(x) @@ -1536,18 +2110,46 @@ this function will retun an AbstractArray of shape `size(output)` of values of t rows, outshape else chunksize = Compiler.element(chunk) - tt′ = MD ? Tuple{BatchMixedDuplicated{XT, chunksize}} : Tuple{BatchDuplicated{XT, chunksize}} - primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#chunk, ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - + tt′ = + MD ? Tuple{BatchMixedDuplicated{XT,chunksize}} : + Tuple{BatchDuplicated{XT,chunksize}} + primal, adjoint = Enzyme.Compiler.thunk( + opt_mi, + FA, + BatchDuplicatedNoNeed{rt}, + tt′, + Val(API.DEM_ReverseModeGradient), + chunk, + ModifiedBetween, + Val(false), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# + num = ((n_out_val + chunksize - 1) ÷ chunksize) - + if num * chunksize == n_out_val last_size = chunksize primal2, adjoint2 = primal, adjoint else - last_size = n_out_val - (num-1)*chunksize - tt′ = Tuple{BatchDuplicated{Core.Typeof(x), last_size}} - primal2, adjoint2 = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + last_size = n_out_val - (num - 1) * chunksize + tt′ = Tuple{BatchDuplicated{Core.Typeof(x),last_size}} + primal2, adjoint2 = Enzyme.Compiler.thunk( + opt_mi, + FA, + BatchDuplicatedNoNeed{rt}, + tt′, + Val(API.DEM_ReverseModeGradient), + Val(last_size), + ModifiedBetween, + Val(false), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# end tmp = ntuple(num) do i @@ -1557,18 +2159,29 @@ this function will retun an AbstractArray of shape `size(output)` of values of t z = make_zero(x) MD ? Ref(z) : z end - res = (i == num ? primal2 : primal)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx)) + res = (i == num ? primal2 : primal)( + Const(f), + MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), + ) tape = res[1] j = 0 for shadow in res[3] j += 1 - @inbounds shadow[(i-1)*chunksize+j] += Compiler.default_adjoint(eltype(typeof(shadow))) + @inbounds shadow[(i-1)*chunksize+j] += + Compiler.default_adjoint(eltype(typeof(shadow))) end - (i == num ? adjoint2 : adjoint)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), tape) - return MD ? (ntuple(Val(i == num ? last_size : chunksize)) do idx - Base.@_inline_meta - dx[idx][] - end) : dx, (i == 1 ? size(res[3][1]) : nothing) + (i == num ? adjoint2 : adjoint)( + Const(f), + MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), + tape, + ) + return MD ? ( + ntuple(Val(i == num ? last_size : chunksize)) do idx + Base.@_inline_meta + dx[idx][] + end + ) : dx, + (i == 1 ? size(res[3][1]) : nothing) end rows = tupleconcat(map(first, tmp)...) outshape = tmp[1][2] @@ -1581,7 +2194,10 @@ this function will retun an AbstractArray of shape `size(output)` of values of t st3 = if length(outshape) == 1 && length(inshape) == 1 transpose(st2) else - transp = ( ((length(inshape)+1):(length(inshape)+length(outshape)))... , (1:length(inshape))... ) + transp = ( + ((length(inshape)+1):(length(inshape)+length(outshape)))..., + (1:length(inshape))..., + ) PermutedDimsArray(st2, transp) end @@ -1590,14 +2206,14 @@ this function will retun an AbstractArray of shape `size(output)` of values of t reshape(collect(rows), outshape) end if ReturnPrimal - # TODO optimize away redundant fwd pass - (; derivs=(res,), val=if f isa Enzyme.Const - f.val(x) - else - f(x) - end) + # TODO optimize away redundant fwd pass + (; derivs = (res,), val = if f isa Enzyme.Const + f.val(x) + else + f(x) + end) else - (res,) + (res,) end end end @@ -1624,7 +2240,7 @@ hvp(f, [2.0, 3.0], [5.0, 2.7]) 16.201003759768003 ``` """ -@inline function hvp(f::F, x::X, v::X) where {F, X} +@inline function hvp(f::F, x::X, v::X) where {F,X} res = make_zero(x) hvp!(res, f, x, v) return res @@ -1657,9 +2273,16 @@ res 16.201003759768003 ``` """ -@inline function hvp!(res::X, f::F, x::X, v::X) where {F, X} +@inline function hvp!(res::X, f::F, x::X, v::X) where {F,X} grad = make_zero(x) - Enzyme.autodiff(Forward, gradient!, Const(Reverse), DuplicatedNoNeed(grad, res), Const(f), Duplicated(x, v)) + Enzyme.autodiff( + Forward, + gradient!, + Const(Reverse), + DuplicatedNoNeed(grad, res), + Const(f), + Duplicated(x, v), + ) return nothing end @@ -1693,8 +2316,15 @@ grad 1.920340573300732 ``` """ -@inline function hvp_and_gradient!(res::X, grad::X, f::F, x::X, v::X) where {F, X} - Enzyme.autodiff(Forward, gradient!, Const(Reverse), Duplicated(grad, res), Const(f), Duplicated(x, v)) +@inline function hvp_and_gradient!(res::X, grad::X, f::F, x::X, v::X) where {F,X} + Enzyme.autodiff( + Forward, + gradient!, + Const(Reverse), + Duplicated(grad, res), + Const(f), + Duplicated(x, v), + ) return nothing end @@ -1732,7 +2362,7 @@ Enzyme.autodiff(Forward, sort, Duplicated, BatchDuplicated(x, (dx,))) """ macro import_frule(args...) return _import_frule(args...) -end +end function _import_rrule end # defined in EnzymeChainRulesCoreExt extension diff --git a/src/absint.jl b/src/absint.jl index b84657aadb..585b1625a3 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -1,9 +1,8 @@ # Abstractly interpret julia from LLVM # Return (bool if could interpret, julia object interpreted to) -function absint(arg::LLVM.Value, partial::Bool=false) - if isa(arg, LLVM.BitCastInst) || - isa(arg, LLVM.AddrSpaceCastInst) +function absint(arg::LLVM.Value, partial::Bool = false) + if isa(arg, LLVM.BitCastInst) || isa(arg, LLVM.AddrSpaceCastInst) return absint(operands(arg)[1], partial) end if isa(arg, ConstantExpr) @@ -18,12 +17,17 @@ function absint(arg::LLVM.Value, partial::Bool=false) nm = LLVM.name(fn) end for (fname, ty) in ( - ("jl_box_int64", Int64), ("ijl_box_int64", Int64), - ("jl_box_uint64", UInt64), ("ijl_box_uint64", UInt64), - ("jl_box_int32", Int32), ("ijl_box_int32", Int32), - ("jl_box_uint32", UInt32), ("ijl_box_uint32", UInt32), - ("jl_box_char", Char), ("ijl_box_char", Char), - ) + ("jl_box_int64", Int64), + ("ijl_box_int64", Int64), + ("jl_box_uint64", UInt64), + ("ijl_box_uint64", UInt64), + ("jl_box_int32", Int32), + ("ijl_box_int32", Int32), + ("jl_box_uint32", UInt32), + ("ijl_box_uint32", UInt32), + ("jl_box_char", Char), + ("ijl_box_char", Char), + ) if nm == fname v = first(operands(arg)) if isa(v, ConstantInt) @@ -39,7 +43,8 @@ function absint(arg::LLVM.Value, partial::Bool=false) return absint(operands(arg)[1], partial) end if nm == "jl_typeof" || nm == "ijl_typeof" - return abs_typeof(operands(arg)[1], partial) + vals = abs_typeof(operands(arg)[1], partial) + return (vals[1], vals[2]) end if LLVM.callconv(arg) == 37 || nm == "julia.call" index = 1 @@ -54,11 +59,11 @@ function absint(arg::LLVM.Value, partial::Bool=false) legal, Ty = absint(operands(arg)[index], partial) unionalls = [] for sarg in operands(arg)[index+1:end-1] - slegal , foundv = absint(sarg, partial) + slegal, foundv = absint(sarg, partial) if slegal push!(found, foundv) elseif partial - foundv = TypeVar(Symbol("sarg"*string(sarg))) + foundv = TypeVar(Symbol("sarg" * string(sarg))) push!(found, foundv) push!(unionalls, foundv) else @@ -80,7 +85,7 @@ function absint(arg::LLVM.Value, partial::Bool=false) found = [] legal = true for sarg in operands(arg)[index:end-1] - slegal , foundv = absint(sarg, partial) + slegal, foundv = absint(sarg, partial) if slegal push!(found, foundv) else @@ -107,25 +112,28 @@ function absint(arg::LLVM.Value, partial::Bool=false) end end - if isa(arg, GlobalVariable) + if isa(arg, GlobalVariable) gname = LLVM.name(arg) for (k, v) in JuliaGlobalNameMap - if gname == k || gname == "ejl_"*k + if gname == k || gname == "ejl_" * k return (true, v) end end for (k, v) in JuliaEnzymeNameMap - if gname == k || gname == "ejl_"*k + if gname == k || gname == "ejl_" * k return (true, v) end end end - if isa(arg, LLVM.LoadInst) && value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked) + if isa(arg, LLVM.LoadInst) && + value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked) ptr = operands(arg)[1] ce = ptr while isa(ce, ConstantExpr) - if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr + if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || + opcode(ce) == LLVM.API.LLVMBitCast || + opcode(ce) == LLVM.API.LLVMIntToPtr ce = operands(ce)[1] else break @@ -149,9 +157,21 @@ function absint(arg::LLVM.Value, partial::Bool=false) return (false, nothing) end -function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Type},Tuple{Bool, Nothing}} - if isa(arg, LLVM.BitCastInst) || - isa(arg, LLVM.AddrSpaceCastInst) +function actual_size(@nospecialize(typ2)) + if typ2 <: Array || typ2 <: AbstractString + return sizeof(Int) + elseif Base.isconcretetype(typ2) + return sizeof(typ2) + else + return sizeof(Int) + end +end + +function abs_typeof( + arg::LLVM.Value, + partial::Bool = false, +)::Union{Tuple{Bool,Type,GPUCompiler.ArgumentCC},Tuple{Bool,Nothing,Nothing}} + if isa(arg, LLVM.BitCastInst) || isa(arg, LLVM.AddrSpaceCastInst) return abs_typeof(operands(arg)[1], partial) end if isa(arg, ConstantExpr) @@ -160,7 +180,7 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ end end - if isa(arg, LLVM.CallInst) + if isa(arg, LLVM.CallInst) fn = LLVM.called_operand(arg) nm = "" if isa(fn, LLVM.Function) @@ -170,27 +190,36 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ if nm == "julia.pointer_from_objref" return abs_typeof(operands(arg)[1], partial) end - + for (fname, ty) in ( - ("jl_box_int64", Int64), ("ijl_box_int64", Int64), - ("jl_box_uint64", UInt64), ("ijl_box_uint64", UInt64), - ("jl_box_int32", Int32), ("ijl_box_int32", Int32), - ("jl_box_uint32", UInt32), ("ijl_box_uint32", UInt32), - ("jl_box_float32", Float32), ("ijl_box_float32", Float32), - ("jl_box_char", Char), ("ijl_box_char", Char), - ("jl_specializations_get_linfo", Core.MethodInstance), - ("ijl_specializations_get_linfo", Core.MethodInstance), - ) + ("jl_box_int64", Int64), + ("ijl_box_int64", Int64), + ("jl_box_uint64", UInt64), + ("ijl_box_uint64", UInt64), + ("jl_box_int32", Int32), + ("ijl_box_int32", Int32), + ("jl_box_uint32", UInt32), + ("ijl_box_uint32", UInt32), + ("jl_box_float32", Float32), + ("ijl_box_float32", Float32), + ("jl_box_char", Char), + ("ijl_box_char", Char), + ("jl_specializations_get_linfo", Core.MethodInstance), + ("ijl_specializations_get_linfo", Core.MethodInstance), + ) if nm == fname - return (true, ty) + return (true, ty, GPUCompiler.MUT_REF) end end - - # Type tag is arg 3 - if nm == "julia.gc_alloc_obj" || nm == "jl_gc_alloc_typed" || nm == "ijl_gc_alloc_typed" - return absint(operands(arg)[3], partial) + + # Type tag is arg 3 + if nm == "julia.gc_alloc_obj" || + nm == "jl_gc_alloc_typed" || + nm == "ijl_gc_alloc_typed" + vals = absint(operands(arg)[3], partial) + return (vals[1], vals[2], vals[1] ? GPUCompiler.BITS_REF : nothing) end - # Type tag is arg 1 + # Type tag is arg 1 if nm == "jl_alloc_array_1d" || nm == "ijl_alloc_array_1d" || nm == "jl_alloc_array_2d" || @@ -199,11 +228,13 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ nm == "ijl_alloc_array_3d" || nm == "jl_new_array" || nm == "ijl_new_array" - return absint(operands(arg)[1], partial) + vals = absint(operands(arg)[1], partial) + return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing) end if nm == "jl_new_structt" || nm == "ijl_new_structt" - return absint(operands(arg)[1], partial) + vals = absint(operands(arg)[1], partial) + return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing) end if LLVM.callconv(arg) == 37 || nm == "julia.call" @@ -213,14 +244,15 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ nm = LLVM.name(fn) index += 1 end - - if nm == "jl_f_isdefined" || nm == "ijl_f_isdefined" - return (true, Bool) - end + + if nm == "jl_f_isdefined" || nm == "ijl_f_isdefined" + return (true, Bool, GPUCompiler.MUT_REF) + end if nm == "jl_new_structv" || nm == "ijl_new_structv" @assert index == 2 - return absint(operands(arg)[index], partial) + vals = absint(operands(arg)[index], partial) + return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing) end if nm == "jl_f_tuple" || nm == "ijl_f_tuple" @@ -229,11 +261,11 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ unionalls = [] legal = true for sarg in operands(arg)[index:end-1] - slegal , foundv = abs_typeof(sarg, partial) + slegal, foundv, _ = abs_typeof(sarg, partial) if slegal push!(found, foundv) elseif partial - foundv = TypeVar(Symbol("sarg"*string(sarg))) + foundv = TypeVar(Symbol("sarg" * string(sarg))) push!(found, foundv) push!(unionalls, foundv) else @@ -246,7 +278,7 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ for u in unionalls res = UnionAll(u, res) end - return (true, res) + return (true, res, GPUCompiler.BITS_REF) end end end @@ -261,16 +293,26 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ end if nm == "jl_array_copy" || nm == "ijl_array_copy" - legal, RT = abs_typeof(operands(arg)[1], partial) + legal, RT, _ = abs_typeof(operands(arg)[1], partial) if legal @assert RT <: Array + return (legal, RT, GPUCompiler.MUT_REF) end - return (legal, RT) + return (legal, RT, nothing) end _, RT = enzyme_custom_extract_mi(arg, false) if RT !== nothing - return (true, RT) + llrt, sret, returnRoots = get_return_info(RT) + if sret !== nothing + if llrt == RT + return (true, RT, GPUCompiler.BITS_VALUE) + elseif llrt == Ptr{RT} + return (true, RT, GPUCompiler.MUT_REF) + elseif llrt == Any + return (true, RT, GPUCompiler.BITS_REF) + end + end end end @@ -279,15 +321,16 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ offset = nothing error = false while true - if isa(larg, LLVM.BitCastInst) || - isa(larg, LLVM.AddrSpaceCastInst) - larg = operands(larg)[1] - continue + if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.AddrSpaceCastInst) + larg = operands(larg)[1] + continue end - if offset === nothing && isa(larg, LLVM.GetElementPtrInst) && all(x->isa(x, LLVM.ConstantInt), operands(larg)[2:end]) - b = LLVM.IRBuilder() + if offset === nothing && + isa(larg, LLVM.GetElementPtrInst) && + all(x -> isa(x, LLVM.ConstantInt), operands(larg)[2:end]) + b = LLVM.IRBuilder() position!(b, larg) - offty = LLVM.IntType(8*sizeof(Int)) + offty = LLVM.IntType(8 * sizeof(Int)) offset = API.EnzymeComputeByteOffsetOfGEP(b, larg, offty) @assert isa(offset, LLVM.ConstantInt) offset = convert(Int, offset) @@ -302,154 +345,148 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ end if !error - if isa(larg, LLVM.Argument) - f = LLVM.Function(LLVM.API.LLVMGetParamParent(larg)) - idx = only([i for (i, v) in enumerate(LLVM.parameters(f)) if v == larg]) - typ, byref = enzyme_extract_parm_type(f, idx, #=error=#false) + legal, typ, byref = abs_typeof(larg) + if legal && (byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF) @static if VERSION < v"1.11-" - if typ !== nothing && typ <: Array && Base.isconcretetype(typ) + if typ <: Array && Base.isconcretetype(typ) T = eltype(typ) if offset === nothing || offset == 0 - return (true, Ptr{T}) + return (true, Ptr{T}, GPUCompiler.BITS_VALUE) else - return (true, Int) + return (true, Int, GPUCompiler.BITS_VALUE) end end end - if typ !== nothing && byref == GPUCompiler.BITS_REF + if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF + dl = LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(arg)))) if offset === nothing - return (true, typ) - else - function llsz(ty) - if isa(ty, LLVM.PointerType) - return sizeof(Ptr{Cvoid}) - elseif isa(ty, LLVM.IntegerType) - return LLVM.width(ty) / 8 + byref = GPUCompiler.BITS_VALUE + legal = true + typ2 = typ + while actual_size(typ2) != sizeof(dl, value_type(arg)) + if fieldcount(typ2) > 0 + typ2 = fieldtype(typ, 1) + if !Base.allocatedinline(typ2) + if byref != GPUCompiler.BITS_VALUE + legal = false + break + end + byref = GPUCompiler.MUT_REF + continue + end end - error("Unknown llvm type to size: "*string(ty)) + legal = false + break + end + if legal + return (true, typ2, byref) end + else @assert Base.isconcretetype(typ) - for i in 1:fieldcount(typ) + for i = 1:fieldcount(typ) if fieldoffset(typ, i) == offset - subT = fieldtype(typ, i) + subT = fieldtype(typ, i) fsize = if i == fieldcount(typ) sizeof(typ) else - fieldoffset(typ, i+1) + fieldoffset(typ, i + 1) end - offset - if fsize == llsz(value_type(larg)) - if Base.isconcretetype(subT) && is_concrete_tuple(subT) && length(subT.parameters) == 1 + if fsize == sizeof(dl, value_type(arg)) + if Base.isconcretetype(subT) && + is_concrete_tuple(subT) && + length(subT.parameters) == 1 subT = subT.parameters[1] end - return (true, subT) + if Base.allocatedinline(subT) + return (true, subT, GPUCompiler.BITS_VALUE) + else + return (true, subT, GPUCompiler.MUT_REF) + end end end end end end - else - legal, RT = abs_typeof(larg) - if legal - if RT <: Array && Base.isconcretetype(RT) - @static if VERSION < v"1.11-" - T = eltype(RT) - - if offset == 0 - return (true, Ptr{T}) - end - - return (true, Int) - end - end - if RT <: Ptr && Base.isconcretetype(RT) - return (true, eltype(RT)) - end - end + elseif legal && if typ <: Ptr && Base.isconcretetype(typ) + return (true, eltype(typ), GPUCompiler.BITS_VALUE) + end end end end - + if isa(arg, LLVM.ExtractValueInst) larg = operands(arg)[1] indptrs = LLVM.API.LLVMGetIndices(arg) numind = LLVM.API.LLVMGetNumIndices(arg) - offset = Cuint[unsafe_load(indptrs, i) for i in 1:numind] - if isa(larg, LLVM.Argument) || isa(larg, LLVM.ExtractValueInst) - typ, byref = if isa(larg, LLVM.Argument) - f = LLVM.Function(LLVM.API.LLVMGetParamParent(larg)) - idx = only([i for (i, v) in enumerate(LLVM.parameters(f)) if v == larg]) - enzyme_extract_parm_type(f, idx, #=error=#false) - else - found, typ = abs_typeof(larg, partial) - if !found - return (false, nothing) - end - (typ, GPUCompiler.BITS_VALUE) - end - if typ !== nothing && byref == GPUCompiler.BITS_VALUE - for ind in offset - @assert Base.isconcretetype(typ) - cnt = 0 - for i in 1:fieldcount(typ) - styp = fieldtype(typ, i) - if isghostty(styp) - continue - end - if cnt == ind - typ = styp - break - end - cnt+=1 + offset = Cuint[unsafe_load(indptrs, i) for i = 1:numind] + found, typ, byref = abs_typeof(larg, partial) + if !found + return (false, nothing, nothing) + end + if byref == GPUCompiler.BITS_VALUE + for ind in offset + @assert Base.isconcretetype(typ) + cnt = 0 + for i = 1:fieldcount(typ) + styp = fieldtype(typ, i) + if isghostty(styp) + continue end + if cnt == ind + typ = styp + break + end + cnt += 1 end - return (true, typ) + end + if Base.allocatedinline(typ) + return (true, typ, GPUCompiler.BITS_VALUE) + else + return (true, typ, GPUCompiler.MUT_REF) end end end - + if isa(arg, LLVM.Argument) f = LLVM.Function(LLVM.API.LLVMGetParamParent(arg)) idx = only([i for (i, v) in enumerate(LLVM.parameters(f)) if v == arg]) - typ, byref = enzyme_extract_parm_type(f, idx, #=error=#false) + typ, byref = enzyme_extract_parm_type(f, idx, false) #=error=# if typ !== nothing - if byref == GPUCompiler.BITS_REF - typ = Ptr{typ} - end - return (true, typ) + return (true, typ, byref) end end legal, val = absint(arg, partial) - if legal - return (true, Core.Typeof(val)) - end - return (false, nothing) -end - -function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} - if isa(arg, ConstantExpr) - ce = arg - while isa(ce, ConstantExpr) - if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr - ce = operands(ce)[1] - elseif opcode(ce) == LLVM.API.LLVMGetElementPtr - if all(x -> isa(x, LLVM.ConstantInt) && convert(UInt, x) == 0, operands(ce)[2:end]) - ce = operands(ce)[1] - else - break - end - else - break - end - end - if isa(ce, LLVM.GlobalVariable) - ce = LLVM.initializer(ce) - if (isa(ce, LLVM.ConstantArray) || isa(ce, LLVM.ConstantDataArray)) && eltype(value_type(ce)) == LLVM.IntType(8) - return (true, String(map((x)->convert(UInt8, x), collect(ce)[1:(end-1)]))) - end - - end - end - return (false, "") + if legal + return (true, Core.Typeof(val), GPUCompiler.BITS_REF) + end + return (false, nothing, nothing) end +# +# function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} +# if isa(arg, ConstantExpr) +# ce = arg +# while isa(ce, ConstantExpr) +# if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr +# ce = operands(ce)[1] +# elseif opcode(ce) == LLVM.API.LLVMGetElementPtr +# if all(x -> isa(x, LLVM.ConstantInt) && convert(UInt, x) == 0, operands(ce)[2:end]) +# ce = operands(ce)[1] +# else +# break +# end +# else +# break +# end +# end +# if isa(ce, LLVM.GlobalVariable) +# ce = LLVM.initializer(ce) +# if (isa(ce, LLVM.ConstantArray) || isa(ce, LLVM.ConstantDataArray)) && eltype(value_type(ce)) == LLVM.IntType(8) +# return (true, String(map((x)->convert(UInt8, x), collect(ce)[1:(end-1)]))) +# end +# +# end +# end +# return (false, "") +# end diff --git a/src/api.jl b/src/api.jl index 9e446dcf57..2861eba86f 100644 --- a/src/api.jl +++ b/src/api.jl @@ -20,51 +20,56 @@ struct IntList data::Ptr{Int64} size::Csize_t end -IntList() = IntList(Ptr{Int64}(0),0) - -@cenum(CConcreteType, - DT_Anything = 0, - DT_Integer = 1, - DT_Pointer = 2, - DT_Half = 3, - DT_Float = 4, - DT_Double = 5, - DT_Unknown = 6, - DT_FP80 = 7, - DT_BFloat16 = 8 +IntList() = IntList(Ptr{Int64}(0), 0) + +@cenum( + CConcreteType, + DT_Anything = 0, + DT_Integer = 1, + DT_Pointer = 2, + DT_Half = 3, + DT_Float = 4, + DT_Double = 5, + DT_Unknown = 6, + DT_FP80 = 7, + DT_BFloat16 = 8 ) function EnzymeConcreteTypeIsFloat(cc::CConcreteType) - if cc == DT_Half - return LLVM.HalfType() - elseif cc == DT_Float - return LLVM.FloatType() - elseif cc == DT_Double - return LLVM.DoubleType() - elseif cc == DT_FP80 - return LLVM.X86FP80Type() - elseif cc == DT_BFloat16 - return LLVM.BFloatType() - else - return nothing - end -end - -@cenum(CValueType, - VT_None = 0, - VT_Primal = 1, - VT_Shadow = 2, - VT_Both = 3 -) - -function EnzymeBitcodeReplacement(mod, NotToReplace, found) + if cc == DT_Half + return LLVM.HalfType() + elseif cc == DT_Float + return LLVM.FloatType() + elseif cc == DT_Double + return LLVM.DoubleType() + elseif cc == DT_FP80 + return LLVM.X86FP80Type() + elseif cc == DT_BFloat16 + return LLVM.BFloatType() + else + return nothing + end +end + +@cenum(CValueType, VT_None = 0, VT_Primal = 1, VT_Shadow = 2, VT_Both = 3) + +function EnzymeBitcodeReplacement(mod, NotToReplace, found) foundSize = Ref{Csize_t}(0) foundP = Ref{Ptr{Cstring}}(C_NULL) - res = ccall((:EnzymeBitcodeReplacement, libEnzymeBCLoad), UInt8, (LLVM.API.LLVMModuleRef, Ptr{Cstring}, Csize_t, Ptr{Ptr{Cstring}}, Ptr{Csize_t}), mod, NotToReplace, length(NotToReplace), foundP, foundSize) + res = ccall( + (:EnzymeBitcodeReplacement, libEnzymeBCLoad), + UInt8, + (LLVM.API.LLVMModuleRef, Ptr{Cstring}, Csize_t, Ptr{Ptr{Cstring}}, Ptr{Csize_t}), + mod, + NotToReplace, + length(NotToReplace), + foundP, + foundSize, + ) foundNum = foundSize[] if foundNum != 0 foundP = foundP[] - for i in 1:foundNum + for i = 1:foundNum str = unsafe_load(foundP, i) push!(found, Base.unsafe_string(str)) Libc.free(str) @@ -72,34 +77,75 @@ function EnzymeBitcodeReplacement(mod, NotToReplace, found) end Libc.free(foundP) end - return res + return res end struct EnzymeTypeTree end const CTypeTreeRef = Ptr{EnzymeTypeTree} EnzymeNewTypeTree() = ccall((:EnzymeNewTypeTree, libEnzyme), CTypeTreeRef, ()) -EnzymeNewTypeTreeCT(T, ctx) = ccall((:EnzymeNewTypeTreeCT, libEnzyme), CTypeTreeRef, (CConcreteType, LLVMContextRef), T, ctx) -EnzymeNewTypeTreeTR(tt) = ccall((:EnzymeNewTypeTreeTR, libEnzyme), CTypeTreeRef, (CTypeTreeRef,), tt) +EnzymeNewTypeTreeCT(T, ctx) = ccall( + (:EnzymeNewTypeTreeCT, libEnzyme), + CTypeTreeRef, + (CConcreteType, LLVMContextRef), + T, + ctx, +) +EnzymeNewTypeTreeTR(tt) = + ccall((:EnzymeNewTypeTreeTR, libEnzyme), CTypeTreeRef, (CTypeTreeRef,), tt) EnzymeFreeTypeTree(tt) = ccall((:EnzymeFreeTypeTree, libEnzyme), Cvoid, (CTypeTreeRef,), tt) -EnzymeSetTypeTree(dst, src) = ccall((:EnzymeSetTypeTree, libEnzyme), UInt8, (CTypeTreeRef, CTypeTreeRef), dst, src) -EnzymeMergeTypeTree(dst, src) = ccall((:EnzymeMergeTypeTree, libEnzyme), UInt8, (CTypeTreeRef, CTypeTreeRef), dst, src) -function EnzymeCheckedMergeTypeTree(dst, src) +EnzymeSetTypeTree(dst, src) = + ccall((:EnzymeSetTypeTree, libEnzyme), UInt8, (CTypeTreeRef, CTypeTreeRef), dst, src) +EnzymeMergeTypeTree(dst, src) = + ccall((:EnzymeMergeTypeTree, libEnzyme), UInt8, (CTypeTreeRef, CTypeTreeRef), dst, src) +function EnzymeCheckedMergeTypeTree(dst, src) legal = Ref{UInt8}(0) - res = ccall((:EnzymeCheckedMergeTypeTree, libEnzyme), UInt8, (CTypeTreeRef, CTypeTreeRef, Ptr{UInt8}), dst, src, legal) + res = ccall( + (:EnzymeCheckedMergeTypeTree, libEnzyme), + UInt8, + (CTypeTreeRef, CTypeTreeRef, Ptr{UInt8}), + dst, + src, + legal, + ) return res != 0, legal[] != 0 end -EnzymeTypeTreeOnlyEq(dst, x) = ccall((:EnzymeTypeTreeOnlyEq, libEnzyme), Cvoid, (CTypeTreeRef, Int64), dst, x) -EnzymeTypeTreeLookupEq(dst, x, dl) = ccall((:EnzymeTypeTreeLookupEq, libEnzyme), Cvoid, (CTypeTreeRef, Int64, Cstring), dst, x, dl) -EnzymeTypeTreeCanonicalizeInPlace(dst, x, dl) = ccall((:EnzymeTypeTreeCanonicalizeInPlace, libEnzyme), Cvoid, (CTypeTreeRef, Int64, Cstring), dst, x, dl) -EnzymeTypeTreeData0Eq(dst) = ccall((:EnzymeTypeTreeData0Eq, libEnzyme), Cvoid, (CTypeTreeRef,), dst) -EnzymeTypeTreeInner0(dst) = ccall((:EnzymeTypeTreeInner0, libEnzyme), CConcreteType, (CTypeTreeRef,), dst) -EnzymeTypeTreeShiftIndiciesEq(dst, dl, offset, maxSize, addOffset) = - ccall((:EnzymeTypeTreeShiftIndiciesEq, libEnzyme), Cvoid, (CTypeTreeRef, Cstring, Int64, Int64, UInt64), - dst, dl, offset, maxSize, addOffset) +EnzymeTypeTreeOnlyEq(dst, x) = + ccall((:EnzymeTypeTreeOnlyEq, libEnzyme), Cvoid, (CTypeTreeRef, Int64), dst, x) +EnzymeTypeTreeLookupEq(dst, x, dl) = ccall( + (:EnzymeTypeTreeLookupEq, libEnzyme), + Cvoid, + (CTypeTreeRef, Int64, Cstring), + dst, + x, + dl, +) +EnzymeTypeTreeCanonicalizeInPlace(dst, x, dl) = ccall( + (:EnzymeTypeTreeCanonicalizeInPlace, libEnzyme), + Cvoid, + (CTypeTreeRef, Int64, Cstring), + dst, + x, + dl, +) +EnzymeTypeTreeData0Eq(dst) = + ccall((:EnzymeTypeTreeData0Eq, libEnzyme), Cvoid, (CTypeTreeRef,), dst) +EnzymeTypeTreeInner0(dst) = + ccall((:EnzymeTypeTreeInner0, libEnzyme), CConcreteType, (CTypeTreeRef,), dst) +EnzymeTypeTreeShiftIndiciesEq(dst, dl, offset, maxSize, addOffset) = ccall( + (:EnzymeTypeTreeShiftIndiciesEq, libEnzyme), + Cvoid, + (CTypeTreeRef, Cstring, Int64, Int64, UInt64), + dst, + dl, + offset, + maxSize, + addOffset, +) -EnzymeTypeTreeToString(tt) = ccall((:EnzymeTypeTreeToString, libEnzyme), Cstring, (CTypeTreeRef,), tt) +EnzymeTypeTreeToString(tt) = + ccall((:EnzymeTypeTreeToString, libEnzyme), Cstring, (CTypeTreeRef,), tt) EnzymeStringFree(str) = ccall((:EnzymeStringFree, libEnzyme), Cvoid, (Cstring,), str) struct CFnTypeInfo @@ -109,35 +155,65 @@ struct CFnTypeInfo known_values::Ptr{IntList} end -SetMD(v::Union{LLVM.Instruction, LLVM.GlobalVariable}, kind::String, node::LLVM.Metadata) = ccall((:EnzymeSetStringMD, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef, Cstring, LLVM.API.LLVMValueRef), v, kind, LLVM.Value(node)) +SetMD(v::Union{LLVM.Instruction,LLVM.GlobalVariable}, kind::String, node::LLVM.Metadata) = + ccall( + (:EnzymeSetStringMD, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef, Cstring, LLVM.API.LLVMValueRef), + v, + kind, + LLVM.Value(node), + ) @static if !isdefined(LLVM, :ValueMetadataDict) -Base.haskey(md::LLVM.InstructionMetadataDict, kind::String) = - ccall((:EnzymeGetStringMD, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef, Cstring), md.inst, kind) != C_NULL - -function Base.getindex(md::LLVM.InstructionMetadataDict, kind::String) - objref = ccall((:EnzymeGetStringMD, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef, Cstring), md.inst, kind) != C_NULL - objref == C_NULL && throw(KeyError(kind)) - return LLVM.Metadata(LLVM.MetadataAsValue(objref)) - end - -Base.setindex!(md::LLVM.InstructionMetadataDict, node::LLVM.Metadata, kind::String) = - ccall((:EnzymeSetStringMD, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef, Cstring, LLVM.API.LLVMValueRef), md.inst, kind, LLVM.Value(node)) -end + Base.haskey(md::LLVM.InstructionMetadataDict, kind::String) = + ccall( + (:EnzymeGetStringMD, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef, Cstring), + md.inst, + kind, + ) != C_NULL + + function Base.getindex(md::LLVM.InstructionMetadataDict, kind::String) + objref = + ccall( + (:EnzymeGetStringMD, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef, Cstring), + md.inst, + kind, + ) != C_NULL + objref == C_NULL && throw(KeyError(kind)) + return LLVM.Metadata(LLVM.MetadataAsValue(objref)) + end -@cenum(CDIFFE_TYPE, - DFT_OUT_DIFF = 0, # add differential to an output struct - DFT_DUP_ARG = 1, # duplicate the argument and store differential inside - DFT_CONSTANT = 2, # no differential - DFT_DUP_NONEED = 3 # duplicate this argument and store differential inside, - # but don't need the forward + Base.setindex!(md::LLVM.InstructionMetadataDict, node::LLVM.Metadata, kind::String) = + ccall( + (:EnzymeSetStringMD, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef, Cstring, LLVM.API.LLVMValueRef), + md.inst, + kind, + LLVM.Value(node), + ) +end + +@cenum( + CDIFFE_TYPE, + DFT_OUT_DIFF = 0, # add differential to an output struct + DFT_DUP_ARG = 1, # duplicate the argument and store differential inside + DFT_CONSTANT = 2, # no differential + DFT_DUP_NONEED = 3 # duplicate this argument and store differential inside, + # but don't need the forward ) -@cenum(CDerivativeMode, - DEM_ForwardMode = 0, - DEM_ReverseModePrimal = 1, - DEM_ReverseModeGradient = 2, - DEM_ReverseModeCombined = 3 +@cenum( + CDerivativeMode, + DEM_ForwardMode = 0, + DEM_ReverseModePrimal = 1, + DEM_ReverseModeGradient = 2, + DEM_ReverseModeCombined = 3 ) # Create the derivative function itself. @@ -155,31 +231,133 @@ end # pass # \p AtomicAdd is whether to perform all adjoint updates to memory in an atomic way # \p PostOpt is whether to perform basic optimization of the function after synthesis -function EnzymeCreatePrimalAndGradient(logic, todiff, retType, constant_args, TA, - returnValue, dretUsed, mode, runtimeActivity, width, additionalArg, - forceAnonymousTape, typeInfo, - uncacheable_args, augmented, atomicAdd) +function EnzymeCreatePrimalAndGradient( + logic, + todiff, + retType, + constant_args, + TA, + returnValue, + dretUsed, + mode, + runtimeActivity, + width, + additionalArg, + forceAnonymousTape, + typeInfo, + uncacheable_args, + augmented, + atomicAdd, +) freeMemory = true - ccall((:EnzymeCreatePrimalAndGradient, libEnzyme), LLVMValueRef, - (EnzymeLogicRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t, - EnzymeTypeAnalysisRef, UInt8, UInt8, CDerivativeMode, UInt8, Cuint, UInt8, LLVMTypeRef, UInt8, CFnTypeInfo, - Ptr{UInt8}, Csize_t, EnzymeAugmentedReturnPtr, UInt8), - logic, C_NULL, C_NULL, todiff, retType, constant_args, length(constant_args), TA, returnValue, - dretUsed, mode, runtimeActivity, width, freeMemory, additionalArg, forceAnonymousTape, typeInfo, uncacheable_args, length(uncacheable_args), - augmented, atomicAdd) -end - -function EnzymeCreateForwardDiff(logic, todiff, retType, constant_args, TA, - returnValue, mode, runtimeActivity, width, additionalArg, typeInfo, - uncacheable_args) + ccall( + (:EnzymeCreatePrimalAndGradient, libEnzyme), + LLVMValueRef, + ( + EnzymeLogicRef, + LLVMValueRef, + LLVM.API.LLVMBuilderRef, + LLVMValueRef, + CDIFFE_TYPE, + Ptr{CDIFFE_TYPE}, + Csize_t, + EnzymeTypeAnalysisRef, + UInt8, + UInt8, + CDerivativeMode, + UInt8, + Cuint, + UInt8, + LLVMTypeRef, + UInt8, + CFnTypeInfo, + Ptr{UInt8}, + Csize_t, + EnzymeAugmentedReturnPtr, + UInt8, + ), + logic, + C_NULL, + C_NULL, + todiff, + retType, + constant_args, + length(constant_args), + TA, + returnValue, + dretUsed, + mode, + runtimeActivity, + width, + freeMemory, + additionalArg, + forceAnonymousTape, + typeInfo, + uncacheable_args, + length(uncacheable_args), + augmented, + atomicAdd, + ) +end + +function EnzymeCreateForwardDiff( + logic, + todiff, + retType, + constant_args, + TA, + returnValue, + mode, + runtimeActivity, + width, + additionalArg, + typeInfo, + uncacheable_args, +) freeMemory = true aug = C_NULL - ccall((:EnzymeCreateForwardDiff, libEnzyme), LLVMValueRef, - (EnzymeLogicRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t, - EnzymeTypeAnalysisRef, UInt8, CDerivativeMode, UInt8, UInt8, Cuint, LLVMTypeRef, CFnTypeInfo, - Ptr{UInt8}, Csize_t, EnzymeAugmentedReturnPtr), - logic, C_NULL, C_NULL, todiff, retType, constant_args, length(constant_args), TA, returnValue, - mode, freeMemory, runtimeActivity, width, additionalArg, typeInfo, uncacheable_args, length(uncacheable_args), aug) + ccall( + (:EnzymeCreateForwardDiff, libEnzyme), + LLVMValueRef, + ( + EnzymeLogicRef, + LLVMValueRef, + LLVM.API.LLVMBuilderRef, + LLVMValueRef, + CDIFFE_TYPE, + Ptr{CDIFFE_TYPE}, + Csize_t, + EnzymeTypeAnalysisRef, + UInt8, + CDerivativeMode, + UInt8, + UInt8, + Cuint, + LLVMTypeRef, + CFnTypeInfo, + Ptr{UInt8}, + Csize_t, + EnzymeAugmentedReturnPtr, + ), + logic, + C_NULL, + C_NULL, + todiff, + retType, + constant_args, + length(constant_args), + TA, + returnValue, + mode, + freeMemory, + runtimeActivity, + width, + additionalArg, + typeInfo, + uncacheable_args, + length(uncacheable_args), + aug, + ) end # Create an augmented forward pass. @@ -193,16 +371,61 @@ end # \p forceAnonymousTape forces the tape to be an i8* rather than the true tape structure # \p AtomicAdd is whether to perform all adjoint updates to memory in an atomic way # \p PostOpt is whether to perform basic optimization of the function after synthesis -function EnzymeCreateAugmentedPrimal(logic, todiff, retType, constant_args, TA, returnUsed, - shadowReturnUsed, - typeInfo, uncacheable_args, forceAnonymousTape, runtimeActivity, width, atomicAdd) - ccall((:EnzymeCreateAugmentedPrimal, libEnzyme), EnzymeAugmentedReturnPtr, - (EnzymeLogicRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t, - EnzymeTypeAnalysisRef, UInt8, UInt8, - CFnTypeInfo, Ptr{UInt8}, Csize_t, UInt8, UInt8, Cuint, UInt8), - logic, C_NULL, C_NULL, todiff, retType, constant_args, length(constant_args), TA, returnUsed, +function EnzymeCreateAugmentedPrimal( + logic, + todiff, + retType, + constant_args, + TA, + returnUsed, + shadowReturnUsed, + typeInfo, + uncacheable_args, + forceAnonymousTape, + runtimeActivity, + width, + atomicAdd, +) + ccall( + (:EnzymeCreateAugmentedPrimal, libEnzyme), + EnzymeAugmentedReturnPtr, + ( + EnzymeLogicRef, + LLVMValueRef, + LLVM.API.LLVMBuilderRef, + LLVMValueRef, + CDIFFE_TYPE, + Ptr{CDIFFE_TYPE}, + Csize_t, + EnzymeTypeAnalysisRef, + UInt8, + UInt8, + CFnTypeInfo, + Ptr{UInt8}, + Csize_t, + UInt8, + UInt8, + Cuint, + UInt8, + ), + logic, + C_NULL, + C_NULL, + todiff, + retType, + constant_args, + length(constant_args), + TA, + returnUsed, shadowReturnUsed, - typeInfo, uncacheable_args, length(uncacheable_args), forceAnonymousTape, runtimeActivity, width, atomicAdd) + typeInfo, + uncacheable_args, + length(uncacheable_args), + forceAnonymousTape, + runtimeActivity, + width, + atomicAdd, + ) end # typedef uint8_t (*CustomRuleType)(int /*direction*/, CTypeTreeRef /*return*/, @@ -213,7 +436,15 @@ const CustomRuleType = Ptr{Cvoid} function CreateTypeAnalysis(logic, rulenames, rules) @assert length(rulenames) == length(rules) - ccall((:CreateTypeAnalysis, libEnzyme), EnzymeTypeAnalysisRef, (EnzymeLogicRef, Ptr{Cstring}, Ptr{CustomRuleType}, Csize_t), logic, rulenames, rules, length(rules)) + ccall( + (:CreateTypeAnalysis, libEnzyme), + EnzymeTypeAnalysisRef, + (EnzymeLogicRef, Ptr{Cstring}, Ptr{CustomRuleType}, Csize_t), + logic, + rulenames, + rules, + length(rules), + ) end function ClearTypeAnalysis(ta) @@ -225,67 +456,422 @@ function FreeTypeAnalysis(ta) end function EnzymeAnalyzeTypes(ta, CTI, F) - ccall((:EnzymeAnalyzeTypes, libEnzyme), EnzymeTypeAnalyzerRef, (EnzymeTypeAnalysisRef, CFnTypeInfo, LLVMValueRef), ta, CTI, F) + ccall( + (:EnzymeAnalyzeTypes, libEnzyme), + EnzymeTypeAnalyzerRef, + (EnzymeTypeAnalysisRef, CFnTypeInfo, LLVMValueRef), + ta, + CTI, + F, + ) end - + const CustomShadowAlloc = Ptr{Cvoid} const CustomShadowFree = Ptr{Cvoid} -EnzymeRegisterAllocationHandler(name, ahandle, fhandle) = ccall((:EnzymeRegisterAllocationHandler, libEnzyme), Cvoid, (Cstring, CustomShadowAlloc, CustomShadowFree), name, ahandle, fhandle) +EnzymeRegisterAllocationHandler(name, ahandle, fhandle) = ccall( + (:EnzymeRegisterAllocationHandler, libEnzyme), + Cvoid, + (Cstring, CustomShadowAlloc, CustomShadowFree), + name, + ahandle, + fhandle, +) const CustomAugmentedForwardPass = Ptr{Cvoid} const CustomForwardPass = Ptr{Cvoid} const CustomReversePass = Ptr{Cvoid} -EnzymeRegisterCallHandler(name, fwdhandle, revhandle) = ccall((:EnzymeRegisterCallHandler, libEnzyme), Cvoid, (Cstring, CustomAugmentedForwardPass, CustomReversePass), name, fwdhandle, revhandle) -EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall((:EnzymeRegisterFwdCallHandler, libEnzyme), Cvoid, (Cstring, CustomForwardPass), name, fwdhandle) +EnzymeRegisterCallHandler(name, fwdhandle, revhandle) = ccall( + (:EnzymeRegisterCallHandler, libEnzyme), + Cvoid, + (Cstring, CustomAugmentedForwardPass, CustomReversePass), + name, + fwdhandle, + revhandle, +) +EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall( + (:EnzymeRegisterFwdCallHandler, libEnzyme), + Cvoid, + (Cstring, CustomForwardPass), + name, + fwdhandle, +) -EnzymeInsertValue(B::LLVM.IRBuilder, v::LLVM.Value, v2::LLVM.Value, insts::Vector{Cuint}, name="") = LLVM.Value(ccall((:EnzymeInsertValue, libEnzyme), LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVMValueRef, LLVMValueRef, Ptr{Cuint}, Int64, Cstring), B, v, v2, insts, length(insts), name)) +EnzymeInsertValue( + B::LLVM.IRBuilder, + v::LLVM.Value, + v2::LLVM.Value, + insts::Vector{Cuint}, + name = "", +) = LLVM.Value( + ccall( + (:EnzymeInsertValue, libEnzyme), + LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVMValueRef, LLVMValueRef, Ptr{Cuint}, Int64, Cstring), + B, + v, + v2, + insts, + length(insts), + name, + ), +) const CustomDiffUse = Ptr{Cvoid} -EnzymeRegisterDiffUseCallHandler(name, handle) = ccall((:EnzymeRegisterDiffUseCallHandler, libEnzyme), Cvoid, (Cstring, CustomDiffUse), name, handle) -EnzymeSetCalledFunction(ci::LLVM.CallInst, fn::LLVM.Function, toremove) = ccall((:EnzymeSetCalledFunction, libEnzyme), Cvoid, (LLVMValueRef, LLVMValueRef, Ptr{Int64}, Int64), ci, fn, toremove, length(toremove)) -EnzymeCloneFunctionWithoutReturnOrArgs(fn::LLVM.Function, keepret, args) = ccall((:EnzymeCloneFunctionWithoutReturnOrArgs, libEnzyme), LLVMValueRef, (LLVMValueRef,UInt8,Ptr{Int64}, Int64), fn, keepret, args, length(args)) -EnzymeGetShadowType(width, T) = ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64,LLVMTypeRef), width, T) - -EnzymeGradientUtilsReplaceAWithB(gutils, a, b) = ccall((:EnzymeGradientUtilsReplaceAWithB, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, LLVMValueRef), gutils, a, b) -EnzymeGradientUtilsErase(gutils, a) = ccall((:EnzymeGradientUtilsErase, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef), gutils, a) -EnzymeGradientUtilsEraseWithPlaceholder(gutils, a, orig, erase) = ccall((:EnzymeGradientUtilsEraseWithPlaceholder, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, LLVMValueRef, UInt8), gutils, a, orig, erase) -EnzymeGradientUtilsGetMode(gutils) = ccall((:EnzymeGradientUtilsGetMode, libEnzyme), CDerivativeMode, (EnzymeGradientUtilsRef,), gutils) -EnzymeGradientUtilsGetWidth(gutils) = ccall((:EnzymeGradientUtilsGetWidth, libEnzyme), UInt64, (EnzymeGradientUtilsRef,), gutils) -EnzymeGradientUtilsGetRuntimeActivity(gutils) = ccall((:EnzymeGradientUtilsGetRuntimeActivity, libEnzyme), UInt8, (EnzymeGradientUtilsRef,), gutils) != 0 -EnzymeGradientUtilsNewFromOriginal(gutils, val) = ccall((:EnzymeGradientUtilsNewFromOriginal, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef), gutils, val) -EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, val, orig) = ccall((:EnzymeGradientUtilsSetDebugLocFromOriginal, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef), gutils, val, orig) -EnzymeGradientUtilsLookup(gutils, val, B) = ccall((:EnzymeGradientUtilsLookup, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), gutils, val, B) -EnzymeGradientUtilsInvertPointer(gutils, val, B) = ccall((:EnzymeGradientUtilsInvertPointer, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), gutils, val, B) -EnzymeGradientUtilsDiffe(gutils, val, B) = ccall((:EnzymeGradientUtilsDiffe, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), gutils, val, B) -EnzymeGradientUtilsAddToDiffe(gutils, val, diffe, B, T) = ccall((:EnzymeGradientUtilsAddToDiffe, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMTypeRef), gutils, val, diffe, B, T) -function EnzymeGradientUtilsAddToInvertedPointerDiffeTT(gutils, orig, origVal, vd, size, origptr, prediff, B, align, premask) - ccall((:EnzymeGradientUtilsAddToInvertedPointerDiffeTT, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef, CTypeTreeRef, Cuint, LLVMValueRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, Cuint, LLVMValueRef), gutils, orig, origVal, vd, size, origptr, prediff, B, align, premask) -end - -EnzymeGradientUtilsSetDiffe(gutils, val, diffe, B) = ccall((:EnzymeGradientUtilsSetDiffe, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), gutils, val, diffe, B) -EnzymeGradientUtilsIsConstantValue(gutils, val) = ccall((:EnzymeGradientUtilsIsConstantValue, libEnzyme), UInt8, (EnzymeGradientUtilsRef, LLVMValueRef), gutils, val) -EnzymeGradientUtilsIsConstantInstruction(gutils, val) = ccall((:EnzymeGradientUtilsIsConstantInstruction, libEnzyme), UInt8, (EnzymeGradientUtilsRef, LLVMValueRef), gutils, val) -EnzymeGradientUtilsAllocationBlock(gutils) = ccall((:EnzymeGradientUtilsAllocationBlock, libEnzyme), LLVM.API.LLVMBasicBlockRef, (EnzymeGradientUtilsRef,), gutils) - -EnzymeGradientUtilsTypeAnalyzer(gutils) = ccall((:EnzymeGradientUtilsTypeAnalyzer, libEnzyme), EnzymeTypeAnalyzerRef, (EnzymeGradientUtilsRef,), gutils) - -EnzymeGradientUtilsAllocAndGetTypeTree(gutils, val) = ccall((:EnzymeGradientUtilsAllocAndGetTypeTree, libEnzyme), CTypeTreeRef, (EnzymeGradientUtilsRef,LLVMValueRef), gutils, val) - -EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall((:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), UInt8, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, UInt64), gutils, orig, uncacheable, size) - -EnzymeGradientUtilsGetDiffeType(gutils, op, isforeign) = ccall((:EnzymeGradientUtilsGetDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, UInt8), gutils, op, isforeign) - -EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) = ccall((:EnzymeGradientUtilsGetReturnDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, Ptr{UInt8}, CDerivativeMode), gutils, orig, needsPrimalP, needsShadowP, mode) - -EnzymeGradientUtilsSubTransferHelper(gutils, mode, secretty, intrinsic, dstAlign, srcAlign, offset, dstConstant, origdst, srcConstant, origsrc, length, isVolatile, MTI, allowForward, shadowsLookedUp) = ccall((:EnzymeGradientUtilsSubTransferHelper, libEnzyme), - Cvoid, - ( EnzymeGradientUtilsRef, CDerivativeMode, LLVMTypeRef, UInt64, UInt64, UInt64, UInt64, UInt8, LLVMValueRef, UInt8, LLVMValueRef, LLVMValueRef, LLVMValueRef, LLVMValueRef, UInt8, UInt8), - gutils, mode, secretty, intrinsic, dstAlign, srcAlign, offset, dstConstant, origdst, srcConstant, origsrc, length, isVolatile, MTI, allowForward, shadowsLookedUp) - -EnzymeGradientUtilsCallWithInvertedBundles(gutils, func, funcTy, argvs, argc, orig, valTys, valCnt, B, lookup) = ccall((:EnzymeGradientUtilsCallWithInvertedBundles, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef,LLVMValueRef, LLVMTypeRef, Ptr{LLVMValueRef}, UInt64, LLVMValueRef, Ptr{CValueType}, UInt64, LLVM.API.LLVMBuilderRef, UInt8), gutils, func, funcTy, argvs, argc, orig, valTys, valCnt, B, lookup) - -function sub_transfer(gutils, mode, secretty, intrinsic, dstAlign, srcAlign, offset, dstConstant, origdst, srcConstant, origsrc, length, isVolatile, MTI, allowForward, shadowsLookedUp) +EnzymeRegisterDiffUseCallHandler(name, handle) = ccall( + (:EnzymeRegisterDiffUseCallHandler, libEnzyme), + Cvoid, + (Cstring, CustomDiffUse), + name, + handle, +) +EnzymeSetCalledFunction(ci::LLVM.CallInst, fn::LLVM.Function, toremove) = ccall( + (:EnzymeSetCalledFunction, libEnzyme), + Cvoid, + (LLVMValueRef, LLVMValueRef, Ptr{Int64}, Int64), + ci, + fn, + toremove, + length(toremove), +) +EnzymeCloneFunctionWithoutReturnOrArgs(fn::LLVM.Function, keepret, args) = ccall( + (:EnzymeCloneFunctionWithoutReturnOrArgs, libEnzyme), + LLVMValueRef, + (LLVMValueRef, UInt8, Ptr{Int64}, Int64), + fn, + keepret, + args, + length(args), +) +EnzymeGetShadowType(width, T) = + ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64, LLVMTypeRef), width, T) + +EnzymeGradientUtilsReplaceAWithB(gutils, a, b) = ccall( + (:EnzymeGradientUtilsReplaceAWithB, libEnzyme), + Cvoid, + (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef), + gutils, + a, + b, +) +EnzymeGradientUtilsErase(gutils, a) = ccall( + (:EnzymeGradientUtilsErase, libEnzyme), + Cvoid, + (EnzymeGradientUtilsRef, LLVMValueRef), + gutils, + a, +) +EnzymeGradientUtilsEraseWithPlaceholder(gutils, a, orig, erase) = ccall( + (:EnzymeGradientUtilsEraseWithPlaceholder, libEnzyme), + Cvoid, + (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef, UInt8), + gutils, + a, + orig, + erase, +) +EnzymeGradientUtilsGetMode(gutils) = ccall( + (:EnzymeGradientUtilsGetMode, libEnzyme), + CDerivativeMode, + (EnzymeGradientUtilsRef,), + gutils, +) +EnzymeGradientUtilsGetWidth(gutils) = ccall( + (:EnzymeGradientUtilsGetWidth, libEnzyme), + UInt64, + (EnzymeGradientUtilsRef,), + gutils, +) +EnzymeGradientUtilsGetRuntimeActivity(gutils) = + ccall( + (:EnzymeGradientUtilsGetRuntimeActivity, libEnzyme), + UInt8, + (EnzymeGradientUtilsRef,), + gutils, + ) != 0 +EnzymeGradientUtilsNewFromOriginal(gutils, val) = ccall( + (:EnzymeGradientUtilsNewFromOriginal, libEnzyme), + LLVMValueRef, + (EnzymeGradientUtilsRef, LLVMValueRef), + gutils, + val, +) +EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, val, orig) = ccall( + (:EnzymeGradientUtilsSetDebugLocFromOriginal, libEnzyme), + Cvoid, + (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef), + gutils, + val, + orig, +) +EnzymeGradientUtilsLookup(gutils, val, B) = ccall( + (:EnzymeGradientUtilsLookup, libEnzyme), + LLVMValueRef, + (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), + gutils, + val, + B, +) +EnzymeGradientUtilsInvertPointer(gutils, val, B) = ccall( + (:EnzymeGradientUtilsInvertPointer, libEnzyme), + LLVMValueRef, + (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), + gutils, + val, + B, +) +EnzymeGradientUtilsDiffe(gutils, val, B) = ccall( + (:EnzymeGradientUtilsDiffe, libEnzyme), + LLVMValueRef, + (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), + gutils, + val, + B, +) +EnzymeGradientUtilsAddToDiffe(gutils, val, diffe, B, T) = ccall( + (:EnzymeGradientUtilsAddToDiffe, libEnzyme), + Cvoid, + ( + EnzymeGradientUtilsRef, + LLVMValueRef, + LLVMValueRef, + LLVM.API.LLVMBuilderRef, + LLVMTypeRef, + ), + gutils, + val, + diffe, + B, + T, +) +function EnzymeGradientUtilsAddToInvertedPointerDiffeTT( + gutils, + orig, + origVal, + vd, + size, + origptr, + prediff, + B, + align, + premask, +) + ccall( + (:EnzymeGradientUtilsAddToInvertedPointerDiffeTT, libEnzyme), + Cvoid, + ( + EnzymeGradientUtilsRef, + LLVMValueRef, + LLVMValueRef, + CTypeTreeRef, + Cuint, + LLVMValueRef, + LLVMValueRef, + LLVM.API.LLVMBuilderRef, + Cuint, + LLVMValueRef, + ), + gutils, + orig, + origVal, + vd, + size, + origptr, + prediff, + B, + align, + premask, + ) +end + +EnzymeGradientUtilsSetDiffe(gutils, val, diffe, B) = ccall( + (:EnzymeGradientUtilsSetDiffe, libEnzyme), + Cvoid, + (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), + gutils, + val, + diffe, + B, +) +EnzymeGradientUtilsIsConstantValue(gutils, val) = ccall( + (:EnzymeGradientUtilsIsConstantValue, libEnzyme), + UInt8, + (EnzymeGradientUtilsRef, LLVMValueRef), + gutils, + val, +) +EnzymeGradientUtilsIsConstantInstruction(gutils, val) = ccall( + (:EnzymeGradientUtilsIsConstantInstruction, libEnzyme), + UInt8, + (EnzymeGradientUtilsRef, LLVMValueRef), + gutils, + val, +) +EnzymeGradientUtilsAllocationBlock(gutils) = ccall( + (:EnzymeGradientUtilsAllocationBlock, libEnzyme), + LLVM.API.LLVMBasicBlockRef, + (EnzymeGradientUtilsRef,), + gutils, +) + +EnzymeGradientUtilsTypeAnalyzer(gutils) = ccall( + (:EnzymeGradientUtilsTypeAnalyzer, libEnzyme), + EnzymeTypeAnalyzerRef, + (EnzymeGradientUtilsRef,), + gutils, +) + +EnzymeGradientUtilsAllocAndGetTypeTree(gutils, val) = ccall( + (:EnzymeGradientUtilsAllocAndGetTypeTree, libEnzyme), + CTypeTreeRef, + (EnzymeGradientUtilsRef, LLVMValueRef), + gutils, + val, +) + +EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall( + (:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), + UInt8, + (EnzymeGradientUtilsRef, LLVMValueRef, Ptr{UInt8}, UInt64), + gutils, + orig, + uncacheable, + size, +) + +EnzymeGradientUtilsGetDiffeType(gutils, op, isforeign) = ccall( + (:EnzymeGradientUtilsGetDiffeType, libEnzyme), + CDIFFE_TYPE, + (EnzymeGradientUtilsRef, LLVMValueRef, UInt8), + gutils, + op, + isforeign, +) + +EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) = + ccall( + (:EnzymeGradientUtilsGetReturnDiffeType, libEnzyme), + CDIFFE_TYPE, + (EnzymeGradientUtilsRef, LLVMValueRef, Ptr{UInt8}, Ptr{UInt8}, CDerivativeMode), + gutils, + orig, + needsPrimalP, + needsShadowP, + mode, + ) + +EnzymeGradientUtilsSubTransferHelper( + gutils, + mode, + secretty, + intrinsic, + dstAlign, + srcAlign, + offset, + dstConstant, + origdst, + srcConstant, + origsrc, + length, + isVolatile, + MTI, + allowForward, + shadowsLookedUp, +) = ccall( + (:EnzymeGradientUtilsSubTransferHelper, libEnzyme), + Cvoid, + ( + EnzymeGradientUtilsRef, + CDerivativeMode, + LLVMTypeRef, + UInt64, + UInt64, + UInt64, + UInt64, + UInt8, + LLVMValueRef, + UInt8, + LLVMValueRef, + LLVMValueRef, + LLVMValueRef, + LLVMValueRef, + UInt8, + UInt8, + ), + gutils, + mode, + secretty, + intrinsic, + dstAlign, + srcAlign, + offset, + dstConstant, + origdst, + srcConstant, + origsrc, + length, + isVolatile, + MTI, + allowForward, + shadowsLookedUp, +) + +EnzymeGradientUtilsCallWithInvertedBundles( + gutils, + func, + funcTy, + argvs, + argc, + orig, + valTys, + valCnt, + B, + lookup, +) = ccall( + (:EnzymeGradientUtilsCallWithInvertedBundles, libEnzyme), + LLVMValueRef, + ( + EnzymeGradientUtilsRef, + LLVMValueRef, + LLVMTypeRef, + Ptr{LLVMValueRef}, + UInt64, + LLVMValueRef, + Ptr{CValueType}, + UInt64, + LLVM.API.LLVMBuilderRef, + UInt8, + ), + gutils, + func, + funcTy, + argvs, + argc, + orig, + valTys, + valCnt, + B, + lookup, +) + +function sub_transfer( + gutils, + mode, + secretty, + intrinsic, + dstAlign, + srcAlign, + offset, + dstConstant, + origdst, + srcConstant, + origsrc, + length, + isVolatile, + MTI, + allowForward, + shadowsLookedUp, +) GC.@preserve secretty begin if secretty === nothing secretty = Base.unsafe_convert(LLVMTypeRef, C_NULL) @@ -293,15 +879,37 @@ function sub_transfer(gutils, mode, secretty, intrinsic, dstAlign, srcAlign, off secretty = Base.unsafe_convert(LLVMTypeRef, secretty) end - EnzymeGradientUtilsSubTransferHelper(gutils, mode, secretty, intrinsic, dstAlign, srcAlign, offset, dstConstant, origdst, srcConstant, origsrc, length, isVolatile, MTI, allowForward, shadowsLookedUp) + EnzymeGradientUtilsSubTransferHelper( + gutils, + mode, + secretty, + intrinsic, + dstAlign, + srcAlign, + offset, + dstConstant, + origdst, + srcConstant, + origsrc, + length, + isVolatile, + MTI, + allowForward, + shadowsLookedUp, + ) end end -function CreateLogic(postOpt=false) +function CreateLogic(postOpt = false) ccall((:CreateEnzymeLogic, libEnzyme), EnzymeLogicRef, (UInt8,), postOpt) end -EnzymeLogicErasePreprocessedFunctions(logic) = ccall((:EnzymeLogicErasePreprocessedFunctions, libEnzyme), Cvoid, (EnzymeLogicRef,), logic) +EnzymeLogicErasePreprocessedFunctions(logic) = ccall( + (:EnzymeLogicErasePreprocessedFunctions, libEnzyme), + Cvoid, + (EnzymeLogicRef,), + logic, +) function ClearLogic(logic) ccall((:ClearEnzymeLogic, libEnzyme), Cvoid, (EnzymeLogicRef,), logic) @@ -313,22 +921,43 @@ end function EnzymeExtractReturnInfo(ret, data, existed) @assert length(data) == length(existed) - ccall((:EnzymeExtractReturnInfo, libEnzyme), - Cvoid, (EnzymeAugmentedReturnPtr, Ptr{Int64}, Ptr{UInt8}, Csize_t), - ret, data, existed, length(data)) + ccall( + (:EnzymeExtractReturnInfo, libEnzyme), + Cvoid, + (EnzymeAugmentedReturnPtr, Ptr{Int64}, Ptr{UInt8}, Csize_t), + ret, + data, + existed, + length(data), + ) end function EnzymeExtractFunctionFromAugmentation(ret) - ccall((:EnzymeExtractFunctionFromAugmentation, libEnzyme), LLVMValueRef, (EnzymeAugmentedReturnPtr,), ret) + ccall( + (:EnzymeExtractFunctionFromAugmentation, libEnzyme), + LLVMValueRef, + (EnzymeAugmentedReturnPtr,), + ret, + ) end function EnzymeExtractTapeTypeFromAugmentation(ret) - ccall((:EnzymeExtractTapeTypeFromAugmentation, libEnzyme), LLVMTypeRef, (EnzymeAugmentedReturnPtr,), ret) + ccall( + (:EnzymeExtractTapeTypeFromAugmentation, libEnzyme), + LLVMTypeRef, + (EnzymeAugmentedReturnPtr,), + ret, + ) end function EnzymeExtractUnderlyingTapeTypeFromAugmentation(ret) - ccall((:EnzymeExtractUnderlyingTapeTypeFromAugmentation, libEnzyme), LLVMTypeRef, (EnzymeAugmentedReturnPtr,), ret) + ccall( + (:EnzymeExtractUnderlyingTapeTypeFromAugmentation, libEnzyme), + LLVMTypeRef, + (EnzymeAugmentedReturnPtr,), + ret, + ) end import Libdl @@ -598,28 +1227,44 @@ function EnzymeRemoveTrivialAtomicIncrements(func) end function EnzymeAddAttributorLegacyPass(PM) - ccall((:EnzymeAddAttributorLegacyPass, libEnzyme),Cvoid,(LLVM.API.LLVMPassManagerRef,), PM) -end - -@cenum(ErrorType, - ET_NoDerivative = 0, - ET_NoShadow = 1, - ET_IllegalTypeAnalysis = 2, - ET_NoType = 3, - ET_IllegalFirstPointer = 4, - ET_InternalError = 5, - ET_TypeDepthExceeded = 6, - ET_MixedActivityError = 7, - ET_IllegalReplaceFicticiousPHIs = 8, - ET_GetIndexError = 9 + ccall( + (:EnzymeAddAttributorLegacyPass, libEnzyme), + Cvoid, + (LLVM.API.LLVMPassManagerRef,), + PM, + ) +end + +@cenum( + ErrorType, + ET_NoDerivative = 0, + ET_NoShadow = 1, + ET_IllegalTypeAnalysis = 2, + ET_NoType = 3, + ET_IllegalFirstPointer = 4, + ET_InternalError = 5, + ET_TypeDepthExceeded = 6, + ET_MixedActivityError = 7, + ET_IllegalReplaceFicticiousPHIs = 8, + ET_GetIndexError = 9 ) function EnzymeTypeAnalyzerToString(typeanalyzer) - ccall((:EnzymeTypeAnalyzerToString, libEnzyme), Cstring, (EnzymeTypeAnalyzerRef,), typeanalyzer) + ccall( + (:EnzymeTypeAnalyzerToString, libEnzyme), + Cstring, + (EnzymeTypeAnalyzerRef,), + typeanalyzer, + ) end function EnzymeGradientUtilsInvertedPointersToString(gutils) - ccall((:EnzymeGradientUtilsInvertedPointersToString, libEnzyme), Cstring, (Ptr{Cvoid},), gutils) + ccall( + (:EnzymeGradientUtilsInvertedPointersToString, libEnzyme), + Cstring, + (Ptr{Cvoid},), + gutils, + ) end function EnzymeSetHandler(handler) @@ -694,60 +1339,162 @@ function __init__() end function moveBefore(i1, i2, BR) - ccall((:EnzymeMoveBefore, libEnzyme),Cvoid,(LLVM.API.LLVMValueRef,LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef), i1, i2, BR) + ccall( + (:EnzymeMoveBefore, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef), + i1, + i2, + BR, + ) end function EnzymeCloneFunctionDISubprogramInto(i1, i2) - ccall((:EnzymeCloneFunctionDISubprogramInto, libEnzyme),Cvoid,(LLVM.API.LLVMValueRef,LLVM.API.LLVMValueRef), i1, i2) + ccall( + (:EnzymeCloneFunctionDISubprogramInto, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef), + i1, + i2, + ) end function EnzymeCopyMetadata(i1, i2) - ccall((:EnzymeCopyMetadata, libEnzyme),Cvoid,(LLVM.API.LLVMValueRef,LLVM.API.LLVMValueRef), i1, i2) + ccall( + (:EnzymeCopyMetadata, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef), + i1, + i2, + ) end function SetMustCache!(i1) - ccall((:EnzymeSetMustCache, libEnzyme),Cvoid,(LLVM.API.LLVMValueRef,), i1) + ccall((:EnzymeSetMustCache, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), i1) end function SetForMemSet!(i1) - ccall((:EnzymeSetForMemSet, libEnzyme),Cvoid,(LLVM.API.LLVMValueRef,), i1) + ccall((:EnzymeSetForMemSet, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), i1) end function HasFromStack(i1) - ccall((:EnzymeHasFromStack, libEnzyme),UInt8,(LLVM.API.LLVMValueRef,), i1) != 0 + ccall((:EnzymeHasFromStack, libEnzyme), UInt8, (LLVM.API.LLVMValueRef,), i1) != 0 end function AddPreserveNVVMPass!(pm, i8) - ccall((:AddPreserveNVVMPass, libEnzyme),Cvoid,(LLVM.API.LLVMPassManagerRef,UInt8), pm, i8) + ccall( + (:AddPreserveNVVMPass, libEnzyme), + Cvoid, + (LLVM.API.LLVMPassManagerRef, UInt8), + pm, + i8, + ) end function EnzymeReplaceFunctionImplementation(mod) - ccall((:EnzymeReplaceFunctionImplementation, libEnzyme),Cvoid,(LLVM.API.LLVMModuleRef,), mod) + ccall( + (:EnzymeReplaceFunctionImplementation, libEnzyme), + Cvoid, + (LLVM.API.LLVMModuleRef,), + mod, + ) end function EnzymeDumpModuleRef(mod) - ccall((:EnzymeDumpModuleRef, libEnzyme),Cvoid,(LLVM.API.LLVMModuleRef,), mod) -end - -EnzymeComputeByteOffsetOfGEP(B, V, T) = LLVM.Value(ccall((:EnzymeComputeByteOffsetOfGEP, libEnzyme), LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMTypeRef), B, V, T)) - -EnzymeAllocaType(al) = LLVM.LLVMType(ccall((:EnzymeAllocaType, libEnzyme), LLVM.API.LLVMTypeRef, (LLVM.API.LLVMValueRef,), al)) - -EnzymeAttributeKnownFunctions(f) = ccall((:EnzymeAttributeKnownFunctions, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), f) + ccall((:EnzymeDumpModuleRef, libEnzyme), Cvoid, (LLVM.API.LLVMModuleRef,), mod) +end + +EnzymeComputeByteOffsetOfGEP(B, V, T) = LLVM.Value( + ccall( + (:EnzymeComputeByteOffsetOfGEP, libEnzyme), + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMTypeRef), + B, + V, + T, + ), +) -EnzymeAnonymousAliasScopeDomain(str, ctx) = LLVM.Metadata(ccall((:EnzymeAnonymousAliasScopeDomain, libEnzyme), LLVM.API.LLVMMetadataRef, (Cstring,LLVMContextRef), str, ctx)) -EnzymeAnonymousAliasScope(dom::LLVM.Metadata, str) = LLVM.Metadata(ccall((:EnzymeAnonymousAliasScope, libEnzyme), LLVM.API.LLVMMetadataRef, (LLVM.API.LLVMMetadataRef,Cstring), dom.ref, str)) -EnzymeFixupJuliaCallingConvention(f) = ccall((:EnzymeFixupJuliaCallingConvention, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), f) -EnzymeFixupBatchedJuliaCallingConvention(f) = ccall((:EnzymeFixupBatchedJuliaCallingConvention, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), f) +EnzymeAllocaType(al) = LLVM.LLVMType( + ccall( + (:EnzymeAllocaType, libEnzyme), + LLVM.API.LLVMTypeRef, + (LLVM.API.LLVMValueRef,), + al, + ), +) -e_extract_value!(builder, AggVal, Index, Name::String="") = - GC.@preserve Index begin - LLVM.Value(ccall((:EnzymeBuildExtractValue, libEnzyme), LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Ptr{Cuint}, Cuint, Cstring), builder, AggVal, Index, length(Index), Name)) - end +EnzymeAttributeKnownFunctions(f) = + ccall((:EnzymeAttributeKnownFunctions, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), f) + +EnzymeAnonymousAliasScopeDomain(str, ctx) = LLVM.Metadata( + ccall( + (:EnzymeAnonymousAliasScopeDomain, libEnzyme), + LLVM.API.LLVMMetadataRef, + (Cstring, LLVMContextRef), + str, + ctx, + ), +) +EnzymeAnonymousAliasScope(dom::LLVM.Metadata, str) = LLVM.Metadata( + ccall( + (:EnzymeAnonymousAliasScope, libEnzyme), + LLVM.API.LLVMMetadataRef, + (LLVM.API.LLVMMetadataRef, Cstring), + dom.ref, + str, + ), +) +EnzymeFixupJuliaCallingConvention(f) = ccall( + (:EnzymeFixupJuliaCallingConvention, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef,), + f, +) +EnzymeFixupBatchedJuliaCallingConvention(f) = ccall( + (:EnzymeFixupBatchedJuliaCallingConvention, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef,), + f, +) -e_insert_value!(builder, AggVal, EltVal, Index, Name::String="") = - GC.@preserve Index begin - LLVM.Value(ccall((:EnzymeBuildInsertValue, libEnzyme), LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef, Ptr{Cuint}, Cuint, Cstring), builder, AggVal, EltVal, Index, length(Index), Name)) - end +e_extract_value!(builder, AggVal, Index, Name::String = "") = GC.@preserve Index begin + LLVM.Value( + ccall( + (:EnzymeBuildExtractValue, libEnzyme), + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Ptr{Cuint}, Cuint, Cstring), + builder, + AggVal, + Index, + length(Index), + Name, + ), + ) +end + +e_insert_value!(builder, AggVal, EltVal, Index, Name::String = "") = + GC.@preserve Index begin + LLVM.Value( + ccall( + (:EnzymeBuildInsertValue, libEnzyme), + LLVM.API.LLVMValueRef, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + LLVM.API.LLVMValueRef, + Ptr{Cuint}, + Cuint, + Cstring, + ), + builder, + AggVal, + EltVal, + Index, + length(Index), + Name, + ), + ) + end end diff --git a/src/compiler.jl b/src/compiler.jl index ce51a6e7f5..f3680cc4c0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,12 +1,32 @@ module Compiler import ..Enzyme -import Enzyme: Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, - BatchDuplicatedNoNeed, - BatchDuplicatedFunc, - Annotation, guess_activity, eltype, - API, TypeTree, typetree, TypeTreeTable, only!, shift!, data0!, merge!, to_md, to_fullmd, - TypeAnalysis, FnTypeInfo, Logic, allocatedinline, ismutabletype +import Enzyme: + Const, + Active, + Duplicated, + DuplicatedNoNeed, + BatchDuplicated, + BatchDuplicatedNoNeed, + BatchDuplicatedFunc, + Annotation, + guess_activity, + eltype, + API, + TypeTree, + typetree, + TypeTreeTable, + only!, + shift!, + data0!, + merge!, + to_md, + to_fullmd, + TypeAnalysis, + FnTypeInfo, + Logic, + allocatedinline, + ismutabletype using Enzyme import EnzymeCore @@ -45,12 +65,11 @@ include("gradientutils.jl") # Julia function to LLVM stem and arity const cmplx_known_ops = -Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( - typeof(Base.inv) => (:cmplx_inv, 1, nothing), - typeof(Base.sqrt) => (:cmplx_sqrt, 1, nothing), - ) -const known_ops = -Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( + Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,DataType}}}}( + typeof(Base.inv) => (:cmplx_inv, 1, nothing), + typeof(Base.sqrt) => (:cmplx_sqrt, 1, nothing), + ) +const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,DataType}}}}( typeof(Base.cbrt) => (:cbrt, 1, nothing), typeof(Base.rem2pi) => (:jl_rem2pi, 2, nothing), typeof(Base.sqrt) => (:sqrt, 1, nothing), @@ -85,7 +104,7 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( typeof(Base.tanh) => (:tanh, 1, nothing), typeof(Base.ldexp) => (:ldexp, 2, nothing), typeof(Base.FastMath.tanh_fast) => (:tanh, 1, nothing), - typeof(Base.fma_emulated) => (:fma, 3, nothing) + typeof(Base.fma_emulated) => (:fma, 3, nothing), ) @inline function find_math_method(@nospecialize(func), sparam_vals) if func ∈ keys(known_ops) @@ -118,7 +137,7 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( return name, toinject, T end end - end + end if func ∈ keys(cmplx_known_ops) name, arity, toinject = cmplx_known_ops[func] @@ -146,67 +165,123 @@ const nofreefns = Set{String}(( "pcre2_jit_stack_assign_8", "pcre2_match_context_create_8", "pcre2_jit_stack_create_8", - "ijl_gc_enable_finalizers_internal", "jl_gc_enable_finalizers_internal", + "ijl_gc_enable_finalizers_internal", + "jl_gc_enable_finalizers_internal", "pcre2_match_data_create_from_pattern_8", - "ijl_gc_run_pending_finalizers", "jl_gc_run_pending_finalizers", - "ijl_typeassert", "jl_typeassert", - "ijl_f_isdefined", "jl_f_isdefined", - "ijl_field_index", "jl_field_index", - "ijl_specializations_get_linfo", "jl_specializations_get_linfo", - "ijl_gf_invoke_lookup_worlds", "jl_gf_invoke_lookup_worlds", - "ijl_gc_get_total_bytes", "jl_gc_get_total_bytes", - "ijl_array_grow_at", "jl_array_grow_at", - "ijl_try_substrtod", "jl_try_substrtod", + "ijl_gc_run_pending_finalizers", + "jl_gc_run_pending_finalizers", + "ijl_typeassert", + "jl_typeassert", + "ijl_f_isdefined", + "jl_f_isdefined", + "ijl_field_index", + "jl_field_index", + "ijl_specializations_get_linfo", + "jl_specializations_get_linfo", + "ijl_gf_invoke_lookup_worlds", + "jl_gf_invoke_lookup_worlds", + "ijl_gc_get_total_bytes", + "jl_gc_get_total_bytes", + "ijl_array_grow_at", + "jl_array_grow_at", + "ijl_try_substrtod", + "jl_try_substrtod", "jl_f__apply_iterate", - "ijl_field_index", "jl_field_index", - "julia.call", "julia.call2", - "ijl_tagged_gensym", "jl_tagged_gensym", - "ijl_array_ptr_copy", "jl_array_ptr_copy", - "ijl_array_copy", "jl_array_copy", - "ijl_get_nth_field_checked", "ijl_get_nth_field_checked", - "jl_array_del_end","ijl_array_del_end", - "jl_get_world_counter", "ijl_get_world_counter", - "memhash32_seed", "memhash_seed", - "ijl_module_parent", "jl_module_parent", + "ijl_field_index", + "jl_field_index", + "julia.call", + "julia.call2", + "ijl_tagged_gensym", + "jl_tagged_gensym", + "ijl_array_ptr_copy", + "jl_array_ptr_copy", + "ijl_array_copy", + "jl_array_copy", + "ijl_get_nth_field_checked", + "ijl_get_nth_field_checked", + "jl_array_del_end", + "ijl_array_del_end", + "jl_get_world_counter", + "ijl_get_world_counter", + "memhash32_seed", + "memhash_seed", + "ijl_module_parent", + "jl_module_parent", "julia.safepoint", - "ijl_set_task_tid", "jl_set_task_tid", - "ijl_get_task_tid", "jl_get_task_tid", + "ijl_set_task_tid", + "jl_set_task_tid", + "ijl_get_task_tid", + "jl_get_task_tid", "julia.get_pgcstack_or_new", - "ijl_global_event_loop", "jl_global_event_loop", - "ijl_gf_invoke_lookup", "jl_gf_invoke_lookup", - "ijl_f_typeassert", "jl_f_typeassert", - "ijl_type_unionall", "jl_type_unionall", - "jl_gc_queue_root", "gpu_report_exception", "gpu_signal_exception", - "julia.ptls_states", "julia.write_barrier", "julia.typeof", - "jl_backtrace_from_here", "ijl_backtrace_from_here", - "jl_box_int64", "jl_box_int32", - "ijl_box_int64", "ijl_box_int32", - "jl_box_uint64", "jl_box_uint32", - "ijl_box_uint64", "ijl_box_uint32", - "ijl_box_char", "jl_box_char", + "ijl_global_event_loop", + "jl_global_event_loop", + "ijl_gf_invoke_lookup", + "jl_gf_invoke_lookup", + "ijl_f_typeassert", + "jl_f_typeassert", + "ijl_type_unionall", + "jl_type_unionall", + "jl_gc_queue_root", + "gpu_report_exception", + "gpu_signal_exception", + "julia.ptls_states", + "julia.write_barrier", + "julia.typeof", + "jl_backtrace_from_here", + "ijl_backtrace_from_here", + "jl_box_int64", + "jl_box_int32", + "ijl_box_int64", + "ijl_box_int32", + "jl_box_uint64", + "jl_box_uint32", + "ijl_box_uint64", + "ijl_box_uint32", + "ijl_box_char", + "jl_box_char", "ijl_subtype", - "jl_subtype", "julia.get_pgcstack", "jl_in_threaded_region", - "jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id", + "jl_subtype", + "julia.get_pgcstack", + "jl_in_threaded_region", + "jl_object_id_", + "jl_object_id", + "ijl_object_id_", + "ijl_object_id", "jl_breakpoint", - "llvm.julia.gc_preserve_begin","llvm.julia.gc_preserve_end", + "llvm.julia.gc_preserve_begin", + "llvm.julia.gc_preserve_end", "jl_get_ptls_states", "ijl_get_ptls_states", "jl_f_fieldtype", "jl_symbol_n", - "jl_stored_inline", "ijl_stored_inline", - "jl_f_apply_type", "jl_f_issubtype", - "jl_isa", "ijl_isa", - "jl_matching_methods", "ijl_matching_methods", - "jl_excstack_state", "ijl_excstack_state", - "jl_current_exception", "ijl_current_exception", + "jl_stored_inline", + "ijl_stored_inline", + "jl_f_apply_type", + "jl_f_issubtype", + "jl_isa", + "ijl_isa", + "jl_matching_methods", + "ijl_matching_methods", + "jl_excstack_state", + "ijl_excstack_state", + "jl_current_exception", + "ijl_current_exception", "memhash_seed", - "jl_f__typevar", "ijl_f__typevar", - "jl_f_isa", "ijl_f_isa", - "jl_set_task_threadpoolid", "ijl_set_task_threadpoolid", - "jl_types_equal", "ijl_types_equal", - "jl_invoke", "ijl_invoke", - "jl_apply_generic", "ijl_apply_generic", - "jl_egal__unboxed", "julia.pointer_from_objref", "_platform_memcmp", + "jl_f__typevar", + "ijl_f__typevar", + "jl_f_isa", + "ijl_f_isa", + "jl_set_task_threadpoolid", + "ijl_set_task_threadpoolid", + "jl_types_equal", + "ijl_types_equal", + "jl_invoke", + "ijl_invoke", + "jl_apply_generic", + "ijl_apply_generic", + "jl_egal__unboxed", + "julia.pointer_from_objref", + "_platform_memcmp", "memcmp", "julia.except_enter", "jl_array_grow_end", @@ -233,53 +308,96 @@ const nofreefns = Set{String}(( const inactivefns = Set{String}(( "pcre2_match_data_create_from_pattern_8", - "ijl_typeassert", "jl_typeassert", - "ijl_f_isdefined", "jl_f_isdefined", - "ijl_field_index", "jl_field_index", - "ijl_specializations_get_linfo", "jl_specializations_get_linfo", - "ijl_gf_invoke_lookup_worlds", "jl_gf_invoke_lookup_worlds", - "ijl_gc_get_total_bytes", "jl_gc_get_total_bytes", - "ijl_try_substrtod", "jl_try_substrtod", - "ijl_tagged_gensym", "jl_tagged_gensym", - "jl_get_world_counter", "ijl_get_world_counter", - "memhash32_seed", "memhash_seed", - "ijl_module_parent", "jl_module_parent", + "ijl_typeassert", + "jl_typeassert", + "ijl_f_isdefined", + "jl_f_isdefined", + "ijl_field_index", + "jl_field_index", + "ijl_specializations_get_linfo", + "jl_specializations_get_linfo", + "ijl_gf_invoke_lookup_worlds", + "jl_gf_invoke_lookup_worlds", + "ijl_gc_get_total_bytes", + "jl_gc_get_total_bytes", + "ijl_try_substrtod", + "jl_try_substrtod", + "ijl_tagged_gensym", + "jl_tagged_gensym", + "jl_get_world_counter", + "ijl_get_world_counter", + "memhash32_seed", + "memhash_seed", + "ijl_module_parent", + "jl_module_parent", "julia.safepoint", - "ijl_set_task_tid", "jl_set_task_tid", - "ijl_get_task_tid", "jl_get_task_tid", + "ijl_set_task_tid", + "jl_set_task_tid", + "ijl_get_task_tid", + "jl_get_task_tid", "julia.get_pgcstack_or_new", - "ijl_global_event_loop", "jl_global_event_loop", - "ijl_gf_invoke_lookup", "jl_gf_invoke_lookup", - "ijl_f_typeassert", "jl_f_typeassert", - "ijl_type_unionall", "jl_type_unionall", - "jl_gc_queue_root", "gpu_report_exception", "gpu_signal_exception", - "julia.ptls_states", "julia.write_barrier", "julia.typeof", - "jl_backtrace_from_here", "ijl_backtrace_from_here", - "jl_box_int64", "jl_box_int32", - "ijl_box_int64", "ijl_box_int32", - "jl_box_uint64", "jl_box_uint32", - "ijl_box_uint64", "ijl_box_uint32", - "ijl_box_char", "jl_box_char", + "ijl_global_event_loop", + "jl_global_event_loop", + "ijl_gf_invoke_lookup", + "jl_gf_invoke_lookup", + "ijl_f_typeassert", + "jl_f_typeassert", + "ijl_type_unionall", + "jl_type_unionall", + "jl_gc_queue_root", + "gpu_report_exception", + "gpu_signal_exception", + "julia.ptls_states", + "julia.write_barrier", + "julia.typeof", + "jl_backtrace_from_here", + "ijl_backtrace_from_here", + "jl_box_int64", + "jl_box_int32", + "ijl_box_int64", + "ijl_box_int32", + "jl_box_uint64", + "jl_box_uint32", + "ijl_box_uint64", + "ijl_box_uint32", + "ijl_box_char", + "jl_box_char", "ijl_subtype", - "jl_subtype", "julia.get_pgcstack", "jl_in_threaded_region", - "jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id", + "jl_subtype", + "julia.get_pgcstack", + "jl_in_threaded_region", + "jl_object_id_", + "jl_object_id", + "ijl_object_id_", + "ijl_object_id", "jl_breakpoint", - "llvm.julia.gc_preserve_begin","llvm.julia.gc_preserve_end", + "llvm.julia.gc_preserve_begin", + "llvm.julia.gc_preserve_end", "jl_get_ptls_states", "ijl_get_ptls_states", "jl_f_fieldtype", "jl_symbol_n", - "jl_stored_inline", "ijl_stored_inline", - "jl_f_apply_type", "jl_f_issubtype", - "jl_isa", "ijl_isa", - "jl_matching_methods", "ijl_matching_methods", - "jl_excstack_state", "ijl_excstack_state", - "jl_current_exception", "ijl_current_exception", + "jl_stored_inline", + "ijl_stored_inline", + "jl_f_apply_type", + "jl_f_issubtype", + "jl_isa", + "ijl_isa", + "jl_matching_methods", + "ijl_matching_methods", + "jl_excstack_state", + "ijl_excstack_state", + "jl_current_exception", + "ijl_current_exception", "memhash_seed", - "jl_f__typevar", "ijl_f__typevar", - "jl_f_isa", "ijl_f_isa", - "jl_set_task_threadpoolid", "ijl_set_task_threadpoolid", - "jl_types_equal", "ijl_types_equal", + "jl_f__typevar", + "ijl_f__typevar", + "jl_f_isa", + "ijl_f_isa", + "jl_set_task_threadpoolid", + "ijl_set_task_threadpoolid", + "jl_types_equal", + "ijl_types_equal", "jl_string_to_array", "ijl_string_to_array", "jl_alloc_string", @@ -292,13 +410,11 @@ const inactivefns = Set{String}(( "uv_os_homedir", "jl_array_to_string", "ijl_array_to_string", - "pcre2_jit_compile_8" + "pcre2_jit_compile_8", # "jl_" )) -const activefns = Set{String}(( - "jl_", -)) +const activefns = Set{String}(("jl_",)) const inactiveglobs = Set{String}(( "ijl_boxed_uint8_cache", @@ -322,7 +438,7 @@ struct Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed} world::worldT end -@inline element(::Val{T}) where T = T +@inline element(::Val{T}) where {T} = T # From https://github.com/JuliaLang/julia/blob/81813164963f38dcd779d65ecd222fad8d7ed437/src/cgutils.cpp#L570 @inline function isghostty(ty) @@ -338,7 +454,9 @@ end return false end -@inline function (c::Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed})(f::Int) where {seen,worldT,justActive,UnionSret,AbstractIsMixed} +@inline function (c::Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed})( + f::Int, +) where {seen,worldT,justActive,UnionSret,AbstractIsMixed} T = element(first(seen)) reftype = ismutabletype(T) || (T isa UnionAll && !AbstractIsMixed) @@ -353,7 +471,14 @@ end return Val(AnyState) end - sub = active_reg_inner(subT, seen, c.world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) + sub = active_reg_inner( + subT, + seen, + c.world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) if sub == AnyState Val(AnyState) @@ -374,9 +499,9 @@ end end end -@inline forcefold(::Val{RT}) where RT = RT +@inline forcefold(::Val{RT}) where {RT} = RT -@inline function forcefold(::Val{ty}, ::Val{sty}, C::Vararg{Any, N}) where {ty, sty, N} +@inline function forcefold(::Val{ty}, ::Val{sty}, C::Vararg{Any,N}) where {ty,sty,N} if sty == AnyState || sty == ty return forcefold(Val(ty), C...) end @@ -387,50 +512,92 @@ end end end -@inline ptreltype(::Type{Ptr{T}}) where T = T +@inline ptreltype(::Type{Ptr{T}}) where {T} = T @inline ptreltype(::Type{Core.LLVMPtr{T,N}}) where {T,N} = T @inline ptreltype(::Type{Core.LLVMPtr{T} where N}) where {T} = T -@inline ptreltype(::Type{Base.RefValue{T}}) where T = T +@inline ptreltype(::Type{Base.RefValue{T}}) where {T} = T @inline ptreltype(::Type{Array{T,N}}) where {T,N} = T -@inline ptreltype(::Type{Array{T, N} where N}) where {T} = T -@inline ptreltype(::Type{Complex{T}}) where T = T -@inline ptreltype(::Type{Tuple{Vararg{T}}}) where T = T -@inline ptreltype(::Type{IdDict{K, V}}) where {K, V} = V -@inline ptreltype(::Type{IdDict{K, V} where K}) where {V} = V +@inline ptreltype(::Type{Array{T,N} where N}) where {T} = T +@inline ptreltype(::Type{Complex{T}}) where {T} = T +@inline ptreltype(::Type{Tuple{Vararg{T}}}) where {T} = T +@inline ptreltype(::Type{IdDict{K,V}}) where {K,V} = V +@inline ptreltype(::Type{IdDict{K,V} where K}) where {V} = V @inline is_arrayorvararg_ty(::Type) = false @inline is_arrayorvararg_ty(::Type{Array{T,N}}) where {T,N} = true -@inline is_arrayorvararg_ty(::Type{Array{T, N} where N}) where {T} = true -@inline is_arrayorvararg_ty(::Type{Tuple{Vararg{T2}}}) where T2 = true -@inline is_arrayorvararg_ty(::Type{Ptr{T}}) where T = true +@inline is_arrayorvararg_ty(::Type{Array{T,N} where N}) where {T} = true +@inline is_arrayorvararg_ty(::Type{Tuple{Vararg{T2}}}) where {T2} = true +@inline is_arrayorvararg_ty(::Type{Ptr{T}}) where {T} = true @inline is_arrayorvararg_ty(::Type{Core.LLVMPtr{T,N}}) where {T,N} = true @inline is_arrayorvararg_ty(::Type{Core.LLVMPtr{T,N} where N}) where {T} = true -@inline is_arrayorvararg_ty(::Type{Base.RefValue{T}}) where T = true -@inline is_arrayorvararg_ty(::Type{IdDict{K, V}}) where {K, V} = true -@inline is_arrayorvararg_ty(::Type{IdDict{K, V} where K}) where {V} = true +@inline is_arrayorvararg_ty(::Type{Base.RefValue{T}}) where {T} = true +@inline is_arrayorvararg_ty(::Type{IdDict{K,V}}) where {K,V} = true +@inline is_arrayorvararg_ty(::Type{IdDict{K,V} where K}) where {V} = true -@inline function datatype_fieldcount(t::Type{T}) where T +@inline function datatype_fieldcount(t::Type{T}) where {T} return Base.datatype_fieldcount(t) end -@inline function staticInTup(::Val{T}, tup::NTuple{N, Val}) where {T, N} +@inline function staticInTup(::Val{T}, tup::NTuple{N,Val}) where {T,N} any(ntuple(Val(N)) do i Base.@_inline_meta Val(T) == tup[i] end) end -@inline function active_reg_recur(::Type{ST}, seen::Seen, world, ::Val{justActive}, ::Val{UnionSret}, ::Val{AbstractIsMixed}) where {ST, Seen, justActive, UnionSret, AbstractIsMixed} +@inline function active_reg_recur( + ::Type{ST}, + seen::Seen, + world, + ::Val{justActive}, + ::Val{UnionSret}, + ::Val{AbstractIsMixed}, +) where {ST,Seen,justActive,UnionSret,AbstractIsMixed} if ST isa Union - return forcefold(Val(active_reg_recur(ST.a, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))), Val(active_reg_recur(ST.b, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)))) + return forcefold( + Val( + active_reg_recur( + ST.a, + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ), + ), + Val( + active_reg_recur( + ST.b, + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ), + ), + ) end - return active_reg_inner(ST, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) + return active_reg_inner( + ST, + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) end @inline is_vararg_tup(x) = false -@inline is_vararg_tup(::Type{Tuple{Vararg{T2}}}) where T2 = true - -@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false), ::Val{AbstractIsMixed}=Val(false))::ActivityState where {ST,T, justActive, UnionSret, AbstractIsMixed} +@inline is_vararg_tup(::Type{Tuple{Vararg{T2}}}) where {T2} = true + +@inline function active_reg_inner( + ::Type{T}, + seen::ST, + world::Union{Nothing,UInt}, + ::Val{justActive} = Val(false), + ::Val{UnionSret} = Val(false), + ::Val{AbstractIsMixed} = Val(false), +)::ActivityState where {ST,T,justActive,UnionSret,AbstractIsMixed} if T === Any if AbstractIsMixed return MixedState @@ -444,7 +611,14 @@ end end if T <: Complex && !(T isa UnionAll) - return active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) + return active_reg_inner( + ptreltype(T), + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) end if T <: BigFloat @@ -455,12 +629,24 @@ end return ActiveState end - if T <: Ptr || T <: Core.LLVMPtr || T <: Base.RefValue || T <: Array || is_arrayorvararg_ty(T) + if T <: Ptr || + T <: Core.LLVMPtr || + T <: Base.RefValue || + T <: Array || + is_arrayorvararg_ty(T) if justActive return AnyState end - if is_arrayorvararg_ty(T) && active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) == AnyState + if is_arrayorvararg_ty(T) && + active_reg_inner( + ptreltype(T), + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) == AnyState return AnyState else if AbstractIsMixed && is_vararg_tup(T) @@ -482,10 +668,22 @@ end inactivety = if typeof(world) === Nothing EnzymeCore.EnzymeRules.inactive_type(T) else - inmi = GPUCompiler.methodinstance(typeof(EnzymeCore.EnzymeRules.inactive_type), Tuple{Type{T}}, world) - args = Any[EnzymeCore.EnzymeRules.inactive_type, T]; + inmi = GPUCompiler.methodinstance( + typeof(EnzymeCore.EnzymeRules.inactive_type), + Tuple{Type{T}}, + world, + ) + args = Any[EnzymeCore.EnzymeRules.inactive_type, T] GC.@preserve T begin - ccall(:jl_invoke, Any, (Any, Ptr{Any}, Cuint, Any), EnzymeCore.EnzymeRules.inactive_type, args, length(args), inmi) + ccall( + :jl_invoke, + Any, + (Any, Ptr{Any}, Cuint, Any), + EnzymeCore.EnzymeRules.inactive_type, + args, + length(args), + inmi, + ) end end @@ -516,19 +714,28 @@ end # if sret union, the data is stored in a stack memory location and is therefore # not unique'd preventing the boxing of the union in the default case if UnionSret && is_sret_union(T) - return active_reg_recur(T, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) + return active_reg_recur( + T, + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) else if justActive return AnyState end - if active_reg_inner(T.a, seen, world, Val(justActive), Val(UnionSret)) != AnyState + if active_reg_inner(T.a, seen, world, Val(justActive), Val(UnionSret)) != + AnyState if AbstractIsMixed return MixedState else return DupState end end - if active_reg_inner(T.b, seen, world, Val(justActive), Val(UnionSret)) != AnyState + if active_reg_inner(T.b, seen, world, Val(justActive), Val(UnionSret)) != + AnyState if AbstractIsMixed return MixedState else @@ -562,17 +769,19 @@ end end nT = if T <: Tuple && T != Tuple && !(T isa UnionAll) - Tuple{(ntuple(length(T.parameters)) do i - Base.@_inline_meta - sT = T.parameters[i] - if sT isa TypeVar - Any - elseif sT isa Core.TypeofVararg - Any - else - sT + Tuple{( + ntuple(length(T.parameters)) do i + Base.@_inline_meta + sT = T.parameters[i] + if sT isa TypeVar + Any + elseif sT isa Core.TypeofVararg + Any + else + sT + end end - end)...} + )...} else T end @@ -583,40 +792,48 @@ end seen2 = (Val(nT), seen...) - fty = Merger{seen2,typeof(world),justActive, UnionSret, AbstractIsMixed}(world) + fty = Merger{seen2,typeof(world),justActive,UnionSret,AbstractIsMixed}(world) ty = forcefold(Val(AnyState), ntuple(fty, Val(fieldcount(nT)))...) return ty end -@inline @generated function active_reg_nothrow(::Type{T}, ::Val{world}) where {T, world} +@inline @generated function active_reg_nothrow(::Type{T}, ::Val{world}) where {T,world} return active_reg_inner(T, (), world) end -Base.@pure @inline function active_reg(::Type{T}, world::Union{Nothing, UInt}=nothing)::Bool where {T} +Base.@pure @inline function active_reg( + ::Type{T}, + world::Union{Nothing,UInt} = nothing, +)::Bool where {T} seen = () # check if it could contain an active - if active_reg_inner(T, seen, world, #=justActive=#Val(true)) == ActiveState - state = active_reg_inner(T, seen, world, #=justActive=#Val(false)) + if active_reg_inner(T, seen, world, Val(true)) == ActiveState #=justActive=# + state = active_reg_inner(T, seen, world, Val(false)) #=justActive=# if state == ActiveState return true end @assert state == MixedState - throw(AssertionError(string(T)*" has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")) + throw( + AssertionError( + string(T) * + " has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information", + ), + ) else return false end end -@inline function guaranteed_const(::Type{T}) where T +@inline function guaranteed_const(::Type{T}) where {T} rt = active_reg_nothrow(T, Val(nothing)) res = rt == AnyState return res end -@inline function guaranteed_const_nongen(::Type{T}, world) where T +@inline function guaranteed_const_nongen(::Type{T}, world) where {T} rt = active_reg_inner(T, (), world) res = rt == AnyState return res @@ -624,12 +841,13 @@ end # check if a value is guaranteed to be not contain active[register] data # (aka not either mixed or active) -@inline function guaranteed_nonactive(::Type{T}) where T +@inline function guaranteed_nonactive(::Type{T}) where {T} rt = Enzyme.Compiler.active_reg_nothrow(T, Val(nothing)) return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState end -@inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = guess_activity(T, convert(API.CDerivativeMode, mode)) +@inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = + guess_activity(T, convert(API.CDerivativeMode, mode)) @inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T} ActReg = active_reg_inner(T, (), nothing) @@ -650,54 +868,72 @@ end end # User facing interface -abstract type AbstractThunk{FA, RT, TT, Width} end +abstract type AbstractThunk{FA,RT,TT,Width} end -struct CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal} <: AbstractThunk{FA, RT, TT, Width} +struct CombinedAdjointThunk{PT,FA,RT,TT,Width,ReturnPrimal} <: AbstractThunk{FA,RT,TT,Width} adjoint::PT end -struct ForwardModeThunk{PT, FA, RT, TT, Width, ReturnPrimal} <: AbstractThunk{FA, RT, TT, Width} +struct ForwardModeThunk{PT,FA,RT,TT,Width,ReturnPrimal} <: AbstractThunk{FA,RT,TT,Width} adjoint::PT end -struct AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeType} <: AbstractThunk{FA, RT, TT, Width} +struct AugmentedForwardThunk{PT,FA,RT,TT,Width,ReturnPrimal,TapeType} <: + AbstractThunk{FA,RT,TT,Width} primal::PT end -struct AdjointThunk{PT, FA, RT, TT, Width, TapeType} <: AbstractThunk{FA, RT, TT, Width} +struct AdjointThunk{PT,FA,RT,TT,Width,TapeType} <: AbstractThunk{FA,RT,TT,Width} adjoint::PT end -struct PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal} <: AbstractThunk{FA, RT, TT, Width} +struct PrimalErrorThunk{PT,FA,RT,TT,Width,ReturnPrimal} <: AbstractThunk{FA,RT,TT,Width} adjoint::PT end -@inline return_type(::AbstractThunk{FA, RT}) where {FA, RT} = RT -@inline return_type(::Type{AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeType}}) where {PT, FA, RT, TT, Width, ReturnPrimal, TapeType} = RT - -@inline EnzymeRules.tape_type(::Type{AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeType}}) where {PT, FA, RT, TT, Width, ReturnPrimal, TapeType} = TapeType -@inline EnzymeRules.tape_type(::AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeType}) where {PT, FA, RT, TT, Width, ReturnPrimal, TapeType} = TapeType -@inline EnzymeRules.tape_type(::Type{AdjointThunk{PT, FA, RT, TT, Width, TapeType}}) where {PT, FA, RT, TT, Width, TapeType} = TapeType -@inline EnzymeRules.tape_type(::AdjointThunk{PT, FA, RT, TT, Width, TapeType}) where {PT, FA, RT, TT, Width, TapeType} = TapeType +@inline return_type(::AbstractThunk{FA,RT}) where {FA,RT} = RT +@inline return_type( + ::Type{AugmentedForwardThunk{PT,FA,RT,TT,Width,ReturnPrimal,TapeType}}, +) where {PT,FA,RT,TT,Width,ReturnPrimal,TapeType} = RT + +@inline EnzymeRules.tape_type( + ::Type{AugmentedForwardThunk{PT,FA,RT,TT,Width,ReturnPrimal,TapeType}}, +) where {PT,FA,RT,TT,Width,ReturnPrimal,TapeType} = TapeType +@inline EnzymeRules.tape_type( + ::AugmentedForwardThunk{PT,FA,RT,TT,Width,ReturnPrimal,TapeType}, +) where {PT,FA,RT,TT,Width,ReturnPrimal,TapeType} = TapeType +@inline EnzymeRules.tape_type( + ::Type{AdjointThunk{PT,FA,RT,TT,Width,TapeType}}, +) where {PT,FA,RT,TT,Width,TapeType} = TapeType +@inline EnzymeRules.tape_type( + ::AdjointThunk{PT,FA,RT,TT,Width,TapeType}, +) where {PT,FA,RT,TT,Width,TapeType} = TapeType using .JIT -declare_allocobj!(mod) = get_function!(mod, "julia.gc_alloc_obj") do - T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - T_ppjlvalue = LLVM.PointerType(LLVM.PointerType(T_jlvalue)) - T_size_t = convert(LLVM.LLVMType, Int) +declare_allocobj!(mod) = + get_function!(mod, "julia.gc_alloc_obj") do + T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_ppjlvalue = LLVM.PointerType(LLVM.PointerType(T_jlvalue)) + T_size_t = convert(LLVM.LLVMType, Int) - LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue]) -end -function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround::Bool, name::String="") + LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue]) + end +function emit_allocobj!( + B, + tag::LLVM.Value, + Size::LLVM.Value, + needs_workaround::Bool, + name::String = "", +) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) - T_jlvalue = LLVM.StructType(LLVMType[]) + T_jlvalue = LLVM.StructType(LLVMType[]) T_pjlvalue = LLVM.PointerType(T_jlvalue) T_ppjlvalue = LLVM.PointerType(T_pjlvalue) @@ -705,13 +941,13 @@ function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround:: T_pint8 = LLVM.PointerType(T_int8) pgcstack = reinsert_gcmarker!(fn, B) - ct = inbounds_gep!(B, + ct = inbounds_gep!( + B, T_pjlvalue, bitcast!(B, pgcstack, T_ppjlvalue), - [LLVM.ConstantInt(current_task_offset())]) - ptls_field = inbounds_gep!(B, - T_pjlvalue, - ct, [LLVM.ConstantInt(current_ptls_offset())]) + [LLVM.ConstantInt(current_task_offset())], + ) + ptls_field = inbounds_gep!(B, T_pjlvalue, ct, [LLVM.ConstantInt(current_ptls_offset())]) T_ppint8 = LLVM.PointerType(T_pint8) ptls = load!(B, T_pint8, bitcast!(B, ptls_field, T_ppint8)) @@ -732,12 +968,12 @@ function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround:: return call!(B, alty, alloc_obj, [ct, Size, tag], name) end -function emit_allocobj!(B, T::DataType, name::String="") +function emit_allocobj!(B, T::DataType, name::String = "") curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) - T_jlvalue = LLVM.StructType(LLVMType[]) + T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -746,14 +982,15 @@ function emit_allocobj!(B, T::DataType, name::String="") T_size_t = convert(LLVM.LLVMType, UInt) Size = LLVM.ConstantInt(T_size_t, sizeof(T)) - emit_allocobj!(B, tag, Size, #=needs_workaround=#false, name) -end -declare_pointerfromobjref!(mod) = get_function!(mod, "julia.pointer_from_objref") do - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Derived) - T_pjlvalue = LLVM.PointerType(T_jlvalue) - LLVM.FunctionType(T_pjlvalue, [T_prjlvalue]) + emit_allocobj!(B, tag, Size, false, name) #=needs_workaround=# end +declare_pointerfromobjref!(mod) = + get_function!(mod, "julia.pointer_from_objref") do + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Derived) + T_pjlvalue = LLVM.PointerType(T_jlvalue) + LLVM.FunctionType(T_pjlvalue, [T_prjlvalue]) + end function emit_pointerfromobjref!(B, T) curent_bb = position(B) fn = LLVM.parent(curent_bb) @@ -762,21 +999,27 @@ function emit_pointerfromobjref!(B, T) return call!(B, fty, func, [T]) end -declare_writebarrier!(mod) = get_function!(mod, "julia.write_barrier") do - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - LLVM.FunctionType(LLVM.VoidType(), [T_prjlvalue]; vararg=true) -end -declare_apply_generic!(mod) = get_function!(mod, "ijl_apply_generic") do - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, LLVM.PointerType(T_prjlvalue), LLVM.Int32Type()]) -end -declare_juliacall!(mod) = get_function!(mod, "julia.call") do - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg=true) -end +declare_writebarrier!(mod) = + get_function!(mod, "julia.write_barrier") do + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + LLVM.FunctionType(LLVM.VoidType(), [T_prjlvalue]; vararg = true) + end +declare_apply_generic!(mod) = + get_function!(mod, "ijl_apply_generic") do + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + LLVM.FunctionType( + T_prjlvalue, + [T_prjlvalue, LLVM.PointerType(T_prjlvalue), LLVM.Int32Type()], + ) + end +declare_juliacall!(mod) = + get_function!(mod, "julia.call") do + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg = true) + end function emit_jl!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value curent_bb = position(B) @@ -804,9 +1047,15 @@ function emit_getfield!(B::LLVM.IRBuilder, val::LLVM.Value, fld::LLVM.Value)::LL args = [val, fld] - julia_call, FT = get_function!(mod, "julia.call", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(gen_FT), T_prjlvalue]; vararg=true)) + julia_call, FT = get_function!( + mod, + "julia.call", + LLVM.FunctionType( + T_prjlvalue, + [LLVM.PointerType(gen_FT), T_prjlvalue]; + vararg = true, + ), + ) res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) return res end @@ -818,7 +1067,7 @@ function emit_nthfield!(B::LLVM.IRBuilder, val::LLVM.Value, fld::LLVM.Value)::LL mod = LLVM.parent(fn) T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) T_size_t = convert(LLVM.LLVMType, Int) gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_size_t]) @@ -882,9 +1131,15 @@ function emit_apply_generic!(B::LLVM.IRBuilder, args)::LLVM.Value inv, _ = get_function!(mod, "ijl_apply_generic", gen_FT) # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!(mod, "julia.call", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(gen_FT), T_prjlvalue]; vararg=true)) + julia_call, FT = get_function!( + mod, + "julia.call", + LLVM.FunctionType( + T_prjlvalue, + [LLVM.PointerType(gen_FT), T_prjlvalue]; + vararg = true, + ), + ) res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) return res end @@ -900,13 +1155,20 @@ function emit_invoke!(B::LLVM.IRBuilder, args)::LLVM.Value T_int32 = LLVM.Int32Type() # {} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)* @ijl_invoke - gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32, T_prjlvalue]) + gen_FT = + LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32, T_prjlvalue]) inv = get_function!(mod, "ijl_invoke", gen_FT) # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!(mod, "julia.call2", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) + julia_call, FT = get_function!( + mod, + "julia.call2", + LLVM.FunctionType( + T_prjlvalue, + [LLVM.PointerType(generic_FT), T_prjlvalue]; + vararg = true, + ), + ) res = call!(B, FT, julia_call, [inv, args...]) return res end @@ -920,13 +1182,13 @@ function emit_svec!(B, args)::LLVM.Value sz = convert(LLVMType, Csize_t) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - LLVM.FunctionType(T_prjlvalue, [sz]; vararg=true) - + LLVM.FunctionType(T_prjlvalue, [sz]; vararg = true) + sz = convert(LLVMType, Csize_t) call!(B, fty, fn, [LLVM.ConstantInt(sz, length(args)), args...]) end -AnyArray(Length::Int) = NamedTuple{ntuple(i->Symbol(i), Val(Length)),NTuple{Length,Any}} +AnyArray(Length::Int) = NamedTuple{ntuple(i -> Symbol(i), Val(Length)),NTuple{Length,Any}} struct EnzymeRuntimeException <: Base.Exception msg::Cstring @@ -953,12 +1215,27 @@ end function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError) println(io, "Constant memory is stored (or returned) to a differentiable variable.") - println(io, "As a result, Enzyme cannot provably ensure correctness and throws this error.") - println(io, "This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).") - println(io, "If Enzyme should be able to prove this use non-differentable, open an issue!"); - println(io, "To work around this issue, either:"); - println(io, " a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or") - println(io, " b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.") + println( + io, + "As a result, Enzyme cannot provably ensure correctness and throws this error.", + ) + println( + io, + "This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).", + ) + println( + io, + "If Enzyme should be able to prove this use non-differentable, open an issue!", + ) + println(io, "To work around this issue, either:") + println( + io, + " a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or", + ) + println( + io, + " b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.", + ) msg = Base.unsafe_string(ece.msg) print(io, msg, '\n') end @@ -992,7 +1269,7 @@ function Base.showerror(io::IO, ece::EnzymeNoDerivativeError) print(io, msg, '\n') end -const JuliaEnzymeNameMap = Dict{String, Any}( +const JuliaEnzymeNameMap = Dict{String,Any}( "enz_val_true" => Val(true), "enz_val_false" => Val(false), "enz_val_1" => Val(1), @@ -1007,7 +1284,7 @@ const JuliaEnzymeNameMap = Dict{String, Any}( "enz_no_derivative_exc" => EnzymeNoDerivativeError, ) -const JuliaGlobalNameMap = Dict{String, Any}( +const JuliaGlobalNameMap = Dict{String,Any}( "jl_type_type" => Type, "jl_any_type" => Any, "jl_datatype_type" => DataType, @@ -1015,12 +1292,10 @@ const JuliaGlobalNameMap = Dict{String, Any}( "jl_symbol_type" => Symbol, "jl_simplevector_type" => Core.SimpleVector, "jl_nothing_type" => Nothing, - "jl_tvar_type" => TypeVar, "jl_typeofbottom_type" => Core.TypeofBottom, "jl_bottom_type" => Union{}, "jl_unionall_type" => UnionAll, - "jl_uniontype_type" => Union, "jl_emptytuple_type" => Tuple{}, "jl_emptytuple" => (), @@ -1052,46 +1327,31 @@ const JuliaGlobalNameMap = Dict{String, Any}( "jl_ref_type" => Ref, "jl_pointer_typename" => Ptr, "jl_voidpointer_type" => Ptr{Nothing}, - "jl_abstractarray_type" => AbstractArray, - "jl_densearray_type" => DenseArray, - "jl_array_type" => Array, - - "jl_array_any_type" => Array{Any, 1}, - - "jl_array_symbol_type" => Array{Symbol, 1}, - - "jl_array_uint8_type" => Array{UInt8, 1}, + "jl_array_any_type" => Array{Any,1}, + "jl_array_symbol_type" => Array{Symbol,1}, + "jl_array_uint8_type" => Array{UInt8,1}, # "jl_array_uint32_type" => Array{UInt32, 1}, - "jl_array_int32_type" => Array{Int32, 1}, - - + "jl_array_int32_type" => Array{Int32,1}, "jl_expr_type" => Expr, - "jl_method_type" => Method, "jl_method_instance_type" => Core.MethodInstance, "jl_code_instance_type" => Core.CodeInstance, "jl_const_type" => Core.Const, "jl_llvmpointer_type" => Core.LLVMPtr, - - "jl_namedtuple_type" => NamedTuple, - "jl_task_type" => Task, - "jl_uint8pointer_type" => Ptr{UInt8}, - "jl_nothing" => nothing, - "jl_anytuple_type" => Tuple, "jl_vararg_type" => Core.TypeofVararg, "jl_opaque_closure_type" => Core.OpaqueClosure, - "jl_array_uint64_type" => Array{UInt64, 1}, - "jl_binding_type" => Core.Binding + "jl_array_uint64_type" => Array{UInt64,1}, + "jl_binding_type" => Core.Binding, ) include("absint.jl") @@ -1104,7 +1364,7 @@ function emit_apply_type!(B::LLVM.IRBuilder, Ty, args)::LLVM.Value legal = true found = [] for arg in args - slegal , foundv = absint(arg) + slegal, foundv = absint(arg) if slegal push!(found, foundv) else @@ -1127,10 +1387,21 @@ function emit_apply_type!(B::LLVM.IRBuilder, Ty, args)::LLVM.Value Ty = unsafe_to_llvm(B, Ty) # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!(mod, "julia.call", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) - tag = call!(B, FT, julia_call, LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), Ty, args...]) + julia_call, FT = get_function!( + mod, + "julia.call", + LLVM.FunctionType( + T_prjlvalue, + [LLVM.PointerType(generic_FT), T_prjlvalue]; + vararg = true, + ), + ) + tag = call!( + B, + FT, + julia_call, + LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), Ty, args...], + ) return tag end @@ -1142,7 +1413,7 @@ function emit_tuple!(B, args)::LLVM.Value legal = true found = [] for arg in args - slegal , foundv = absint(arg) + slegal, foundv = absint(arg) if slegal push!(found, foundv) else @@ -1164,10 +1435,21 @@ function emit_tuple!(B, args)::LLVM.Value f_apply_type, _ = get_function!(mod, "jl_f_tuple", generic_FT) # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!(mod, "julia.call", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) - tag = call!(B, FT, julia_call, LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), args...]) + julia_call, FT = get_function!( + mod, + "julia.call", + LLVM.FunctionType( + T_prjlvalue, + [LLVM.PointerType(generic_FT), T_prjlvalue]; + vararg = true, + ), + ) + tag = call!( + B, + FT, + julia_call, + LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), args...], + ) return tag end @@ -1176,14 +1458,14 @@ function emit_jltypeof!(B::LLVM.IRBuilder, arg::LLVM.Value)::LLVM.Value fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) - legal, val = abs_typeof(arg) + legal, val, byref = abs_typeof(arg) if legal return unsafe_to_llvm(B, val) end T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg=true) + FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg = true) fn, _ = get_function!(mod, "jl_typeof", FT) call!(B, FT, fn, [arg]) end @@ -1206,43 +1488,65 @@ function emit_methodinstance!(B::LLVM.IRBuilder, func, args)::LLVM.Value meth = only(methods(func)) tag = emit_apply_type!(B, Tuple, primalvaltys) -# TT = meth.sig -# while TT isa UnionAll -# TT = TT.body -# end -# parms = TT.parameters -# -# tosv = primalvaltys -# if length(parms) > 0 && typeof(parms[end]) == Core.TypeofVararg -# tosv = LLVM.Value[tosv[1:length(parms)-1]..., emit_apply_type!(B, Tuple, tosv[length(parms):end])] -# end -# sv = emit_svec!(B, tosv[2:end]) -# + # TT = meth.sig + # while TT isa UnionAll + # TT = TT.body + # end + # parms = TT.parameters + # + # tosv = primalvaltys + # if length(parms) > 0 && typeof(parms[end]) == Core.TypeofVararg + # tosv = LLVM.Value[tosv[1:length(parms)-1]..., emit_apply_type!(B, Tuple, tosv[length(parms):end])] + # end + # sv = emit_svec!(B, tosv[2:end]) + # meth = unsafe_to_llvm(B, meth) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - worlds, FT = get_function!(mod, "jl_gf_invoke_lookup_worlds", - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue, sizeT, psizeT, psizeT])) + worlds, FT = get_function!( + mod, + "jl_gf_invoke_lookup_worlds", + LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue, sizeT, psizeT, psizeT]), + ) EB = LLVM.IRBuilder() position!(EB, first(LLVM.instructions(LLVM.entry(fn)))) minworld = alloca!(EB, sizeT) maxworld = alloca!(EB, sizeT) store!(B, LLVM.ConstantInt(sizeT, 0), minworld) store!(B, LLVM.ConstantInt(sizeT, -1), maxworld) - methodmatch = call!(B, FT, worlds, LLVM.Value[tag, unsafe_to_llvm(B, nothing), LLVM.ConstantInt(sizeT, world), minworld, maxworld]) + methodmatch = call!( + B, + FT, + worlds, + LLVM.Value[ + tag, + unsafe_to_llvm(B, nothing), + LLVM.ConstantInt(sizeT, world), + minworld, + maxworld, + ], + ) # emit_jl!(B, methodmatch) # emit_jl!(B, emit_jltypeof!(B, methodmatch)) offset = 1 - AT = LLVM.ArrayType(T_prjlvalue, offset+1) + AT = LLVM.ArrayType(T_prjlvalue, offset + 1) methodmatch = addrspacecast!(B, methodmatch, LLVM.PointerType(T_jlvalue, Derived)) methodmatch = bitcast!(B, methodmatch, LLVM.PointerType(AT, Derived)) - gep = LLVM.inbounds_gep!(B, AT, methodmatch, LLVM.Value[LLVM.ConstantInt(0), LLVM.ConstantInt(offset)]) + gep = LLVM.inbounds_gep!( + B, + AT, + methodmatch, + LLVM.Value[LLVM.ConstantInt(0), LLVM.ConstantInt(offset)], + ) sv = LLVM.load!(B, T_prjlvalue, gep) - fn, FT = get_function!(mod, "jl_specializations_get_linfo", - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue, T_prjlvalue])) + fn, FT = get_function!( + mod, + "jl_specializations_get_linfo", + LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue, T_prjlvalue]), + ) mi = call!(B, FT, fn, [meth, tag, sv]) @@ -1259,86 +1563,99 @@ end function get_array_struct() -@static if VERSION < v"1.11-" -# JL_EXTENSION typedef struct { -# JL_DATA_TYPE -# void *data; -# #ifdef STORE_ARRAY_LEN (just true new newer versions) -# size_t length; -# #endif -# jl_array_flags_t flags; -# uint16_t elsize; // element size including alignment (dim 1 memory stride) -# uint32_t offset; // for 1-d only. does not need to get big. -# size_t nrows; -# union { -# // 1d -# size_t maxsize; -# // Nd -# size_t ncols; -# }; -# // other dim sizes go here for ndims > 2 -# -# // followed by alignment padding and inline data, or owner pointer -# } jl_array_t; - - i8 = LLVM.IntType(8) - ptrty = LLVM.PointerType(i8, 13) - sizeT = LLVM.IntType(8*sizeof(Csize_t)) - arrayFlags = LLVM.IntType(16) - elsz = LLVM.IntType(16) - off = LLVM.IntType(32) - nrows = LLVM.IntType(8*sizeof(Csize_t)) - - return LLVM.StructType([ptrty, sizeT, arrayFlags, elsz, off, nrows]; packed=true) -else -# JL_EXTENSION typedef struct { -# JL_DATA_TYPE -# size_t length; -# void *ptr; -# // followed by padding and inline data, or owner pointer -# #ifdef _P64 -# // union { -# // jl_value_t *owner; -# // T inl[]; -# // }; -# #else -# // -# // jl_value_t *owner; -# // size_t padding[1]; -# // T inl[]; -# #endif -# } jl_genericmemory_t; -# -# JL_EXTENSION typedef struct { -# JL_DATA_TYPE -# void *ptr_or_offset; -# jl_genericmemory_t *mem; -# } jl_genericmemoryref_t; -# -# JL_EXTENSION typedef struct { -# JL_DATA_TYPE -# jl_genericmemoryref_t ref; -# size_t dimsize[]; // length for 1-D, otherwise length is mem->length -# } jl_array_t; - i8 = LLVM.IntType(8) - ptrty = LLVM.PointerType(i8, 10) - sizeT = LLVM.IntType(8*sizeof(Csize_t)) - return LLVM.StructType([ptrty, sizeT]; packed=true) -end + @static if VERSION < v"1.11-" + # JL_EXTENSION typedef struct { + # JL_DATA_TYPE + # void *data; + # #ifdef STORE_ARRAY_LEN (just true new newer versions) + # size_t length; + # #endif + # jl_array_flags_t flags; + # uint16_t elsize; // element size including alignment (dim 1 memory stride) + # uint32_t offset; // for 1-d only. does not need to get big. + # size_t nrows; + # union { + # // 1d + # size_t maxsize; + # // Nd + # size_t ncols; + # }; + # // other dim sizes go here for ndims > 2 + # + # // followed by alignment padding and inline data, or owner pointer + # } jl_array_t; + + i8 = LLVM.IntType(8) + ptrty = LLVM.PointerType(i8, 13) + sizeT = LLVM.IntType(8 * sizeof(Csize_t)) + arrayFlags = LLVM.IntType(16) + elsz = LLVM.IntType(16) + off = LLVM.IntType(32) + nrows = LLVM.IntType(8 * sizeof(Csize_t)) + + return LLVM.StructType([ptrty, sizeT, arrayFlags, elsz, off, nrows]; packed = true) + else + # JL_EXTENSION typedef struct { + # JL_DATA_TYPE + # size_t length; + # void *ptr; + # // followed by padding and inline data, or owner pointer + # #ifdef _P64 + # // union { + # // jl_value_t *owner; + # // T inl[]; + # // }; + # #else + # // + # // jl_value_t *owner; + # // size_t padding[1]; + # // T inl[]; + # #endif + # } jl_genericmemory_t; + # + # JL_EXTENSION typedef struct { + # JL_DATA_TYPE + # void *ptr_or_offset; + # jl_genericmemory_t *mem; + # } jl_genericmemoryref_t; + # + # JL_EXTENSION typedef struct { + # JL_DATA_TYPE + # jl_genericmemoryref_t ref; + # size_t dimsize[]; // length for 1-D, otherwise length is mem->length + # } jl_array_t; + i8 = LLVM.IntType(8) + ptrty = LLVM.PointerType(i8, 10) + sizeT = LLVM.IntType(8 * sizeof(Csize_t)) + return LLVM.StructType([ptrty, sizeT]; packed = true) + end end function get_array_data(B, array) i8 = LLVM.IntType(8) ptrty = LLVM.PointerType(i8, 13) - array = LLVM.pointercast!(B, array, LLVM.PointerType(ptrty, LLVM.addrspace(LLVM.value_type(array)))) + array = LLVM.pointercast!( + B, + array, + LLVM.PointerType(ptrty, LLVM.addrspace(LLVM.value_type(array))), + ) return LLVM.load!(B, ptrty, array) end function get_array_elsz(B, array) ST = get_array_struct() elsz = LLVM.IntType(16) - array = LLVM.pointercast!(B, array, LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array)))) - v = inbounds_gep!(B, ST, array, LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(3))]) + array = LLVM.pointercast!( + B, + array, + LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array))), + ) + v = inbounds_gep!( + B, + ST, + array, + LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(3))], + ) return LLVM.load!(B, elsz, v) end @@ -1351,13 +1668,16 @@ function get_array_len(B, array) end for (fname, num) in ( - ("jl_alloc_array_1d", 1), ("ijl_alloc_array_1d", 1), - ("jl_alloc_array_2d", 2), ("jl_alloc_array_2d", 2), - ("jl_alloc_array_2d", 3), ("jl_alloc_array_2d", 3), - ) + ("jl_alloc_array_1d", 1), + ("ijl_alloc_array_1d", 1), + ("jl_alloc_array_2d", 2), + ("jl_alloc_array_2d", 2), + ("jl_alloc_array_2d", 3), + ("jl_alloc_array_2d", 3), + ) if nm == fname res = operands(array)[2] - for i in 2:num + for i = 2:num res = mul!(B, res, operands(array)[1+i]) end return res @@ -1365,17 +1685,35 @@ function get_array_len(B, array) end end ST = get_array_struct() - array = LLVM.pointercast!(B, array, LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array)))) - v = inbounds_gep!(B, ST, array, LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(1))]) - sizeT = LLVM.IntType(8*sizeof(Csize_t)) + array = LLVM.pointercast!( + B, + array, + LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array))), + ) + v = inbounds_gep!( + B, + ST, + array, + LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(1))], + ) + sizeT = LLVM.IntType(8 * sizeof(Csize_t)) return LLVM.load!(B, sizeT, v) end function get_array_nrows(B, array) ST = get_array_struct() - array = LLVM.pointercast!(B, array, LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array)))) - v = inbounds_gep!(B, ST, array, LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(5))]) - nrows = LLVM.IntType(8*sizeof(Csize_t)) + array = LLVM.pointercast!( + B, + array, + LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array))), + ) + v = inbounds_gep!( + B, + ST, + array, + LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(5))], + ) + nrows = LLVM.IntType(8 * sizeof(Csize_t)) return LLVM.load!(B, nrows, v) end @@ -1391,7 +1729,14 @@ function permit_inlining!(f::LLVM.Function) if isa(inst, LLVM.LoadInst) md = metadata(inst) if haskey(md, LLVM.MD_tbaa) - modified = LLVM.Metadata(ccall((:EnzymeMakeNonConstTBAA, API.libEnzyme), LLVM.API.LLVMMetadataRef, (LLVM.API.LLVMMetadataRef,), md[LLVM.MD_tbaa])) + modified = LLVM.Metadata( + ccall( + (:EnzymeMakeNonConstTBAA, API.libEnzyme), + LLVM.API.LLVMMetadataRef, + (LLVM.API.LLVMMetadataRef,), + md[LLVM.MD_tbaa], + ), + ) setindex!(md, modified, LLVM.MD_tbaa) end if haskey(md, LLVM.MD_invariant_load) @@ -1406,11 +1751,15 @@ struct Tape{TapeTy,ShadowTy,ResT} shadow_return::ShadowTy end -function emit_gc_preserve_begin(B::LLVM.IRBuilder, args=LLVM.Value[]) +function emit_gc_preserve_begin(B::LLVM.IRBuilder, args = LLVM.Value[]) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) - func, FT = get_function!(mod, "llvm.julia.gc_preserve_begin", LLVM.FunctionType(LLVM.TokenType(), vararg=true)) + func, FT = get_function!( + mod, + "llvm.julia.gc_preserve_begin", + LLVM.FunctionType(LLVM.TokenType(), vararg = true), + ) token = call!(B, FT, func, args) return token @@ -1421,7 +1770,11 @@ function emit_gc_preserve_end(B::LLVM.IRBuilder, token) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) - func, FT = get_function!(mod, "llvm.julia.gc_preserve_end", LLVM.FunctionType(LLVM.VoidType(), [LLVM.TokenType()])) + func, FT = get_function!( + mod, + "llvm.julia.gc_preserve_end", + LLVM.FunctionType(LLVM.VoidType(), [LLVM.TokenType()]), + ) call!(B, FT, func, [token]) return @@ -1440,20 +1793,29 @@ function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, N) allocate_sret!(B, N) end -@inline function EnzymeCore.make_zero(x::FT)::FT where {FT <: AbstractFloat} +@inline function EnzymeCore.make_zero(x::FT)::FT where {FT<:AbstractFloat} return Base.zero(x) end -@inline function EnzymeCore.make_zero(x::Complex{FT})::Complex{FT} where {FT <: AbstractFloat} +@inline function EnzymeCore.make_zero(x::Complex{FT})::Complex{FT} where {FT<:AbstractFloat} return Base.zero(x) end -@inline function EnzymeCore.make_zero(x::Array{FT, N})::Array{FT, N} where {FT <: AbstractFloat, N} +@inline function EnzymeCore.make_zero( + x::Array{FT,N}, +)::Array{FT,N} where {FT<:AbstractFloat,N} return Base.zero(x) end -@inline function EnzymeCore.make_zero(x::Array{Complex{FT}, N})::Array{Complex{FT}, N} where {FT <: AbstractFloat, N} +@inline function EnzymeCore.make_zero( + x::Array{Complex{FT},N}, +)::Array{Complex{FT},N} where {FT<:AbstractFloat,N} return Base.zero(x) end -@inline function EnzymeCore.make_zero(::Type{Array{FT, N}}, seen::IdDict, prev::Array{FT, N}, ::Val{copy_if_inactive}=Val(false))::Array{FT, N} where {copy_if_inactive, FT<:AbstractFloat, N} +@inline function EnzymeCore.make_zero( + ::Type{Array{FT,N}}, + seen::IdDict, + prev::Array{FT,N}, + ::Val{copy_if_inactive} = Val(false), +)::Array{FT,N} where {copy_if_inactive,FT<:AbstractFloat,N} if haskey(seen, prev) return seen[prev] end @@ -1461,7 +1823,12 @@ end seen[prev] = newa return newa end -@inline function EnzymeCore.make_zero(::Type{Array{Complex{FT}, N}}, seen::IdDict, prev::Array{Complex{FT}, N}, ::Val{copy_if_inactive}=Val(false))::Array{Complex{FT}, N} where {copy_if_inactive, FT<:AbstractFloat, N} +@inline function EnzymeCore.make_zero( + ::Type{Array{Complex{FT},N}}, + seen::IdDict, + prev::Array{Complex{FT},N}, + ::Val{copy_if_inactive} = Val(false), +)::Array{Complex{FT},N} where {copy_if_inactive,FT<:AbstractFloat,N} if haskey(seen, prev) return seen[prev] end @@ -1470,15 +1837,30 @@ end return newa end -@inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT<:AbstractFloat} +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT<:AbstractFloat} return RT(0) end -@inline function EnzymeCore.make_zero(::Type{Complex{RT}}, seen::IdDict, prev::Complex{RT}, ::Val{copy_if_inactive}=Val(false))::Complex{RT} where {copy_if_inactive, RT<:AbstractFloat} +@inline function EnzymeCore.make_zero( + ::Type{Complex{RT}}, + seen::IdDict, + prev::Complex{RT}, + ::Val{copy_if_inactive} = Val(false), +)::Complex{RT} where {copy_if_inactive,RT<:AbstractFloat} return RT(0) end -@inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT<:Array} +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT<:Array} if haskey(seen, prev) return seen[prev] end @@ -1491,35 +1873,58 @@ end if isassigned(prev, I) pv = prev[I] innerty = Core.Typeof(pv) - @inbounds newa[I] = EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) + @inbounds newa[I] = + EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) end end return newa end -@inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT<:Tuple} +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT<:Tuple} return ntuple(length(prev)) do i Base.@_inline_meta EnzymeCore.make_zero(RT.parameters[i], seen, prev[i], Val(copy_if_inactive)) end end -@inline function EnzymeCore.make_zero(::Type{NamedTuple{A,RT}}, seen::IdDict, prev::NamedTuple{A,RT}, ::Val{copy_if_inactive}=Val(false))::NamedTuple{A,RT} where {copy_if_inactive, A,RT} +@inline function EnzymeCore.make_zero( + ::Type{NamedTuple{A,RT}}, + seen::IdDict, + prev::NamedTuple{A,RT}, + ::Val{copy_if_inactive} = Val(false), +)::NamedTuple{A,RT} where {copy_if_inactive,A,RT} return NamedTuple{A,RT}(EnzymeCore.make_zero(RT, seen, RT(prev), Val(copy_if_inactive))) end -@inline function EnzymeCore.make_zero(::Type{Core.Box}, seen::IdDict, prev::Core.Box, ::Val{copy_if_inactive}=Val(false)) where {copy_if_inactive} +@inline function EnzymeCore.make_zero( + ::Type{Core.Box}, + seen::IdDict, + prev::Core.Box, + ::Val{copy_if_inactive} = Val(false), +) where {copy_if_inactive} if haskey(seen, prev) return seen[prev] end prev2 = prev.contents res = Core.Box() seen[prev] = res - res.contents = Base.Ref(EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive))) + res.contents = Base.Ref( + EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)), + ) return res end -@inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT} +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT} if guaranteed_const_nongen(RT, nothing) return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev end @@ -1529,11 +1934,11 @@ end @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) - + if ismutable(prev) y = ccall(:jl_new_struct_uninit, Any, (Any,), RT) seen[prev] = y - for i in 1:nf + for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) T = Core.Typeof(xi) @@ -1543,13 +1948,13 @@ end end return y end - + if nf == 0 return prev end flds = Vector{Any}(undef, nf) - for i in 1:nf + for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) xi = EnzymeCore.make_zero(Core.Typeof(xi), seen, xi, Val(copy_if_inactive)) @@ -1564,32 +1969,33 @@ end return y end -function make_zero_immutable!(prev::T, seen::S)::T where {T <: AbstractFloat, S} +function make_zero_immutable!(prev::T, seen::S)::T where {T<:AbstractFloat,S} zero(T) end -function make_zero_immutable!(prev::Complex{T}, seen::S)::Complex{T} where {T <: AbstractFloat, S} +function make_zero_immutable!( + prev::Complex{T}, + seen::S, +)::Complex{T} where {T<:AbstractFloat,S} zero(T) end -function make_zero_immutable!(prev::T, seen::S)::T where {T <: Tuple, S} +function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} ntuple(Val(length(T.parameters))) do i Base.@_inline_meta make_zero_immutable!(prev[i], seen) end end -function make_zero_immutable!(prev::NamedTuple{a, b}, seen::S)::NamedTuple{a, b} where {a,b, S} - NamedTuple{a, b}( - ntuple(Val(length(T.parameters))) do i +function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} + NamedTuple{a,b}(ntuple(Val(length(T.parameters))) do i Base.@_inline_meta make_zero_immutable!(prev[a[i]], seen) - end - ) + end) end -function make_zero_immutable!(prev::T, seen::S)::T where {T, S} +function make_zero_immutable!(prev::T, seen::S)::T where {T,S} if guaranteed_const_nongen(T, nothing) return prev end @@ -1601,11 +2007,11 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T, S} nf = fieldcount(RT) flds = Vector{Any}(undef, nf) - for i in 1:nf + for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) ST = Core.Typeof(xi) - flds[i] = if active_reg_inner(ST, (), nothing, #=justActive=#Val(true)) == ActiveState + flds[i] = if active_reg_inner(ST, (), nothing, Val(true)) == ActiveState #=justActive=# make_zero_immutable!(xi, seen) else EnzymeCore.make_zero!(xi, seen) @@ -1619,47 +2025,65 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T, S} ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)::T end -@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T}, seen::ST)::Nothing where {T <: AbstractFloat, ST} +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{T}, + seen::ST, +)::Nothing where {T<:AbstractFloat,ST} T[] = zero(T) nothing end -@inline function EnzymeCore.make_zero!(prev::Base.RefValue{Complex{T}}, seen::ST)::Nothing where {T <: AbstractFloat, ST} +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{Complex{T}}, + seen::ST, +)::Nothing where {T<:AbstractFloat,ST} T[] = zero(Complex{T}) nothing end -@inline function EnzymeCore.make_zero!(prev::Array{T, N}, seen::ST)::Nothing where {T <: AbstractFloat, N, ST} +@inline function EnzymeCore.make_zero!( + prev::Array{T,N}, + seen::ST, +)::Nothing where {T<:AbstractFloat,N,ST} fill!(prev, zero(T)) nothing end -@inline function EnzymeCore.make_zero!(prev::Array{Complex{T}, N}, seen::ST)::Nothing where {T <: AbstractFloat, N, ST} +@inline function EnzymeCore.make_zero!( + prev::Array{Complex{T},N}, + seen::ST, +)::Nothing where {T<:AbstractFloat,N,ST} fill!(prev, zero(Complex{T})) nothing end -@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T})::Nothing where {T <: AbstractFloat} +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{T}, +)::Nothing where {T<:AbstractFloat} EnzymeCore.make_zero!(prev, nothing) nothing end -@inline function EnzymeCore.make_zero!(prev::Base.RefValue{Complex{T}})::Nothing where {T <: AbstractFloat} +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{Complex{T}}, +)::Nothing where {T<:AbstractFloat} EnzymeCore.make_zero!(prev, nothing) nothing end -@inline function EnzymeCore.make_zero!(prev::Array{T, N})::Nothing where {T <: AbstractFloat, N} +@inline function EnzymeCore.make_zero!(prev::Array{T,N})::Nothing where {T<:AbstractFloat,N} EnzymeCore.make_zero!(prev, nothing) nothing end -@inline function EnzymeCore.make_zero!(prev::Array{Complex{T}, N})::Nothing where {T <: AbstractFloat, N} +@inline function EnzymeCore.make_zero!( + prev::Array{Complex{T},N}, +)::Nothing where {T<:AbstractFloat,N} EnzymeCore.make_zero!(prev, nothing) nothing end -@inline function EnzymeCore.make_zero!(prev::Array{T, N}, seen::ST)::Nothing where {T, N, ST} +@inline function EnzymeCore.make_zero!(prev::Array{T,N}, seen::ST)::Nothing where {T,N,ST} if guaranteed_const_nongen(T, nothing) return end @@ -1672,7 +2096,7 @@ end if isassigned(prev, I) pv = prev[I] SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# @inbounds prev[I] = make_zero_immutable!(pv, seen) nothing else @@ -1684,7 +2108,10 @@ end nothing end -@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T}, seen::ST)::Nothing where {T, ST} +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{T}, + seen::ST, +)::Nothing where {T,ST} if guaranteed_const_nongen(T, nothing) return end @@ -1695,7 +2122,7 @@ end pv = prev[] SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# prev[] = make_zero_immutable!(pv, seen) nothing else @@ -1716,7 +2143,7 @@ end end push!(seen, prev) SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) nothing else @@ -1726,7 +2153,10 @@ end nothing end -@inline function EnzymeCore.make_zero!(prev::T, seen::S=Base.IdSet{Any}())::Nothing where {T, S} +@inline function EnzymeCore.make_zero!( + prev::T, + seen::S = Base.IdSet{Any}(), +)::Nothing where {T,S} if guaranteed_const_nongen(T, nothing) return end @@ -1736,7 +2166,7 @@ end @assert !Base.isabstracttype(T) @assert Base.isconcretetype(T) nf = fieldcount(T) - + if nf == 0 return @@ -1744,14 +2174,14 @@ end push!(seen, prev) - for i in 1:nf + for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) SBT = Core.Typeof(xi) if guaranteed_const_nongen(SBT, nothing) continue end - if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# setfield!(prev, i, make_zero_immutable!(xi, seen)) nothing else @@ -1763,7 +2193,7 @@ end return end -function emit_error(B::LLVM.IRBuilder, orig, string, errty=EnzymeRuntimeException) +function emit_error(B::LLVM.IRBuilder, orig, string, errty = EnzymeRuntimeException) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -1777,17 +2207,22 @@ function emit_error(B::LLVM.IRBuilder, orig, string, errty=EnzymeRuntimeExceptio vt = LLVM.VoidType() ptr = convert(LLVMType, Ptr{Cvoid}) - exc, _ = get_function!(mod, "gpu_report_exception", LLVM.FunctionType(vt, [ptr])) + exc, _ = + get_function!(mod, "gpu_report_exception", LLVM.FunctionType(vt, [ptr])) string = ptrtoint!(B, string, ptr) call!(B, LLVM.function_type(exc), exc, [string]) - framefn, ft = get_function!(mod, "gpu_report_exception_frame", LLVM.FunctionType(vt, [LLVM.Int32Type(), ptr, ptr, LLVM.Int32Type()])) + framefn, ft = get_function!( + mod, + "gpu_report_exception_frame", + LLVM.FunctionType(vt, [LLVM.Int32Type(), ptr, ptr, LLVM.Int32Type()]), + ) if orig !== nothing bt = GPUCompiler.backtrace(orig) - for (i,frame) in enumerate(bt) + for (i, frame) in enumerate(bt) idx = ConstantInt(parameters(ft)[1], i) func = globalstring_ptr!(B, String(frame.func), "di_func") func = ptrtoint!(B, func, ptr) @@ -1797,27 +2232,42 @@ function emit_error(B::LLVM.IRBuilder, orig, string, errty=EnzymeRuntimeExceptio call!(B, ft, framefn, [idx, func, file, line]) end end - - sigfn, sigft = get_function!(mod, "gpu_signal_exception", LLVM.FunctionType(vt, LLVM.LLVMType[])) - call!(B, sigft, sigfn) - trap_ft = LLVM.FunctionType(LLVM.VoidType()) - trap = if haskey(functions(mod), "llvm.trap") - functions(mod)["llvm.trap"] - else - LLVM.Function(mod, "llvm.trap", trap_ft) - end - call!(B, trap_ft, trap) + + sigfn, sigft = get_function!( + mod, + "gpu_signal_exception", + LLVM.FunctionType(vt, LLVM.LLVMType[]), + ) + call!(B, sigft, sigfn) + trap_ft = LLVM.FunctionType(LLVM.VoidType()) + trap = if haskey(functions(mod), "llvm.trap") + functions(mod)["llvm.trap"] + else + LLVM.Function(mod, "llvm.trap", trap_ft) + end + call!(B, trap_ft, trap) else err = emit_allocobj!(B, errty) err2 = bitcast!(B, err, LLVM.PointerType(LLVM.PointerType(LLVM.Int8Type()), 10)) store!(B, string, err2) - emit_jl_throw!(B, addrspacecast!(B, err, LLVM.PointerType(LLVM.StructType(LLVMType[]), 12))) + emit_jl_throw!( + B, + addrspacecast!(B, err, LLVM.PointerType(LLVM.StructType(LLVMType[]), 12)), + ) end # 2. Call error function and insert unreachable - LLVM.API.LLVMAddCallSiteAttribute(ct, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), EnumAttribute("noreturn")) + LLVM.API.LLVMAddCallSiteAttribute( + ct, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + EnumAttribute("noreturn"), + ) if EnzymeMutabilityException != errty - LLVM.API.LLVMAddCallSiteAttribute(ct, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("enzyme_error")) + LLVM.API.LLVMAddCallSiteAttribute( + ct, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + StringAttribute("enzyme_error"), + ) end return ct end @@ -1839,15 +2289,21 @@ function prepare_llvm(mod, job, meta) continue end llvmfn = functions(mod)[k_name] - + RT = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype _, _, returnRoots = get_return_info(RT) returnRoots = returnRoots !== nothing attributes = function_attributes(llvmfn) - push!(attributes, StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi))))) - push!(attributes, StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(RT))))) + push!( + attributes, + StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi)))), + ) + push!( + attributes, + StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(RT)))), + ) if returnRoots attr = StringAttribute("enzymejl_returnRoots", "") push!(parameter_attributes(llvmfn, 2), attr) @@ -1860,10 +2316,15 @@ function prepare_llvm(mod, job, meta) end end -function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, funcspec::Core.MethodInstance, world) +function nested_codegen!( + mode::API.CDerivativeMode, + mod::LLVM.Module, + funcspec::Core.MethodInstance, + world, +) # TODO: Put a cache here index on `mod` and f->tt - + # 3) Use the MI to create the correct augmented fwd/reverse # TODO: # - GPU support @@ -1871,16 +2332,23 @@ function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, funcspec:: target = DefaultCompilerTarget() params = PrimalCompilerParams(mode) - job = CompilerJob(funcspec, CompilerConfig(target, params; kernel=false), world) + job = CompilerJob(funcspec, CompilerConfig(target, params; kernel = false), world) # TODO parent_job = nothing - otherMod, meta = GPUCompiler.codegen(:llvm, job; optimize=false, cleanup=false, validate=false, parent_job=parent_job) + otherMod, meta = GPUCompiler.codegen( + :llvm, + job; + optimize = false, + cleanup = false, + validate = false, + parent_job = parent_job, + ) prepare_llvm(otherMod, job, meta) entry = name(meta.entry) - + for f in functions(otherMod) permit_inlining!(f) end @@ -1907,7 +2375,7 @@ function removed_ret_parms(F::LLVM.Function) parmrem = nothing retRemove = false for a in collect(function_attributes(F)) - if isa(a, StringAttribute) + if isa(a, StringAttribute) if kind(a) == "enzyme_parmremove" parmrem = a end @@ -1928,8 +2396,8 @@ end abstract type CompilationException <: Base.Exception end struct NoDerivativeException <: CompilationException msg::String - ir::Union{Nothing, String} - bt::Union{Nothing, Vector{StackTraces.StackFrame}} + ir::Union{Nothing,String} + bt::Union{Nothing,Vector{StackTraces.StackFrame}} end function Base.showerror(io::IO, ece::NoDerivativeException) @@ -1948,8 +2416,8 @@ end struct IllegalTypeAnalysisException <: CompilationException msg::String sval::String - ir::Union{Nothing, String} - bt::Union{Nothing, Vector{StackTraces.StackFrame}} + ir::Union{Nothing,String} + bt::Union{Nothing,Vector{StackTraces.StackFrame}} end function Base.showerror(io::IO, ece::IllegalTypeAnalysisException) @@ -1962,7 +2430,7 @@ function Base.showerror(io::IO, ece::IllegalTypeAnalysisException) write(io, ece.sval) print(io, '\n', ece.msg, '\n') if ece.bt !== nothing - print(io,"\nCaused by:") + print(io, "\nCaused by:") Base.show_backtrace(io, ece.bt) println(io) end @@ -1970,8 +2438,8 @@ end struct IllegalFirstPointerException <: CompilationException msg::String - ir::Union{Nothing, String} - bt::Union{Nothing, Vector{StackTraces.StackFrame}} + ir::Union{Nothing,String} + bt::Union{Nothing,Vector{StackTraces.StackFrame}} end function Base.showerror(io::IO, ece::IllegalFirstPointerException) @@ -1989,8 +2457,8 @@ end struct EnzymeInternalError <: CompilationException msg::String - ir::Union{Nothing, String} - bt::Union{Nothing, Vector{StackTraces.StackFrame}} + ir::Union{Nothing,String} + bt::Union{Nothing,Vector{StackTraces.StackFrame}} end function Base.showerror(io::IO, ece::EnzymeInternalError) @@ -2006,63 +2474,76 @@ function Base.showerror(io::IO, ece::EnzymeInternalError) end end -parent_scope(val::LLVM.Function, depth=0) = depth==0 ? LLVM.parent(val) : val -parent_scope(val::LLVM.Module, depth=0) = val -parent_scope(val::LLVM.Value, depth=0) = parent_scope(LLVM.parent(val), depth+1) -parent_scope(val::LLVM.Argument, depth=0) = parent_scope(LLVM.Function(LLVM.API.LLVMGetParamParent(val)), depth+1) +parent_scope(val::LLVM.Function, depth = 0) = depth == 0 ? LLVM.parent(val) : val +parent_scope(val::LLVM.Module, depth = 0) = val +parent_scope(val::LLVM.Value, depth = 0) = parent_scope(LLVM.parent(val), depth + 1) +parent_scope(val::LLVM.Argument, depth = 0) = + parent_scope(LLVM.Function(LLVM.API.LLVMGetParamParent(val)), depth + 1) -const CheckNan = Ref(false) -function julia_sanitize(orig::LLVM.API.LLVMValueRef, val::LLVM.API.LLVMValueRef, B::LLVM.API.LLVMBuilderRef, mask::LLVM.API.LLVMValueRef)::LLVM.API.LLVMValueRef - orig = LLVM.Value(orig) - val = LLVM.Value(val) - B = LLVM.IRBuilder(B) - if CheckNan[] - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - ty = LLVM.value_type(val) - vt = LLVM.VoidType() - FT = LLVM.FunctionType(vt, [ty, LLVM.PointerType(LLVM.Int8Type())]) +const CheckNan = Ref(false) +function julia_sanitize( + orig::LLVM.API.LLVMValueRef, + val::LLVM.API.LLVMValueRef, + B::LLVM.API.LLVMBuilderRef, + mask::LLVM.API.LLVMValueRef, +)::LLVM.API.LLVMValueRef + orig = LLVM.Value(orig) + val = LLVM.Value(val) + B = LLVM.IRBuilder(B) + if CheckNan[] + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + ty = LLVM.value_type(val) + vt = LLVM.VoidType() + FT = LLVM.FunctionType(vt, [ty, LLVM.PointerType(LLVM.Int8Type())]) - stringv = "Enzyme: Found nan while computing derivative of "*string(orig) - if orig !== nothing && isa(orig, LLVM.Instruction) - bt = GPUCompiler.backtrace(orig) - function printBT(io) - print(io,"\nCaused by:") - Base.show_backtrace(io, bt) + stringv = "Enzyme: Found nan while computing derivative of " * string(orig) + if orig !== nothing && isa(orig, LLVM.Instruction) + bt = GPUCompiler.backtrace(orig) + function printBT(io) + print(io, "\nCaused by:") + Base.show_backtrace(io, bt) + end + stringv *= sprint(io -> Base.show_backtrace(io, bt)) end - stringv*=sprint(io->Base.show_backtrace(io, bt)) - end - fn, _ = get_function!(mod, "julia.sanitize."*string(ty), FT) - if isempty(blocks(fn)) - let builder = IRBuilder() - entry = BasicBlock(fn, "entry") - good = BasicBlock(fn, "good") - bad = BasicBlock(fn, "bad") - position!(builder, entry) - inp, sval = collect(parameters(fn)) - cmp = fcmp!(builder, LLVM.API.LLVMRealUNO, inp, inp) + fn, _ = get_function!(mod, "julia.sanitize." * string(ty), FT) + if isempty(blocks(fn)) + let builder = IRBuilder() + entry = BasicBlock(fn, "entry") + good = BasicBlock(fn, "good") + bad = BasicBlock(fn, "bad") + position!(builder, entry) + inp, sval = collect(parameters(fn)) + cmp = fcmp!(builder, LLVM.API.LLVMRealUNO, inp, inp) - br!(builder, cmp, bad, good) + br!(builder, cmp, bad, good) - position!(builder, good) - ret!(builder) + position!(builder, good) + ret!(builder) - position!(builder, bad) + position!(builder, bad) - emit_error(builder, nothing, sval, EnzymeNoDerivativeError) - unreachable!(builder) - dispose(builder) + emit_error(builder, nothing, sval, EnzymeNoDerivativeError) + unreachable!(builder) + dispose(builder) + end end + # val = + call!(B, FT, fn, LLVM.Value[val, globalstring_ptr!(B, stringv)]) end - # val = - call!(B, FT, fn, LLVM.Value[val, globalstring_ptr!(B, stringv)]) - end - return val.ref + return val.ref end -function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.ErrorType, data::Ptr{Cvoid}, data2::LLVM.API.LLVMValueRef, B::LLVM.API.LLVMBuilderRef)::LLVM.API.LLVMValueRef +function julia_error( + cstr::Cstring, + val::LLVM.API.LLVMValueRef, + errtype::API.ErrorType, + data::Ptr{Cvoid}, + data2::LLVM.API.LLVMValueRef, + B::LLVM.API.LLVMBuilderRef, +)::LLVM.API.LLVMValueRef msg = Base.unsafe_string(cstr) bt = nothing ir = nothing @@ -2098,12 +2579,15 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end if errtype == API.ET_NoDerivative - if occursin("No create nofree of empty function", msg) || occursin("No forward mode derivative found for", msg) || occursin("No augmented forward pass", msg) || occursin("No reverse pass found", msg) + if occursin("No create nofree of empty function", msg) || + occursin("No forward mode derivative found for", msg) || + occursin("No augmented forward pass", msg) || + occursin("No reverse pass found", msg) ir = nothing end if B != C_NULL B = IRBuilder(B) - msg2 = sprint() do io + msg2 = sprint() do io if ir !== nothing print(io, "Current scope: \n") print(io, ir) @@ -2124,7 +2608,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err msgN = sprint() do io::IO if isa(val, LLVM.Argument) fn = parent_scope(val) - ir = string(LLVM.name(fn))*string(function_type(fn)) + ir = string(LLVM.name(fn)) * string(function_type(fn)) print(io, "Current scope: \n") print(io, ir) end @@ -2137,7 +2621,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end print(io, '\n', msg, '\n') if bt !== nothing - print(io,"\nCaused by:") + print(io, "\nCaused by:") Base.show_backtrace(io, bt) println(io) end @@ -2149,9 +2633,12 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err ip = API.EnzymeTypeAnalyzerToString(data) sval = Base.unsafe_string(ip) API.EnzymeStringFree(ip) - + if isa(val, LLVM.Instruction) - mi, rt = enzyme_custom_extract_mi(LLVM.parent(LLVM.parent(val))::LLVM.Function, #=error=#false) + mi, rt = enzyme_custom_extract_mi( + LLVM.parent(LLVM.parent(val))::LLVM.Function, + false, + ) #=error=# if mi !== nothing msg *= "\n" * string(mi) * "\n" end @@ -2160,7 +2647,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err elseif errtype == API.ET_NoType @assert B != C_NULL B = IRBuilder(B) - + data = API.EnzymeTypeAnalyzerRef(data) ip = API.EnzymeTypeAnalyzerToString(data) sval = Base.unsafe_string(ip) @@ -2177,12 +2664,12 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end print(io, '\n', msg, '\n') if bt !== nothing - print(io,"\nCaused by:") + print(io, "\nCaused by:") Base.show_backtrace(io, bt) println(io) end pscope = parent_scope(val) - mi, rt = enzyme_custom_extract_mi(pscope, #=error=#false) + mi, rt = enzyme_custom_extract_mi(pscope, false) #=error=# if mi !== nothing println(io, "within ", mi) end @@ -2227,18 +2714,20 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err badval = nothing gutils = GradientUtils(API.EnzymeGradientUtilsRef(data)) # Ignore mismatched activity if phi/store of ghost - seen = Dict{LLVM.Value, LLVM.Value}() + seen = Dict{LLVM.Value,LLVM.Value}() illegal = false - created = LLVM.Instruction[] + created = LLVM.Instruction[] world = enzyme_extract_world(LLVM.parent(position(IRBuilder(B)))) - width = get_width(gutils) + width = get_width(gutils) function make_batched(cur, B) if width == 1 return cur else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur)))) - for idx in 1:width - shadowres = insert_value!(B, shadowres, cur, idx-1) + shadowres = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))), + ) + for idx = 1:width + shadowres = insert_value!(B, shadowres, cur, idx - 1) if isa(shadowres, LLVM.Instruction) push!(created, shadowres) end @@ -2254,8 +2743,8 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err if cur in keys(seen) return seen[cur] end - - legal, TT = abs_typeof(cur, true) + + legal, TT, byref = abs_typeof(cur, true) if legal if guaranteed_const_nongen(TT, world) return make_batched(ncur, prevbb) @@ -2264,16 +2753,21 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err legal2, obj = absint(cur) # Only do so for the immediate operand/etc to a phi, since otherwise we will make multiple - if legal2 && active_reg_inner(TT, (), world) == ActiveState && isa(cur, LLVM.ConstantExpr) && cur == data2 + if legal2 && + active_reg_inner(TT, (), world) == ActiveState && + isa(cur, LLVM.ConstantExpr) && + cur == data2 if width == 1 res = emit_allocobj!(prevbb, Base.RefValue{TT}) push!(created, res) return res else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur)))) - for idx in 1:width + shadowres = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))), + ) + for idx = 1:width res = emit_allocobj!(prevbb, Base.RefValue{TT}) - shadowres = insert_value!(prevbb, shadowres, res, idx-1) + shadowres = insert_value!(prevbb, shadowres, res, idx - 1) push!(created, shadowres) end return shadowres @@ -2281,15 +2775,15 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end badval = if legal2 - string(obj)*" of type"*" "*string(TT) + string(obj) * " of type" * " " * string(TT) else - "Unknown object of type"*" "*string(TT) + "Unknown object of type" * " " * string(TT) end illegalVal = cur illegal = true return make_batched(ncur, prevbb) end - + if isa(cur, LLVM.PointerNull) return make_batched(ncur, prevbb) end @@ -2297,9 +2791,9 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err return make_batched(ncur, prevbb) end @static if LLVM.version() >= v"12" - if isa(cur, LLVM.PoisonValue) - return make_batched(ncur, prevbb) - end + if isa(cur, LLVM.PoisonValue) + return make_batched(ncur, prevbb) + end end if isa(cur, LLVM.ConstantAggregateZero) return make_batched(ncur, prevbb) @@ -2313,10 +2807,10 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end end if isa(cur, LLVM.ConstantFP) - return make_batched(ConstantFP(value_type(cur), 0), prevbb) + return make_batched(ConstantFP(value_type(cur), 0), prevbb) end if isa(cur, LLVM.ConstantDataSequential) - cvals = LLVM.Value[] + cvals = LLVM.Value[] changed = false for v in collect(cur) tmp = make_replacement(v, prevbb) @@ -2340,20 +2834,23 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err return cur2 end if isa(cur, LLVM.ConstantInt) - if LLVM.width(value_type(cur)) <= sizeof(Int)*8 + if LLVM.width(value_type(cur)) <= sizeof(Int) * 8 return make_batched(ncur, prevbb) end - if LLVM.width(value_type(cur)) == sizeof(Int)*8 && abs(convert(Int, cur)) < 10000 + if LLVM.width(value_type(cur)) == sizeof(Int) * 8 && + abs(convert(Int, cur)) < 10000 return make_batched(ncur, prevbb) end # if storing a constant int as a non-pointer, presume it is not a GC'd var and is safe # for activity state to mix - if isa(val, LLVM.StoreInst) operands(val)[1] == cur && !isa(value_type(operands(val)[1]), LLVM.PointerType) + if isa(val, LLVM.StoreInst) + operands(val)[1] == cur && + !isa(value_type(operands(val)[1]), LLVM.PointerType) return make_batched(ncur, prevbb) end end - - if isa(cur, LLVM.SelectInst) + + if isa(cur, LLVM.SelectInst) lhs = make_replacement(operands(cur)[2], prevbb) if illegal return ncur @@ -2366,22 +2863,37 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err return make_batched(ncur, prevbb) end if width == 1 - nv = select!(prevbb, new_from_original(gutils, operands(cur)[1]), lhs, rhs) + nv = select!( + prevbb, + new_from_original(gutils, operands(cur)[1]), + lhs, + rhs, + ) push!(created, nv) seen[cur] = nv return nv else shadowres = LLVM.UndefValue(value_type(lhs)) - for idx in 1:width - shadowres = insert_value!(prevbb, shadowres, select!(prevbb, new_from_original(gutils, operands(cur)[1]), extract_value!(prevbb, lhs, idx-1), extract_value!(prevbb, rhs, idx-1)), idx-1) + for idx = 1:width + shadowres = insert_value!( + prevbb, + shadowres, + select!( + prevbb, + new_from_original(gutils, operands(cur)[1]), + extract_value!(prevbb, lhs, idx - 1), + extract_value!(prevbb, rhs, idx - 1), + ), + idx - 1, + ) if isa(shadowres, LLVM.Instruction) push!(created, shadowres) end end return shadowres end - end - + end + if isa(cur, LLVM.InsertValueInst) lhs = make_replacement(operands(cur)[1], prevbb) if illegal @@ -2396,7 +2908,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end inds = LLVM.API.LLVMGetIndices(cur.ref) ninds = LLVM.API.LLVMGetNumIndices(cur.ref) - jinds = Cuint[unsafe_load(inds, i) for i in 1:ninds] + jinds = Cuint[unsafe_load(inds, i) for i = 1:ninds] if width == 1 nv = API.EnzymeInsertValue(prevbb, lhs, rhs, jinds) push!(created, nv) @@ -2404,10 +2916,15 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err return nv else shadowres = lhs - for idx in 1:width + for idx = 1:width jindsv = copy(jinds) - pushfirst!(jindsv, idx-1) - shadowres = API.EnzymeInsertValue(prevbb, shadowres, extract_value!(prevbb, rhs, idx-1), jindsv) + pushfirst!(jindsv, idx - 1) + shadowres = API.EnzymeInsertValue( + prevbb, + shadowres, + extract_value!(prevbb, rhs, idx - 1), + jindsv, + ) if isa(shadowres, LLVM.Instruction) push!(created, shadowres) end @@ -2415,15 +2932,15 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err return shadowres end end - + if isa(cur, LLVM.PHIInst) Bphi = IRBuilder() position!(Bphi, ncur) shadowty = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))) - phi2 = phi!(Bphi, shadowty, "tempphi"*LLVM.name(cur)) + phi2 = phi!(Bphi, shadowty, "tempphi" * LLVM.name(cur)) seen[cur] = phi2 changed = false - recsize = length(created)+1 + recsize = length(created) + 1 for (v, bb) in LLVM.incoming(cur) B2 = IRBuilder() position!(B2, new_from_original(gutils, last(instructions(bb)))) @@ -2442,15 +2959,15 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err LLVM.API.LLVMInstructionEraseFromParent(phi2) seen[cur] = ncur plen = length(created) - for i in recsize:plen + for i = recsize:plen u = created[i] replace_uses!(u, LLVM.UndefValue(value_type(u))) end - for i in recsize:plen + for i = recsize:plen u = created[i] LLVM.API.LLVMInstructionEraseFromParent(u) end - for i in recsize:plen + for i = recsize:plen pop!(created) end return illegal ? ncur : make_batched(ncur, prevbb) @@ -2477,7 +2994,10 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err LLVM.API.LLVMInstructionEraseFromParent(u) end if LLVM.API.LLVMIsAReturnInst(val) != C_NULL - mi, rt = enzyme_custom_extract_mi(LLVM.parent(LLVM.parent(val))::LLVM.Function, #=error=#false) + mi, rt = enzyme_custom_extract_mi( + LLVM.parent(LLVM.parent(val))::LLVM.Function, + false, + ) #=error=# if mi !== nothing && isghostty(rt) return C_NULL end @@ -2486,7 +3006,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err print(io, msg) println(io) if badval !== nothing - println(io, " value="*badval) + println(io, " value=" * badval) else ttval = val if isa(ttval, LLVM.StoreInst) @@ -2499,7 +3019,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err API.EnzymeStringFree(st) end if illegalVal !== nothing - println(io, " llvalue="*string(illegalVal)) + println(io, " llvalue=" * string(illegalVal)) end if bt !== nothing Base.show_backtrace(io, bt) @@ -2512,9 +3032,9 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err B = IRBuilder(B) msg5 = sprint() do io::IO print(io, "Enzyme internal error\n") - print(io, msg, '\n') + print(io, msg, '\n') if bt !== nothing - print(io,"\nCaused by:") + print(io, "\nCaused by:") Base.show_backtrace(io, bt) println(io) end @@ -2535,7 +3055,7 @@ function any_jltypes(Type::LLVM.PointerType) end any_jltypes(Type::LLVM.StructType) = any(any_jltypes, LLVM.elements(Type)) -any_jltypes(Type::Union{LLVM.VectorType, LLVM.ArrayType}) = any_jltypes(eltype(Type)) +any_jltypes(Type::Union{LLVM.VectorType,LLVM.ArrayType}) = any_jltypes(eltype(Type)) any_jltypes(::LLVM.IntegerType) = false any_jltypes(::LLVM.FloatingPointType) = false any_jltypes(::LLVM.VoidType) = false @@ -2543,12 +3063,13 @@ any_jltypes(::LLVM.VoidType) = false @inline any_jltypes(::Type{Nothing}) = false @inline any_jltypes(::Type{T}) where {T<:AbstractFloat} = false @inline any_jltypes(::Type{T}) where {T<:Integer} = false -@inline any_jltypes(::Type{Complex{T}}) where T = any_jltypes(T) +@inline any_jltypes(::Type{Complex{T}}) where {T} = any_jltypes(T) @inline any_jltypes(::Type{Tuple{}}) = false -@inline any_jltypes(::Type{NTuple{Size, T}}) where {Size, T} = any_jltypes(T) -@inline any_jltypes(::Type{Core.LLVMPtr{T, Addr}}) where {T, Addr} = 10 <= Addr <= 12 +@inline any_jltypes(::Type{NTuple{Size,T}}) where {Size,T} = any_jltypes(T) +@inline any_jltypes(::Type{Core.LLVMPtr{T,Addr}}) where {T,Addr} = 10 <= Addr <= 12 @inline any_jltypes(::Type{Any}) = true -@inline any_jltypes(::Type{NamedTuple{A,B}}) where {A,B} = any(any_jltypes(b) for b in B.parameters) +@inline any_jltypes(::Type{NamedTuple{A,B}}) where {A,B} = + any(any_jltypes(b) for b in B.parameters) @inline any_jltypes(::Type{T}) where {T<:Tuple} = any(any_jltypes(b) for b in T.parameters) nfields(Type::LLVM.StructType) = length(LLVM.elements(Type)) @@ -2559,9 +3080,9 @@ nfields(Type::LLVM.PointerType) = 1 mutable struct EnzymeTapeToLoad{T} data::T end -Base.eltype(::EnzymeTapeToLoad{T}) where T = T +Base.eltype(::EnzymeTapeToLoad{T}) where {T} = T -const TapeTypes = Dict{String, DataType}() +const TapeTypes = Dict{String,DataType}() base_type(T::UnionAll) = base_type(T.body) base_type(T::DataType) = T @@ -2570,8 +3091,10 @@ const WideIntWidths = [256, 512, 1024, 2048] let for n ∈ WideIntWidths - let T = Symbol(:UInt,n) - eval(quote primitive type $T <: Unsigned $n end end) + let T = Symbol(:UInt, n) + eval(quote + primitive type $T <: Unsigned $n end + end) end end end @@ -2583,8 +3106,8 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} nelems = LLVM.API.LLVMCountStructElementTypes(Type) containsAny = false syms = Symbol[] - for i in 1:nelems - e = LLVM.API.LLVMStructGetTypeAtIndex(Type, i-1) + for i = 1:nelems + e = LLVM.API.LLVMStructGetTypeAtIndex(Type, i - 1) T, sub = to_tape_type(e) containsAny |= sub push!(tys, T) @@ -2593,7 +3116,7 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} Tup = Tuple{tys...} if containsAny res = (syms...,) - return NamedTuple{res, Tup}, false + return NamedTuple{res,Tup}, false else return Tup, false end @@ -2606,9 +3129,9 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} e = LLVM.API.LLVMGetElementType(Type) tkind2 = LLVM.API.LLVMGetTypeKind(e) if tkind2 == LLVM.API.LLVMFunctionTypeKind - return Core.LLVMPtr{Cvoid, Int(addrspace)}, false + return Core.LLVMPtr{Cvoid,Int(addrspace)}, false else - return Core.LLVMPtr{to_tape_type(e)[1], Int(addrspace)}, false + return Core.LLVMPtr{to_tape_type(e)[1],Int(addrspace)}, false end end end @@ -2616,9 +3139,9 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} e = LLVM.API.LLVMGetElementType(Type) T, sub = to_tape_type(e) len = Int(LLVM.API.LLVMGetArrayLength(Type)) - Tup = NTuple{len, T} + Tup = NTuple{len,T} if sub - return NamedTuple{ntuple(Core.Symbol, Val(len)), Tup}, false + return NamedTuple{ntuple(Core.Symbol, Val(len)),Tup}, false else return Tup, false end @@ -2627,9 +3150,9 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} e = LLVM.API.LLVMGetElementType(Type) T, sub = to_tape_type(e) len = Int(LLVM.API.LLVMGetVectorSize(Type)) - Tup = NTuple{len, T} + Tup = NTuple{len,T} if sub - return NamedTuple{ntuple(Core.Symbol, Val(len)), Tup}, false + return NamedTuple{ntuple(Core.Symbol, Val(len)),Tup}, false else return Tup, false end @@ -2637,7 +3160,7 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} if tkind == LLVM.API.LLVMIntegerTypeKind N = LLVM.API.LLVMGetIntTypeWidth(Type) if N == 1 - return Bool, false + return Bool, false elseif N == 8 return UInt8, false elseif N == 16 @@ -2683,10 +3206,12 @@ function tape_type(LLVMType::LLVM.LLVMType) return TT end -from_tape_type(::Type{T}) where T<:AbstractFloat = convert(LLVMType, T) -from_tape_type(::Type{T}) where T<:Integer = convert(LLVMType, T) -from_tape_type(::Type{NTuple{Size, T}}) where {Size, T} = LLVM.ArrayType(from_tape_type(T), Size) -from_tape_type(::Type{Core.LLVMPtr{T, Addr}}) where {T, Addr} = LLVM.PointerType(from_tape_type(UInt8), Addr) +from_tape_type(::Type{T}) where {T<:AbstractFloat} = convert(LLVMType, T) +from_tape_type(::Type{T}) where {T<:Integer} = convert(LLVMType, T) +from_tape_type(::Type{NTuple{Size,T}}) where {Size,T} = + LLVM.ArrayType(from_tape_type(T), Size) +from_tape_type(::Type{Core.LLVMPtr{T,Addr}}) where {T,Addr} = + LLVM.PointerType(from_tape_type(UInt8), Addr) # from_tape_type(::Type{Core.LLVMPtr{T, Addr}}, ctx) where {T, Addr} = LLVM.PointerType(from_tape_type(T, ctx), Addr) from_tape_type(::Type{Any}) = LLVM.PointerType(LLVM.StructType(LLVM.LLVMType[]), Tracked) function from_tape_type(::Type{NamedTuple{A,B}}) where {A,B} @@ -2702,10 +3227,12 @@ function from_tape_type(::Type{B}) where {B<:Tuple} end # See get_current_task_from_pgcstack (used from 1.7+) -current_task_offset() = -(unsafe_load(cglobal(:jl_task_gcstack_offset, Cint)) ÷ sizeof(Ptr{Cvoid})) +current_task_offset() = + -(unsafe_load(cglobal(:jl_task_gcstack_offset, Cint)) ÷ sizeof(Ptr{Cvoid})) # See get_current_ptls_from_task (used from 1.7+) -current_ptls_offset() = unsafe_load(cglobal(:jl_task_ptls_offset, Cint)) ÷ sizeof(Ptr{Cvoid}) +current_ptls_offset() = + unsafe_load(cglobal(:jl_task_ptls_offset, Cint)) ÷ sizeof(Ptr{Cvoid}) function store_nonjl_types!(B, startval, p) T_jlvalue = LLVM.StructType(LLVMType[]) @@ -2714,7 +3241,7 @@ function store_nonjl_types!(B, startval, p) if p != nothing push!(vals, p) end - todo = Tuple{Tuple, LLVM.Value}[((), startval)] + todo = Tuple{Tuple,LLVM.Value}[((), startval)] while length(todo) != 0 path, cur = popfirst!(todo) ty = value_type(cur) @@ -2725,9 +3252,9 @@ function store_nonjl_types!(B, startval, p) end if isa(ty, LLVM.ArrayType) if any_jltypes(ty) - for i=1:length(ty) - ev = extract_value!(B, cur, i-1) - push!(todo, ((path..., i-1), ev)) + for i = 1:length(ty) + ev = extract_value!(B, cur, i - 1) + push!(todo, ((path..., i - 1), ev)) end continue end @@ -2735,8 +3262,8 @@ function store_nonjl_types!(B, startval, p) if isa(ty, LLVM.StructType) if any_jltypes(ty) for (i, t) in enumerate(LLVM.elements(ty)) - ev = extract_value!(B, cur, i-1) - push!(todo, ((path..., i-1), ev)) + ev = extract_value!(B, cur, i - 1) + push!(todo, ((path..., i - 1), ev)) end continue end @@ -2751,7 +3278,7 @@ function store_nonjl_types!(B, startval, p) return end -function get_julia_inner_types(B, p, startvals...; added=LLVM.API.LLVMValueRef[]) +function get_julia_inner_types(B, p, startvals...; added = LLVM.API.LLVMValueRef[]) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) vals = LLVM.Value[] @@ -2765,7 +3292,12 @@ function get_julia_inner_types(B, p, startvals...; added=LLVM.API.LLVMValueRef[] if isa(ty, LLVM.PointerType) if any_jltypes(ty) if addrspace(ty) != Tracked - cur = addrspacecast!(B, cur, LLVM.PointerType(eltype(ty), Tracked), LLVM.name(cur)*".innertracked") + cur = addrspacecast!( + B, + cur, + LLVM.PointerType(eltype(ty), Tracked), + LLVM.name(cur) * ".innertracked", + ) if isa(cur, LLVM.Instruction) push!(added, cur.ref) end @@ -2782,8 +3314,8 @@ function get_julia_inner_types(B, p, startvals...; added=LLVM.API.LLVMValueRef[] end if isa(ty, LLVM.ArrayType) if any_jltypes(ty) - for i=1:length(ty) - ev = extract_value!(B, cur, i-1) + for i = 1:length(ty) + ev = extract_value!(B, cur, i - 1) if isa(ev, LLVM.Instruction) push!(added, ev.ref) end @@ -2795,7 +3327,7 @@ function get_julia_inner_types(B, p, startvals...; added=LLVM.API.LLVMValueRef[] if isa(ty, LLVM.StructType) for (i, t) in enumerate(LLVM.elements(ty)) if any_jltypes(t) - ev = extract_value!(B, cur, i-1) + ev = extract_value!(B, cur, i - 1) if isa(ev, LLVM.Instruction) push!(added, ev.ref) end @@ -2822,14 +3354,20 @@ function get_julia_inner_types(B, p, startvals...; added=LLVM.API.LLVMValueRef[] return vals end -function julia_post_cache_store(SI::LLVM.API.LLVMValueRef, B::LLVM.API.LLVMBuilderRef, R2)::Ptr{LLVM.API.LLVMValueRef} +function julia_post_cache_store( + SI::LLVM.API.LLVMValueRef, + B::LLVM.API.LLVMBuilderRef, + R2, +)::Ptr{LLVM.API.LLVMValueRef} B = LLVM.IRBuilder(B) SI = LLVM.Instruction(SI) v = operands(SI)[1] p = operands(SI)[2] added = LLVM.API.LLVMValueRef[] while true - if isa(p, LLVM.GetElementPtrInst) || isa(p, LLVM.BitCastInst) || isa(p, LLVM.AddrSpaceCastInst) + if isa(p, LLVM.GetElementPtrInst) || + isa(p, LLVM.BitCastInst) || + isa(p, LLVM.AddrSpaceCastInst) p = operands(p)[1] continue end @@ -2845,14 +3383,17 @@ function julia_post_cache_store(SI::LLVM.API.LLVMValueRef, B::LLVM.API.LLVMBuild end p = pn - vals = get_julia_inner_types(B, p, v, added=added) + vals = get_julia_inner_types(B, p, v, added = added) r = emit_writebarrier!(B, vals) @assert isa(r, LLVM.Instruction) push!(added, r.ref) end if R2 != C_NULL unsafe_store!(R2, length(added)) - ptr = Base.unsafe_convert(Ptr{LLVM.API.LLVMValueRef}, Libc.malloc(sizeof(LLVM.API.LLVMValueRef)*length(added))) + ptr = Base.unsafe_convert( + Ptr{LLVM.API.LLVMValueRef}, + Libc.malloc(sizeof(LLVM.API.LLVMValueRef) * length(added)), + ) for (i, v) in enumerate(added) @assert isa(LLVM.Value(v), LLVM.Instruction) unsafe_store!(ptr, v, i) @@ -2868,7 +3409,11 @@ function julia_default_tape_type(C::LLVM.API.LLVMContextRef) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) return T_prjlvalue.ref end -function julia_undef_value_for_type(mod::LLVM.API.LLVMModuleRef, Ty::LLVM.API.LLVMTypeRef, forceZero::UInt8)::LLVM.API.LLVMValueRef +function julia_undef_value_for_type( + mod::LLVM.API.LLVMModuleRef, + Ty::LLVM.API.LLVMTypeRef, + forceZero::UInt8, +)::LLVM.API.LLVMValueRef ty = LLVM.LLVMType(Ty) if !any_jltypes(ty) if forceZero != 0 @@ -2889,7 +3434,7 @@ function julia_undef_value_for_type(mod::LLVM.API.LLVMModuleRef, Ty::LLVM.API.LL end if isa(ty, LLVM.ArrayType) st = LLVM.Value(julia_undef_value_for_type(mod, eltype(ty).ref, forceZero)) - return ConstantArray(eltype(ty), [st for i in 1:length(ty)]).ref + return ConstantArray(eltype(ty), [st for i = 1:length(ty)]).ref end if isa(ty, LLVM.StructType) vals = LLVM.Constant[] @@ -2905,11 +3450,13 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie V = LLVM.CallInst(V) gutils = GradientUtils(gutils) mode = get_mode(gutils) - if mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient || mode == API.DEM_ReverseModeCombined + if mode == API.DEM_ReverseModePrimal || + mode == API.DEM_ReverseModeGradient || + mode == API.DEM_ReverseModeCombined fn = LLVM.parent(LLVM.parent(V)) world = enzyme_extract_world(fn) - has, Ty = abs_typeof(V) - @assert has + has, Ty, byref = abs_typeof(V) + @assert has rt = active_reg_inner(Ty, (), world) if rt == ActiveState || rt == MixedState B = LLVM.IRBuilder() @@ -2920,7 +3467,14 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie nothing end -function julia_allocator(B::LLVM.API.LLVMBuilderRef, LLVMType::LLVM.API.LLVMTypeRef, Count::LLVM.API.LLVMValueRef, AlignedSize::LLVM.API.LLVMValueRef, IsDefault::UInt8, ZI) +function julia_allocator( + B::LLVM.API.LLVMBuilderRef, + LLVMType::LLVM.API.LLVMTypeRef, + Count::LLVM.API.LLVMValueRef, + AlignedSize::LLVM.API.LLVMValueRef, + IsDefault::UInt8, + ZI, +) B = LLVM.IRBuilder(B) Count = LLVM.Value(Count) AlignedSize = LLVM.Value(AlignedSize) @@ -2972,7 +3526,11 @@ function zero_single_allocation(builder, jlType, LLVMType, nobj, zeroAll, idx) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) - todo = Tuple{Vector{LLVM.Value},LLVM.LLVMType,DataType}[(LLVM.Value[idx], LLVMType, jlType)] + todo = Tuple{Vector{LLVM.Value},LLVM.LLVMType,DataType}[( + LLVM.Value[idx], + LLVMType, + jlType, + )] while length(todo) != 0 path, ty, jlty = popfirst!(todo) @@ -2981,7 +3539,11 @@ function zero_single_allocation(builder, jlType, LLVMType, nobj, zeroAll, idx) loc = gep!(builder, LLVMType, nobj, path) mod = LLVM.parent(LLVM.parent(Base.position(builder))) fill_val = unsafe_nothing_to_llvm(mod) - loc = bitcast!(builder, loc, LLVM.PointerType(T_prjlvalue, addrspace(value_type(loc)))) + loc = bitcast!( + builder, + loc, + LLVM.PointerType(T_prjlvalue, addrspace(value_type(loc))), + ) store!(builder, fill_val, loc) elseif zeroAll loc = gep!(builder, LLVMType, nobj, path) @@ -2996,36 +3558,36 @@ function zero_single_allocation(builder, jlType, LLVMType, nobj, zeroAll, idx) end continue end - if isa(ty, LLVM.ArrayType) - for i=1:length(ty) + if isa(ty, LLVM.ArrayType) + for i = 1:length(ty) npath = copy(path) - push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) + push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i - 1)) push!(todo, (npath, eltype(ty), eltype(jlty))) end continue end - if isa(ty, LLVM.VectorType) - for i=1:size(ty) + if isa(ty, LLVM.VectorType) + for i = 1:size(ty) npath = copy(path) - push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) + push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i - 1)) push!(todo, (npath, eltype(ty), eltype(jlty))) end continue end if isa(ty, LLVM.StructType) i = 1 - for ii in 1:fieldcount(jlty) + for ii = 1:fieldcount(jlty) jlet = fieldtype(jlty, ii) if isghostty(jlet) || Core.Compiler.isconstType(jlet) continue end t = LLVM.elements(ty)[i] npath = copy(path) - push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) + push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i - 1)) push!(todo, (npath, t, jlet)) - i+=1 + i += 1 end - @assert i == Int(length(LLVM.elements(ty)))+1 + @assert i == Int(length(LLVM.elements(ty))) + 1 continue end end @@ -3034,7 +3596,15 @@ function zero_single_allocation(builder, jlType, LLVMType, nobj, zeroAll, idx) end -function zero_allocation(B::LLVM.IRBuilder, jlType, LLVMType, obj, AlignedSize, Size, zeroAll::Bool)::LLVM.API.LLVMValueRef +function zero_allocation( + B::LLVM.IRBuilder, + jlType, + LLVMType, + obj, + AlignedSize, + Size, + zeroAll::Bool, +)::LLVM.API.LLVMValueRef func = LLVM.parent(position(B)) mod = LLVM.parent(func) T_int8 = LLVM.Int8Type() @@ -3043,7 +3613,11 @@ function zero_allocation(B::LLVM.IRBuilder, jlType, LLVMType, obj, AlignedSize, T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) - wrapper_f = LLVM.Function(mod, "zeroType", LLVM.FunctionType(LLVM.VoidType(), [value_type(obj), T_int8, value_type(Size)])) + wrapper_f = LLVM.Function( + mod, + "zeroType", + LLVM.FunctionType(LLVM.VoidType(), [value_type(obj), T_int8, value_type(Size)]), + ) push!(function_attributes(wrapper_f), StringAttribute("enzyme_math", "enzyme_zerotype")) push!(function_attributes(wrapper_f), StringAttribute("enzyme_inactive")) push!(function_attributes(wrapper_f), StringAttribute("enzyme_no_escaping_allocation")) @@ -3064,24 +3638,46 @@ function zero_allocation(B::LLVM.IRBuilder, jlType, LLVMType, obj, AlignedSize, exit = BasicBlock(wrapper_f, "exit") position!(builder, entry) nobj, _, nsize = collect(parameters(wrapper_f)) - nobj = pointercast!(builder, nobj, LLVM.PointerType(LLVMType, addrspace(value_type(nobj)))) + nobj = pointercast!( + builder, + nobj, + LLVM.PointerType(LLVMType, addrspace(value_type(nobj))), + ) LLVM.br!(builder, loop) position!(builder, loop) idx = LLVM.phi!(builder, value_type(Size)) inc = add!(builder, idx, LLVM.ConstantInt(value_type(Size), 1)) - append!(LLVM.incoming(idx), [(LLVM.ConstantInt(value_type(Size), 0), entry), (inc, loop)]) + append!( + LLVM.incoming(idx), + [(LLVM.ConstantInt(value_type(Size), 0), entry), (inc, loop)], + ) zero_single_allocation(builder, jlType, LLVMType, nobj, zeroAll, idx) - br!(builder, icmp!(builder, LLVM.API.LLVMIntEQ, inc, LLVM.Value(LLVM.API.LLVMBuildExactUDiv(builder, nsize, AlignedSize, ""))), exit, loop) + br!( + builder, + icmp!( + builder, + LLVM.API.LLVMIntEQ, + inc, + LLVM.Value(LLVM.API.LLVMBuildExactUDiv(builder, nsize, AlignedSize, "")), + ), + exit, + loop, + ) position!(builder, exit) ret!(builder) dispose(builder) end - return call!(B, LLVM.function_type(wrapper_f), wrapper_f, [obj, LLVM.ConstantInt(T_int8, 0), Size]).ref + return call!( + B, + LLVM.function_type(wrapper_f), + wrapper_f, + [obj, LLVM.ConstantInt(T_int8, 0), Size], + ).ref end function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) @@ -3102,7 +3698,8 @@ function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) TT = tape_type(LLVMType) if esizeof(TT) != convert(Int, AlignedSize) - GPUCompiler.@safe_error "Enzyme aligned size and Julia size disagree" AlignedSize=convert(Int, AlignedSize) esizeof(TT) fieldtypes(TT) + GPUCompiler.@safe_error "Enzyme aligned size and Julia size disagree" AlignedSize = + convert(Int, AlignedSize) esizeof(TT) fieldtypes(TT) emit_error(B, nothing, "Enzyme: Tape allocation failed.") # TODO: Pick appropriate orig return LLVM.API.LLVMValueRef(LLVM.UndefValue(LLVMType).ref) end @@ -3110,9 +3707,11 @@ function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) if Count isa LLVM.ConstantInt N = convert(Int, Count) - ETT = N == 1 ? TT : NTuple{N, TT} - if sizeof(ETT) != N*convert(Int, AlignedSize) - GPUCompiler.@safe_error "Size of Enzyme tape is incorrect. Please report this issue" ETT sizeof(ETT) TargetSize = N*convert(Int, AlignedSize) LLVMType + ETT = N == 1 ? TT : NTuple{N,TT} + if sizeof(ETT) != N * convert(Int, AlignedSize) + GPUCompiler.@safe_error "Size of Enzyme tape is incorrect. Please report this issue" ETT sizeof( + ETT, + ) TargetSize = N * convert(Int, AlignedSize) LLVMType emit_error(B, nothing, "Enzyme: Tape allocation failed.") # TODO: Pick appropriate orig return LLVM.API.LLVMValueRef(LLVM.UndefValue(LLVMType).ref) @@ -3137,7 +3736,8 @@ function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) @static if VERSION >= v"1.10.5" needs_dynamic_size_workaround = false else - needs_dynamic_size_workaround = !isa(Size, LLVM.ConstantInt) || convert(Int, Size) != 1 + needs_dynamic_size_workaround = + !isa(Size, LLVM.ConstantInt) || convert(Int, Size) != 1 end T_size_t = convert(LLVM.LLVMType, Int) @@ -3150,12 +3750,16 @@ function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) obj = emit_allocobj!(B, tag, allocSize, needs_dynamic_size_workaround) if ZI != C_NULL - unsafe_store!(ZI, zero_allocation(B, TT, LLVMType, obj, AlignedSize, Size, #=ZeroAll=#false)) + unsafe_store!( + ZI, + zero_allocation(B, TT, LLVMType, obj, AlignedSize, Size, false), + ) #=ZeroAll=# end AS = Tracked else ptr8 = LLVM.PointerType(LLVM.IntType(8)) - mallocF, fty = get_function!(mod, "malloc", LLVM.FunctionType(ptr8, [value_type(Count)])) + mallocF, fty = + get_function!(mod, "malloc", LLVM.FunctionType(ptr8, [value_type(Count)])) obj = call!(B, fty, mallocF, [Size]) # if ZI != C_NULL @@ -3166,13 +3770,29 @@ function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) AS = 0 end - LLVM.API.LLVMAddCallSiteAttribute(obj, LLVM.API.LLVMAttributeReturnIndex, EnumAttribute("noalias")) - LLVM.API.LLVMAddCallSiteAttribute(obj, LLVM.API.LLVMAttributeReturnIndex, EnumAttribute("nonnull")) + LLVM.API.LLVMAddCallSiteAttribute( + obj, + LLVM.API.LLVMAttributeReturnIndex, + EnumAttribute("noalias"), + ) + LLVM.API.LLVMAddCallSiteAttribute( + obj, + LLVM.API.LLVMAttributeReturnIndex, + EnumAttribute("nonnull"), + ) if isa(Count, LLVM.ConstantInt) val = convert(UInt, AlignedSize) val *= convert(UInt, Count) - LLVM.API.LLVMAddCallSiteAttribute(obj, LLVM.API.LLVMAttributeReturnIndex, EnumAttribute("dereferenceable", val)) - LLVM.API.LLVMAddCallSiteAttribute(obj, LLVM.API.LLVMAttributeReturnIndex, EnumAttribute("dereferenceable_or_null", val)) + LLVM.API.LLVMAddCallSiteAttribute( + obj, + LLVM.API.LLVMAttributeReturnIndex, + EnumAttribute("dereferenceable", val), + ) + LLVM.API.LLVMAddCallSiteAttribute( + obj, + LLVM.API.LLVMAttributeReturnIndex, + EnumAttribute("dereferenceable_or_null", val), + ) end mem = pointercast!(B, obj, LLVM.PointerType(LLVMType, AS)) @@ -3195,7 +3815,11 @@ function julia_deallocator(B::LLVM.IRBuilder, Obj::LLVM.Value) ptr8 = LLVM.PointerType(LLVM.IntType(8)) freeF, fty = get_function!(mod, "free", LLVM.FunctionType(T_void, [ptr8])) callf = call!(B, fty, freeF, [pointercast!(B, Obj, ptr8)]) - LLVM.API.LLVMAddCallSiteAttribute(callf, LLVM.API.LLVMAttributeIndex(1), EnumAttribute("nonnull")) + LLVM.API.LLVMAddCallSiteAttribute( + callf, + LLVM.API.LLVMAttributeIndex(1), + EnumAttribute("nonnull"), + ) end return LLVM.API.LLVMValueRef(callf.ref) end @@ -3208,10 +3832,14 @@ function emit_inacterror(B, V, orig) mod = LLVM.parent(fn) bt = GPUCompiler.backtrace(orig) - bts = sprint(io->Base.show_backtrace(io, bt)) - fmt = globalstring_ptr!(B, "%s:\nBacktrace\n"*bts) + bts = sprint(io -> Base.show_backtrace(io, bt)) + fmt = globalstring_ptr!(B, "%s:\nBacktrace\n" * bts) - funcT = LLVM.FunctionType(LLVM.VoidType(), LLVMType[LLVM.PointerType(LLVM.Int8Type())], vararg=true) + funcT = LLVM.FunctionType( + LLVM.VoidType(), + LLVMType[LLVM.PointerType(LLVM.Int8Type())], + vararg = true, + ) func, _ = get_function!(mod, "jl_errorf", funcT, [EnumAttribute("noreturn")]) call!(B, funcT, func, LLVM.Value[fmt, LLVM.Value(V)]) @@ -3224,31 +3852,22 @@ include("rules/llvmrules.jl") for (k, v) in ( ("enz_runtime_newtask_fwd", Enzyme.Compiler.runtime_newtask_fwd), ("enz_runtime_newtask_augfwd", Enzyme.Compiler.runtime_newtask_augfwd), - ("enz_runtime_generic_fwd", Enzyme.Compiler.runtime_generic_fwd), ("enz_runtime_generic_augfwd", Enzyme.Compiler.runtime_generic_augfwd), ("enz_runtime_generic_rev", Enzyme.Compiler.runtime_generic_rev), - ("enz_runtime_iterate_fwd", Enzyme.Compiler.runtime_iterate_fwd), ("enz_runtime_iterate_augfwd", Enzyme.Compiler.runtime_iterate_augfwd), ("enz_runtime_iterate_rev", Enzyme.Compiler.runtime_iterate_rev), - ("enz_runtime_newstruct_augfwd", Enzyme.Compiler.runtime_newstruct_augfwd), ("enz_runtime_newstruct_rev", Enzyme.Compiler.runtime_newstruct_rev), - ("enz_runtime_tuple_augfwd", Enzyme.Compiler.runtime_tuple_augfwd), ("enz_runtime_tuple_rev", Enzyme.Compiler.runtime_tuple_rev), - - ("enz_runtime_jl_getfield_aug", Enzyme.Compiler.rt_jl_getfield_aug), ("enz_runtime_jl_getfield_rev", Enzyme.Compiler.rt_jl_getfield_rev), - ("enz_runtime_idx_jl_getfield_aug", Enzyme.Compiler.idx_jl_getfield_aug), ("enz_runtime_idx_jl_getfield_rev", Enzyme.Compiler.idx_jl_getfield_aug), - ("enz_runtime_jl_setfield_aug", Enzyme.Compiler.rt_jl_setfield_aug), ("enz_runtime_jl_setfield_rev", Enzyme.Compiler.rt_jl_setfield_rev), - ("enz_runtime_error_if_differentiable", Enzyme.Compiler.error_if_differentiable), ("enz_runtime_error_if_active", Enzyme.Compiler.error_if_active), ) @@ -3258,30 +3877,103 @@ end function __init__() API.memmove_warning!(false) API.typeWarning!(false) - API.EnzymeSetHandler(@cfunction(julia_error, LLVM.API.LLVMValueRef, (Cstring, LLVM.API.LLVMValueRef, API.ErrorType, Ptr{Cvoid}, LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef))) - API.EnzymeSetSanitizeDerivatives(@cfunction(julia_sanitize, LLVM.API.LLVMValueRef, (LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef))); - API.EnzymeSetRuntimeInactiveError(@cfunction(emit_inacterror, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef))) - API.EnzymeSetDefaultTapeType(@cfunction( - julia_default_tape_type, LLVM.API.LLVMTypeRef, (LLVM.API.LLVMContextRef,))) - API.EnzymeSetCustomAllocator(@cfunction( - julia_allocator, LLVM.API.LLVMValueRef, - (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMTypeRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef, UInt8, Ptr{LLVM.API.LLVMValueRef}))) - API.EnzymeSetCustomDeallocator(@cfunction( - julia_deallocator, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef))) - API.EnzymeSetPostCacheStore(@cfunction( - julia_post_cache_store, Ptr{LLVM.API.LLVMValueRef}, - (LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef, Ptr{UInt64}))) - - API.EnzymeSetCustomZero(@cfunction( - zero_allocation, Cvoid, - (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMTypeRef, LLVM.API.LLVMValueRef, UInt8))) - API.EnzymeSetFixupReturn(@cfunction( - fixup_return, LLVM.API.LLVMValueRef, - (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef))) - API.EnzymeSetUndefinedValueForType(@cfunction( - julia_undef_value_for_type, LLVM.API.LLVMValueRef, (LLVM.API.LLVMModuleRef, LLVM.API.LLVMTypeRef,UInt8))) - API.EnzymeSetShadowAllocRewrite(@cfunction( - shadow_alloc_rewrite, Cvoid, (LLVM.API.LLVMValueRef,API.EnzymeGradientUtilsRef))) + API.EnzymeSetHandler( + @cfunction( + julia_error, + LLVM.API.LLVMValueRef, + ( + Cstring, + LLVM.API.LLVMValueRef, + API.ErrorType, + Ptr{Cvoid}, + LLVM.API.LLVMValueRef, + LLVM.API.LLVMBuilderRef, + ) + ) + ) + API.EnzymeSetSanitizeDerivatives( + @cfunction( + julia_sanitize, + LLVM.API.LLVMValueRef, + ( + LLVM.API.LLVMValueRef, + LLVM.API.LLVMValueRef, + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + ) + ) + ) + API.EnzymeSetRuntimeInactiveError( + @cfunction( + emit_inacterror, + Cvoid, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef) + ) + ) + API.EnzymeSetDefaultTapeType( + @cfunction( + julia_default_tape_type, + LLVM.API.LLVMTypeRef, + (LLVM.API.LLVMContextRef,) + ) + ) + API.EnzymeSetCustomAllocator( + @cfunction( + julia_allocator, + LLVM.API.LLVMValueRef, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMTypeRef, + LLVM.API.LLVMValueRef, + LLVM.API.LLVMValueRef, + UInt8, + Ptr{LLVM.API.LLVMValueRef}, + ) + ) + ) + API.EnzymeSetCustomDeallocator( + @cfunction( + julia_deallocator, + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef) + ) + ) + API.EnzymeSetPostCacheStore( + @cfunction( + julia_post_cache_store, + Ptr{LLVM.API.LLVMValueRef}, + (LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef, Ptr{UInt64}) + ) + ) + + API.EnzymeSetCustomZero( + @cfunction( + zero_allocation, + Cvoid, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMTypeRef, LLVM.API.LLVMValueRef, UInt8) + ) + ) + API.EnzymeSetFixupReturn( + @cfunction( + fixup_return, + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef) + ) + ) + API.EnzymeSetUndefinedValueForType( + @cfunction( + julia_undef_value_for_type, + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMModuleRef, LLVM.API.LLVMTypeRef, UInt8) + ) + ) + API.EnzymeSetShadowAllocRewrite( + @cfunction( + shadow_alloc_rewrite, + Cvoid, + (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef) + ) + ) register_alloc_rules() register_llvm_rules() @@ -3291,8 +3983,7 @@ function __init__() end # Define EnzymeTarget -Base.@kwdef struct EnzymeTarget <: AbstractCompilerTarget -end +Base.@kwdef struct EnzymeTarget <: AbstractCompilerTarget end GPUCompiler.llvm_triple(::EnzymeTarget) = LLVM.triple(JIT.get_jit()) GPUCompiler.llvm_datalayout(::EnzymeTarget) = LLVM.datalayout(JIT.get_jit()) @@ -3301,20 +3992,19 @@ function GPUCompiler.llvm_machine(::EnzymeTarget) return JIT.get_tm() end -module Runtime -end +module Runtime end abstract type AbstractEnzymeCompilerParams <: AbstractCompilerParams end struct EnzymeCompilerParams <: AbstractEnzymeCompilerParams TT::Type{<:Tuple} mode::API.CDerivativeMode width::Int - rt::Type{<:Annotation{T} where T} + rt::Type{<:Annotation{T} where {T}} run_enzyme::Bool abiwrap::Bool # Whether, in split mode, acessible primal argument data is modified # between the call and the split - modifiedBetween::NTuple{N, Bool} where N + modifiedBetween::NTuple{N,Bool} where {N} # Whether to also return the primal returnPrimal::Bool # Whether to (in aug fwd) += by one @@ -3335,7 +4025,8 @@ struct PrimalCompilerParams <: AbstractEnzymeCompilerParams mode::API.CDerivativeMode end -DefaultCompilerTarget(;kwargs...) = GPUCompiler.NativeCompilerTarget(;jlruntime=true, kwargs...) +DefaultCompilerTarget(; kwargs...) = + GPUCompiler.NativeCompilerTarget(; jlruntime = true, kwargs...) ## job @@ -3352,41 +4043,55 @@ GPUCompiler.runtime_slug(job::CompilerJob{EnzymeTarget}) = "enzyme" # provide a specific interpreter to use. if VERSION >= v"1.11.0-DEV.1552" -struct EnzymeCacheToken - target_type::Type - always_inline - method_table::Core.MethodTable - param_type::Type - is_fwd::Bool -end - -GPUCompiler.ci_cache_token(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = - EnzymeCacheToken( - typeof(job.config.target), job.config.always_inline, GPUCompiler.method_table(job), - typeof(job.config.params), job.config.params.mode == API.DEM_ForwardMode, - ) + struct EnzymeCacheToken + target_type::Type + always_inline::Any + method_table::Core.MethodTable + param_type::Type + is_fwd::Bool + end -GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = - Interpreter.EnzymeInterpreter(GPUCompiler.ci_cache_token(job), GPUCompiler.method_table(job), job.world, job.config.params.mode) + GPUCompiler.ci_cache_token(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = + EnzymeCacheToken( + typeof(job.config.target), + job.config.always_inline, + GPUCompiler.method_table(job), + typeof(job.config.params), + job.config.params.mode == API.DEM_ForwardMode, + ) + + GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = + Interpreter.EnzymeInterpreter( + GPUCompiler.ci_cache_token(job), + GPUCompiler.method_table(job), + job.world, + job.config.params.mode, + ) else -# the codeinstance cache to use -- should only be used for the constructor -# Note that the only way the interpreter modifies codegen is either not inlining a fwd mode -# rule or not inlining a rev mode rule. Otherwise, all caches can be re-used. -const GLOBAL_FWD_CACHE = GPUCompiler.CodeCache() -const GLOBAL_REV_CACHE = GPUCompiler.CodeCache() -function enzyme_ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) - return if job.config.params.mode == API.DEM_ForwardMode - GLOBAL_FWD_CACHE - else - GLOBAL_REV_CACHE + # the codeinstance cache to use -- should only be used for the constructor + # Note that the only way the interpreter modifies codegen is either not inlining a fwd mode + # rule or not inlining a rev mode rule. Otherwise, all caches can be re-used. + const GLOBAL_FWD_CACHE = GPUCompiler.CodeCache() + const GLOBAL_REV_CACHE = GPUCompiler.CodeCache() + function enzyme_ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) + return if job.config.params.mode == API.DEM_ForwardMode + GLOBAL_FWD_CACHE + else + GLOBAL_REV_CACHE + end end -end -GPUCompiler.ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = enzyme_ci_cache(job) + GPUCompiler.ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = + enzyme_ci_cache(job) -GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = - Interpreter.EnzymeInterpreter(enzyme_ci_cache(job), GPUCompiler.method_table(job), job.world, job.config.params.mode) + GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = + Interpreter.EnzymeInterpreter( + enzyme_ci_cache(job), + GPUCompiler.method_table(job), + job.world, + job.config.params.mode, + ) end include("compiler/passes.jl") @@ -3399,55 +4104,75 @@ import .Interpreter: isKWCallSignature """ Create the methodinstance pair, and lookup the primal return type. """ -@inline function fspec(@nospecialize(F), @nospecialize(TT), world::Union{Integer, Nothing}=nothing) +@inline function fspec( + @nospecialize(F), + @nospecialize(TT), + world::Union{Integer,Nothing} = nothing, +) # primal function. Inferred here to get return type _tt = (TT.parameters...,) primal_tt = Tuple{map(eltype, _tt)...} primal = if world isa Nothing - GPUCompiler.methodinstance(F, primal_tt) + GPUCompiler.methodinstance(F, primal_tt) else - GPUCompiler.methodinstance(F, primal_tt, world) + GPUCompiler.methodinstance(F, primal_tt, world) end return primal end -@generated function primal_return_type(::ReverseMode, ::Val{world}, ::Type{FT}, ::Type{TT}) where {world, FT, TT} +@generated function primal_return_type( + ::ReverseMode, + ::Val{world}, + ::Type{FT}, + ::Type{TT}, +) where {world,FT,TT} mode = Enzyme.API.DEM_ReverseModeCombined CT = @static if VERSION >= v"1.11.0-DEV.1552" EnzymeCacheToken( - typeof(DefaultCompilerTarget()), #=job.config.always_inline=#false, GPUCompiler.GLOBAL_METHOD_TABLE, - EnzymeCompilerParams, false, + typeof(DefaultCompilerTarget()), + false, + GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=# + EnzymeCompilerParams, + false, ) else Enzyme.Compiler.GLOBAL_REV_CACHE end interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode) - res = Core.Compiler._return_type(interp, Tuple{FT, TT.parameters...}) + res = Core.Compiler._return_type(interp, Tuple{FT,TT.parameters...}) return quote Base.@_inline_meta $res end end -@generated function primal_return_type(::ForwardMode, ::Val{world}, ::Type{FT}, ::Type{TT}) where {world, FT, TT} +@generated function primal_return_type( + ::ForwardMode, + ::Val{world}, + ::Type{FT}, + ::Type{TT}, +) where {world,FT,TT} mode = Enzyme.API.DEM_ForwardMode CT = @static if VERSION >= v"1.11.0-DEV.1552" EnzymeCacheToken( - typeof(DefaultCompilerTarget()), #=always_inline=#false, GPUCompiler.GLOBAL_METHOD_TABLE, - EnzymeCompilerParams, false, + typeof(DefaultCompilerTarget()), + false, + GPUCompiler.GLOBAL_METHOD_TABLE, #=always_inline=# + EnzymeCompilerParams, + false, ) else Enzyme.Compiler.GLOBAL_FWD_CACHE end interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode) - res = Core.Compiler._return_type(interp, Tuple{FT, TT.parameters...}) + res = Core.Compiler._return_type(interp, Tuple{FT,TT.parameters...}) return quote Base.@_inline_meta $res @@ -3467,7 +4192,7 @@ function annotate!(mod, mode) for f in fns API.EnzymeAttributeKnownFunctions(f.ref) end - + for gname in inactiveglobs globs = LLVM.globals(mod) if haskey(globs, gname) @@ -3496,8 +4221,22 @@ function annotate!(mod, mode) if operands(c)[1] != fn continue end - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + inactive, + ) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + no_escaping_alloc, + ) end end end @@ -3521,7 +4260,14 @@ function annotate!(mod, mode) if operands(c)[1] != fn continue end - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), LLVM.EnumAttribute("nofree", 0)) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + LLVM.EnumAttribute("nofree", 0), + ) end end end @@ -3533,7 +4279,8 @@ function annotate!(mod, mode) end end - for fname in ("julia.typeof", "jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id") + for fname in + ("julia.typeof", "jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id") if haskey(fns, fname) fn = fns[fname] if LLVM.version().major <= 15 @@ -3551,14 +4298,25 @@ function annotate!(mod, mode) end end - for fname in ("jl_excstack_state","ijl_excstack_state", "ijl_field_index", "jl_field_index") + for fname in + ("jl_excstack_state", "ijl_excstack_state", "ijl_field_index", "jl_field_index") if haskey(fns, fname) fn = fns[fname] if LLVM.version().major <= 15 push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) push!(function_attributes(fn), LLVM.StringAttribute("inaccessiblememonly")) else - push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_Ref << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) end end end @@ -3570,7 +4328,14 @@ function annotate!(mod, mode) end end - for fname in ("jl_f_getfield","ijl_f_getfield","jl_get_nth_field_checked","ijl_get_nth_field_checked", "jl_f__svec_ref", "ijl_f__svec_ref") + for fname in ( + "jl_f_getfield", + "ijl_f_getfield", + "jl_get_nth_field_checked", + "ijl_get_nth_field_checked", + "jl_f__svec_ref", + "ijl_f__svec_ref", + ) if haskey(fns, fname) fn = fns[fname] push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) @@ -3592,9 +4357,23 @@ function annotate!(mod, mode) attr = if LLVM.version().major <= 15 LLVM.EnumAttribute("readonly") else - EnumAttribute("memory", MemoryEffect((MRI_Ref << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data) + EnumAttribute( + "memory", + MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) end - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), attr) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + attr, + ) end end end @@ -3619,30 +4398,52 @@ function annotate!(mod, mode) end end - for fname in ("julia.get_pgcstack", "julia.ptls_states", "jl_get_ptls_states", "julia.safepoint", "ijl_throw", "julia.pointer_from_objref", - "ijl_array_grow_end", "jl_array_grow_end", "ijl_array_del_end", "jl_array_del_end", - "ijl_array_grow_beg", "jl_array_grow_beg", "ijl_array_del_beg", "jl_array_del_beg", - "ijl_array_grow_at", "jl_array_grow_at", - "ijl_array_del_at", "jl_array_del_at", - "ijl_pop_handler", "jl_pop_handler", - "ijl_push_handler", "jl_push_handler", - "ijl_module_name", "jl_module_name", - "ijl_restore_excstack", "jl_restore_excstack", - "julia.except_enter", - "ijl_get_nth_field_checked", "jl_get_nth_field_checked", - "jl_egal__unboxed", - "ijl_reshape_array", "jl_reshape_array", - "ijl_eqtable_get", "jl_eqtable_get", - "jl_gc_run_pending_finalizers", - "ijl_try_substrtod", "jl_try_substrtod", - ) + for fname in ( + "julia.get_pgcstack", + "julia.ptls_states", + "jl_get_ptls_states", + "julia.safepoint", + "ijl_throw", + "julia.pointer_from_objref", + "ijl_array_grow_end", + "jl_array_grow_end", + "ijl_array_del_end", + "jl_array_del_end", + "ijl_array_grow_beg", + "jl_array_grow_beg", + "ijl_array_del_beg", + "jl_array_del_beg", + "ijl_array_grow_at", + "jl_array_grow_at", + "ijl_array_del_at", + "jl_array_del_at", + "ijl_pop_handler", + "jl_pop_handler", + "ijl_push_handler", + "jl_push_handler", + "ijl_module_name", + "jl_module_name", + "ijl_restore_excstack", + "jl_restore_excstack", + "julia.except_enter", + "ijl_get_nth_field_checked", + "jl_get_nth_field_checked", + "jl_egal__unboxed", + "ijl_reshape_array", + "jl_reshape_array", + "ijl_eqtable_get", + "jl_eqtable_get", + "jl_gc_run_pending_finalizers", + "ijl_try_substrtod", + "jl_try_substrtod", + ) if haskey(fns, fname) fn = fns[fname] push!(function_attributes(fn), no_escaping_alloc) end end - + for fname in ("julia.pointer_from_objref",) if haskey(fns, fname) @@ -3655,14 +4456,35 @@ function annotate!(mod, mode) end end - for boxfn in ("julia.gc_alloc_obj", "jl_gc_alloc_typed", "ijl_gc_alloc_typed", - "jl_box_float32", "jl_box_float64", "jl_box_int32", "jl_box_int64", - "ijl_box_float32", "ijl_box_float64", "ijl_box_int32", "ijl_box_int64", - "jl_alloc_array_1d", "jl_alloc_array_2d", "jl_alloc_array_3d", - "ijl_alloc_array_1d", "ijl_alloc_array_2d", "ijl_alloc_array_3d", - "jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash", - "jl_f_tuple", "ijl_f_tuple", "jl_new_structv", "ijl_new_structv", - "ijl_new_array", "jl_new_array") + for boxfn in ( + "julia.gc_alloc_obj", + "jl_gc_alloc_typed", + "ijl_gc_alloc_typed", + "jl_box_float32", + "jl_box_float64", + "jl_box_int32", + "jl_box_int64", + "ijl_box_float32", + "ijl_box_float64", + "ijl_box_int32", + "ijl_box_int64", + "jl_alloc_array_1d", + "jl_alloc_array_2d", + "jl_alloc_array_3d", + "ijl_alloc_array_1d", + "ijl_alloc_array_2d", + "ijl_alloc_array_3d", + "jl_array_copy", + "ijl_array_copy", + "jl_idtable_rehash", + "ijl_idtable_rehash", + "jl_f_tuple", + "ijl_f_tuple", + "jl_new_structv", + "ijl_new_structv", + "ijl_new_array", + "jl_new_array", + ) if haskey(fns, boxfn) fn = fns[boxfn] push!(return_attributes(fn), LLVM.EnumAttribute("noalias", 0)) @@ -3670,9 +4492,23 @@ function annotate!(mod, mode) accattr = if LLVM.version().major <= 15 LLVM.EnumAttribute("inaccessiblememonly") else - EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_ModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data) + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_ModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) end - if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) + if !( + boxfn in ( + "jl_array_copy", + "ijl_array_copy", + "jl_idtable_rehash", + "ijl_idtable_rehash", + ) + ) push!(function_attributes(fn), accattr) end for u in LLVM.uses(fn) @@ -3682,9 +4518,27 @@ function annotate!(mod, mode) end cf = LLVM.called_operand(c) if cf == fn - LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0)) - if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), accattr) + LLVM.API.LLVMAddCallSiteAttribute( + c, + LLVM.API.LLVMAttributeReturnIndex, + LLVM.EnumAttribute("noalias", 0), + ) + if !( + boxfn in ( + "jl_array_copy", + "ijl_array_copy", + "jl_idtable_rehash", + "ijl_idtable_rehash", + ) + ) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + accattr, + ) end end if !isa(cf, LLVM.Function) @@ -3696,15 +4550,47 @@ function annotate!(mod, mode) if operands(c)[1] != fn continue end - LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0)) - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) - if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) + LLVM.API.LLVMAddCallSiteAttribute( + c, + LLVM.API.LLVMAttributeReturnIndex, + LLVM.EnumAttribute("noalias", 0), + ) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + no_escaping_alloc, + ) + if !( + boxfn in ( + "jl_array_copy", + "ijl_array_copy", + "jl_idtable_rehash", + "ijl_idtable_rehash", + ) + ) attr = if LLVM.version().major <= 15 LLVM.EnumAttribute("inaccessiblememonly") else - EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_ModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data) + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_ModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) end - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), attr) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + attr, + ) end end end @@ -3716,7 +4602,17 @@ function annotate!(mod, mode) if LLVM.version().major <= 15 push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly")) else - push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_ModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_ModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) end end end @@ -3730,7 +4626,17 @@ function annotate!(mod, mode) push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) else - push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_Ref << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) end end end @@ -3745,7 +4651,17 @@ function annotate!(mod, mode) if LLVM.version().major <= 15 push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) else - push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_ModRef << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_ModRef << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) end end end @@ -3757,7 +4673,17 @@ function annotate!(mod, mode) push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly")) else - push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_Ref << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) end end end @@ -3774,7 +4700,7 @@ function enzyme_extract_world(fn::LLVM.Function)::UInt throw(AssertionError("Enzyme: could not find world in $(string(fn))")) end -function enzyme_custom_extract_mi(orig::LLVM.Instruction, error=true) +function enzyme_custom_extract_mi(orig::LLVM.Instruction, error = true) operand = LLVM.called_operand(orig) if isa(operand, LLVM.Function) return enzyme_custom_extract_mi(operand::LLVM.Function, error) @@ -3784,7 +4710,7 @@ function enzyme_custom_extract_mi(orig::LLVM.Instruction, error=true) return nothing, nothing end -function enzyme_custom_extract_mi(orig::LLVM.Function, error=true) +function enzyme_custom_extract_mi(orig::LLVM.Function, error = true) mi = nothing RT = nothing for fattr in collect(function_attributes(orig)) @@ -3805,7 +4731,7 @@ function enzyme_custom_extract_mi(orig::LLVM.Function, error=true) return mi, RT end -function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error=true) +function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error = true) ty = nothing byref = nothing for fattr in collect(parameter_attributes(fn, idx)) @@ -3820,7 +4746,9 @@ function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error=true) end end if error && (byref === nothing || ty === nothing) - GPUCompiler.@safe_error "Enzyme: Custom handler, could not find parm type at index", idx, fn + GPUCompiler.@safe_error "Enzyme: Custom handler, could not find parm type at index", + idx, + fn end return ty, byref end @@ -3828,18 +4756,39 @@ end include("rules/typerules.jl") include("rules/activityrules.jl") -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: Const = API.DFT_CONSTANT -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: Active = API.DFT_OUT_DIFF -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: Duplicated = API.DFT_DUP_ARG -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: BatchDuplicated = API.DFT_DUP_ARG -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: BatchDuplicatedFunc = API.DFT_DUP_ARG -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: DuplicatedNoNeed = API.DFT_DUP_NONEED -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: BatchDuplicatedNoNeed = API.DFT_DUP_NONEED +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:Const} = API.DFT_CONSTANT +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:Active} = + API.DFT_OUT_DIFF +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:Duplicated} = + API.DFT_DUP_ARG +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:BatchDuplicated} = + API.DFT_DUP_ARG +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:BatchDuplicatedFunc} = + API.DFT_DUP_ARG +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:DuplicatedNoNeed} = + API.DFT_DUP_NONEED +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:BatchDuplicatedNoNeed} = + API.DFT_DUP_NONEED const DumpPreEnzyme = Ref(false) const DumpPostWrap = Ref(false) -function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wrap, modifiedBetween, returnPrimal, expectedTapeType, loweredArgs, boxedArgs) +function enzyme!( + job, + mod, + primalf, + TT, + mode, + width, + parallel, + actualRetType, + wrap, + modifiedBetween, + returnPrimal, + expectedTapeType, + loweredArgs, + boxedArgs, +) if DumpPreEnzyme[] API.EnzymeDumpModuleRef(mod.ref) end @@ -3853,17 +4802,24 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr ctx = context(mod) dl = string(LLVM.datalayout(mod)) - tt = [TT.parameters[2:end]...,] + tt = [TT.parameters[2:end]...] - args_activity = API.CDIFFE_TYPE[] - uncacheable_args = Bool[] - args_typeInfo = TypeTree[] + args_activity = API.CDIFFE_TYPE[] + uncacheable_args = Bool[] + args_typeInfo = TypeTree[] args_known_values = API.IntList[] @assert length(modifiedBetween) == length(TT.parameters) - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(primalf, i)))) for i in 1:length(collect(parameters(primalf)))) + swiftself = any( + any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(primalf, i)), + ), + ) for i = 1:length(collect(parameters(primalf))) + ) if swiftself push!(args_activity, API.DFT_CONSTANT) push!(args_typeInfo, TypeTree()) @@ -3876,7 +4832,11 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr source_typ = eltype(T) if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) if !(T <: Const) - error("Type of ghost or constant type "*string(T)*" is marked as differentiable.") + error( + "Type of ghost or constant type " * + string(T) * + " is marked as differentiable.", + ) end continue end @@ -3890,9 +4850,13 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr else push!(args_activity, API.DFT_OUT_DIFF) end - elseif T <: Duplicated || T<: BatchDuplicated || T<: BatchDuplicatedFunc || T <: MixedDuplicated || T <: BatchMixedDuplicated + elseif T <: Duplicated || + T <: BatchDuplicated || + T <: BatchDuplicatedFunc || + T <: MixedDuplicated || + T <: BatchMixedDuplicated push!(args_activity, API.DFT_DUP_ARG) - elseif T <: DuplicatedNoNeed || T<: BatchDuplicatedNoNeed + elseif T <: DuplicatedNoNeed || T <: BatchDuplicatedNoNeed push!(args_activity, API.DFT_DUP_NONEED) else error("illegal annotation type $T") @@ -3922,36 +4886,105 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr convert(API.CDIFFE_TYPE, rt) end - rules = Dict{String, API.CustomRuleType}( - "jl_array_copy" => @cfunction(inout_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_array_copy" => @cfunction(inout_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "julia.pointer_from_objref" => @cfunction(inout_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_inactive_inout" => @cfunction(inout_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_excstack_state" => @cfunction(int_return_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_excstack_state" => @cfunction(int_return_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "julia.except_enter" => @cfunction(int_return_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), + rules = Dict{String,API.CustomRuleType}( + "jl_array_copy" => @cfunction( + inout_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), + "ijl_array_copy" => @cfunction( + inout_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), + "julia.pointer_from_objref" => @cfunction( + inout_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), + "jl_inactive_inout" => @cfunction( + inout_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), + "jl_excstack_state" => @cfunction( + int_return_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), + "ijl_excstack_state" => @cfunction( + int_return_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), + "julia.except_enter" => @cfunction( + int_return_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), ) logic = Logic() TA = TypeAnalysis(logic, rules) - retT = (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? - Ptr{actualRetType} : actualRetType - retTT = (!isa(actualRetType, Union) && actualRetType <: Tuple && in(Any, actualRetType.parameters)) ? TypeTree() : typetree(retT, ctx, dl, seen) + retT = + (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? + Ptr{actualRetType} : actualRetType + retTT = + ( + !isa(actualRetType, Union) && + actualRetType <: Tuple && + in(Any, actualRetType.parameters) + ) ? TypeTree() : typetree(retT, ctx, dl, seen) typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values) @@ -3959,15 +4992,33 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr if mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType)) - shadowReturnUsed = returnUsed && (retType == API.DFT_DUP_ARG || retType == API.DFT_DUP_NONEED || rt <: MixedDuplicated || rt <: BatchMixedDuplicated) + shadowReturnUsed = + returnUsed && ( + retType == API.DFT_DUP_ARG || + retType == API.DFT_DUP_NONEED || + rt <: MixedDuplicated || + rt <: BatchMixedDuplicated + ) returnUsed &= returnPrimal augmented = API.EnzymeCreateAugmentedPrimal( - logic, primalf, retType, args_activity, TA, #=returnUsed=# returnUsed, - #=shadowReturnUsed=#shadowReturnUsed, - typeInfo, uncacheable_args, #=forceAnonymousTape=# false, runtimeActivity, width, #=atomicAdd=# parallel) + logic, + primalf, + retType, + args_activity, + TA, + returnUsed, #=returnUsed=# + shadowReturnUsed, #=shadowReturnUsed=# + typeInfo, + uncacheable_args, + false, + runtimeActivity, + width, + parallel, + ) #=atomicAdd=# # 2. get new_primalf and tape - augmented_primalf = LLVM.Function(API.EnzymeExtractFunctionFromAugmentation(augmented)) + augmented_primalf = + LLVM.Function(API.EnzymeExtractFunctionFromAugmentation(augmented)) tape = API.EnzymeExtractTapeTypeFromAugmentation(augmented) utape = API.EnzymeExtractUnderlyingTapeTypeFromAugmentation(augmented) if utape != C_NULL @@ -3983,55 +5034,145 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr end if wrap - augmented_primalf = create_abi_wrapper(augmented_primalf, TT, rt, actualRetType, API.DEM_ReverseModePrimal, augmented, width, returnPrimal, shadow_init, world, interp) + augmented_primalf = create_abi_wrapper( + augmented_primalf, + TT, + rt, + actualRetType, + API.DEM_ReverseModePrimal, + augmented, + width, + returnPrimal, + shadow_init, + world, + interp, + ) end # TODOs: # 1. Handle mutable or !pointerfree arguments by introducing caching # + specifically by setting uncacheable_args[i] = true - adjointf = LLVM.Function(API.EnzymeCreatePrimalAndGradient( - logic, primalf, retType, args_activity, TA, - #=returnValue=#false, #=dretUsed=#false, #=mode=#API.DEM_ReverseModeGradient, runtimeActivity, width, - #=additionalArg=#tape, #=forceAnonymousTape=#false, typeInfo, - uncacheable_args, augmented, #=atomicAdd=# parallel)) + adjointf = LLVM.Function( + API.EnzymeCreatePrimalAndGradient( + logic, + primalf, + retType, + args_activity, + TA, + false, + false, + API.DEM_ReverseModeGradient, + runtimeActivity, + width, #=mode=# + tape, + false, + typeInfo, #=forceAnonymousTape=# + uncacheable_args, + augmented, + parallel, + ), + ) #=atomicAdd=# if wrap - adjointf = create_abi_wrapper(adjointf, TT, rt, actualRetType, API.DEM_ReverseModeGradient, augmented, width, #=returnPrimal=#false, shadow_init, world, interp) + adjointf = create_abi_wrapper( + adjointf, + TT, + rt, + actualRetType, + API.DEM_ReverseModeGradient, + augmented, + width, + false, + shadow_init, + world, + interp, + ) #=returnPrimal=# end elseif mode == API.DEM_ReverseModeCombined returnUsed = !isghostty(actualRetType) returnUsed &= returnPrimal - adjointf = LLVM.Function(API.EnzymeCreatePrimalAndGradient( - logic, primalf, retType, args_activity, TA, - #=returnValue=#returnUsed, #=dretUsed=#false, #=mode=#API.DEM_ReverseModeCombined, runtimeActivity, width, - #=additionalArg=#C_NULL, #=forceAnonymousTape=#false, typeInfo, - uncacheable_args, #=augmented=#C_NULL, #=atomicAdd=# parallel)) + adjointf = LLVM.Function( + API.EnzymeCreatePrimalAndGradient( + logic, + primalf, + retType, + args_activity, + TA, + returnUsed, + false, + API.DEM_ReverseModeCombined, + runtimeActivity, + width, #=mode=# + C_NULL, + false, + typeInfo, #=forceAnonymousTape=# + uncacheable_args, + C_NULL, + parallel, + ), + ) #=atomicAdd=# augmented_primalf = nothing if wrap - adjointf = create_abi_wrapper(adjointf, TT, rt, actualRetType, API.DEM_ReverseModeCombined, nothing, width, returnPrimal, shadow_init, world, interp) + adjointf = create_abi_wrapper( + adjointf, + TT, + rt, + actualRetType, + API.DEM_ReverseModeCombined, + nothing, + width, + returnPrimal, + shadow_init, + world, + interp, + ) end elseif mode == API.DEM_ForwardMode returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType)) returnUsed &= returnPrimal - adjointf = LLVM.Function(API.EnzymeCreateForwardDiff( - logic, primalf, retType, args_activity, TA, - #=returnValue=#returnUsed, #=mode=#API.DEM_ForwardMode, runtimeActivity, width, - #=additionalArg=#C_NULL, typeInfo, - uncacheable_args)) + adjointf = LLVM.Function( + API.EnzymeCreateForwardDiff( + logic, + primalf, + retType, + args_activity, + TA, + returnUsed, + API.DEM_ForwardMode, + runtimeActivity, + width, #=mode=# + C_NULL, + typeInfo, #=additionalArg=# + uncacheable_args, + ), + ) augmented_primalf = nothing if wrap - pf = adjointf - adjointf = create_abi_wrapper(adjointf, TT, rt, actualRetType, API.DEM_ForwardMode, nothing, width, returnPrimal, shadow_init, world, interp) + pf = adjointf + adjointf = create_abi_wrapper( + adjointf, + TT, + rt, + actualRetType, + API.DEM_ForwardMode, + nothing, + width, + returnPrimal, + shadow_init, + world, + interp, + ) end else @assert "Unhandled derivative mode", mode end if DumpPostWrap[] API.EnzymeDumpModuleRef(mod.ref) - end + end API.EnzymeLogicErasePreprocessedFunctions(logic) adjointfname = adjointf == nothing ? nothing : LLVM.name(adjointf) - augmented_primalfname = augmented_primalf == nothing ? nothing : LLVM.name(augmented_primalf) + augmented_primalfname = + augmented_primalf == nothing ? nothing : LLVM.name(augmented_primalf) for f in collect(functions(mod)) API.EnzymeFixupBatchedJuliaCallingConvention(f) end @@ -4041,7 +5182,8 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr end fix_decayaddr!(mod) adjointf = adjointf == nothing ? nothing : functions(mod)[adjointfname] - augmented_primalf = augmented_primalf == nothing ? nothing : functions(mod)[augmented_primalfname] + augmented_primalf = + augmented_primalf == nothing ? nothing : functions(mod)[augmented_primalfname] return adjointf, augmented_primalf, TapeType end @@ -4061,18 +5203,39 @@ function set_subprogram!(f::LLVM.Function, sp) end end -function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, Mode::API.CDerivativeMode, augmented, width, returnPrimal, shadow_init, world, interp) +function create_abi_wrapper( + enzymefn::LLVM.Function, + TT, + rettype, + actualRetType, + Mode::API.CDerivativeMode, + augmented, + width, + returnPrimal, + shadow_init, + world, + interp, +) is_adjoint = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModeCombined - is_split = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModePrimal + is_split = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModePrimal needs_tape = Mode == API.DEM_ReverseModeGradient mod = LLVM.parent(enzymefn) ctx = LLVM.context(mod) push!(function_attributes(enzymefn), EnumAttribute("alwaysinline", 0)) - hasNoInline = any(map(k->kind(k)==kind(EnumAttribute("noinline")), collect(function_attributes(enzymefn)))) + hasNoInline = any( + map( + k -> kind(k) == kind(EnumAttribute("noinline")), + collect(function_attributes(enzymefn)), + ), + ) if hasNoInline - LLVM.API.LLVMRemoveEnumAttributeAtIndex(enzymefn, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), kind(EnumAttribute("noinline"))) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + enzymefn, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + kind(EnumAttribute("noinline")), + ) end T_void = convert(LLVMType, Nothing) ptr8 = LLVM.PointerType(LLVM.IntType(8)) @@ -4082,7 +5245,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, # Create Enzyme calling convention T_wrapperargs = LLVMType[] # Arguments of the wrapper - sret_types = Type[] # Julia types of all returned variables + sret_types = Type[] # Julia types of all returned variables pactualRetType = actualRetType sret_union = is_sret_union(actualRetType) @@ -4122,7 +5285,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if width == 1 push!(ActiveRetTypes, source_typ) else - push!(ActiveRetTypes, NTuple{width, source_typ}) + push!(ActiveRetTypes, NTuple{width,source_typ}) end end elseif T <: Duplicated || T <: DuplicatedNoNeed @@ -4148,7 +5311,10 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if is_adjoint NT = Tuple{ActiveRetTypes...} - if any(any_jltypes(convert(LLVM.LLVMType, b; allow_boxed=true)) for b in ActiveRetTypes) + if any( + any_jltypes(convert(LLVM.LLVMType, b; allow_boxed = true)) for + b in ActiveRetTypes + ) NT = AnonymousStruct(NT) end push!(sret_types, NT) @@ -4156,26 +5322,40 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, # API.DFT_OUT_DIFF if is_adjoint - if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + if rettype <: Active || + rettype <: MixedDuplicated || + rettype <: BatchMixedDuplicated @assert !sret_union if allocatedinline(actualRetType) != allocatedinline(literal_rt) msg = sprint() do io println(io, string(enzymefn)) - println(io, "Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype), sret_union=$(sret_union), pactualRetType=$(pactualRetType)") + println( + io, + "Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype), sret_union=$(sret_union), pactualRetType=$(pactualRetType)", + ) end throw(AssertionError(msg)) end - if rettype <: Active + if rettype <: Active if !allocatedinline(actualRetType) - throw(AssertionError("Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)")) + throw( + AssertionError( + "Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)", + ), + ) end end - dretTy = LLVM.LLVMType(API.EnzymeGetShadowType(width, convert(LLVMType, actualRetType; allow_boxed=!(rettype <: Active)))) + dretTy = LLVM.LLVMType( + API.EnzymeGetShadowType( + width, + convert(LLVMType, actualRetType; allow_boxed = !(rettype <: Active)), + ), + ) push!(T_wrapperargs, dretTy) end end - data = Array{Int64}(undef, 3) + data = Array{Int64}(undef, 3) existed = Array{UInt8}(undef, 3) if Mode == API.DEM_ReverseModePrimal API.EnzymeExtractReturnInfo(augmented, data, existed) @@ -4208,17 +5388,24 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end # shadow return if existed[3] != 0 - if rettype <: Duplicated || rettype <: DuplicatedNoNeed || rettype <: BatchDuplicated || rettype <: BatchDuplicatedNoNeed || rettype <: BatchDuplicatedFunc + if rettype <: Duplicated || + rettype <: DuplicatedNoNeed || + rettype <: BatchDuplicated || + rettype <: BatchDuplicatedNoNeed || + rettype <: BatchDuplicatedFunc if width == 1 push!(sret_types, literal_rt) else - push!(sret_types, AnonymousStruct(NTuple{width, literal_rt})) + push!(sret_types, AnonymousStruct(NTuple{width,literal_rt})) end elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated if width == 1 push!(sret_types, Base.RefValue{literal_rt}) else - push!(sret_types, AnonymousStruct(NTuple{width, Base.RefValue{literal_rt}})) + push!( + sret_types, + AnonymousStruct(NTuple{width,Base.RefValue{literal_rt}}), + ) end end else @@ -4236,7 +5423,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if width == 1 push!(sret_types, literal_rt) else - push!(sret_types, AnonymousStruct(NTuple{width, literal_rt})) + push!(sret_types, AnonymousStruct(NTuple{width,literal_rt})) end end if returnPrimal @@ -4244,11 +5431,14 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end end - combinedReturn = if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types) - AnonymousStruct(Tuple{sret_types...}) - else - Tuple{sret_types...} - end + combinedReturn = + if any( + any_jltypes(convert(LLVM.LLVMType, T; allow_boxed = true)) for T in sret_types + ) + AnonymousStruct(Tuple{sret_types...}) + else + Tuple{sret_types...} + end uses_sret = is_sret(combinedReturn) @@ -4268,14 +5458,14 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, returnRoots = false root_ty = nothing if uses_sret - returnRoots = deserves_rooting(jltype) - if returnRoots - tracked = CountTrackedPointers(jltype) + returnRoots = deserves_rooting(jltype) + if returnRoots + tracked = CountTrackedPointers(jltype) root_ty = LLVM.ArrayType(T_prjlvalue, tracked.count) pushfirst!(T_wrapperargs, LLVM.PointerType(root_ty)) pushfirst!(T_wrapperargs, LLVM.PointerType(jltype)) - end + end end if needs_tape @@ -4286,7 +5476,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end if tape != C_NULL tape = LLVM.LLVMType(tape) - jltape = convert(LLVM.LLVMType, tape_type(tape); allow_boxed=true) + jltape = convert(LLVM.LLVMType, tape_type(tape); allow_boxed = true) push!(T_wrapperargs, jltape) else needs_tape = false @@ -4295,7 +5485,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, T_ret = returnRoots ? T_void : jltype FT = LLVM.FunctionType(T_ret, T_wrapperargs) - llvm_f = LLVM.Function(mod, safe_name(LLVM.name(enzymefn)*"wrap"), FT) + llvm_f = LLVM.Function(mod, safe_name(LLVM.name(enzymefn) * "wrap"), FT) API.EnzymeCloneFunctionDISubprogramInto(llvm_f, enzymefn) dl = datalayout(mod) @@ -4316,7 +5506,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if returnRoots sret = params[i] - i+= 1 + i += 1 attr = if LLVM.version().major >= 12 TypeAttribute("sret", jltype) @@ -4332,7 +5522,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, rootRet = nothing if returnRoots rootRet = params[i] - i+=1 + i += 1 end activeNum = 0 @@ -4348,7 +5538,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, llty = value_type(params[i]) - convty = convert(LLVMType, T′; allow_boxed=true) + convty = convert(LLVMType, T′; allow_boxed = true) if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType)) al0 = al = emit_allocobj!(builder, Base.RefValue{T′}, "mixedparameter") @@ -4368,7 +5558,10 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if isboxed if is_split msg = sprint() do io - println(io, "Unimplemented: Had active input arg needing a box in split mode") + println( + io, + "Unimplemented: Had active input arg needing a box in split mode", + ) println(io, T, " at index ", i) println(io, TT) end @@ -4376,13 +5569,28 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end @assert !is_split # TODO replace with better enzyme_zero - ptr = gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), activeNum)]) + ptr = gep!( + builder, + jltype, + sret, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), activeNum), + ], + ) cst = pointercast!(builder, ptr, ptr8) push!(realparms, ptr) - LLVM.memset!(builder, cst, LLVM.ConstantInt(LLVM.IntType(8), 0), - LLVM.ConstantInt(LLVM.IntType(64), LLVM.storage_size(dl, Base.eltype(LLVM.value_type(ptr)) )), - #=align=#0 ) + LLVM.memset!( + builder, + cst, + LLVM.ConstantInt(LLVM.IntType(8), 0), + LLVM.ConstantInt( + LLVM.IntType(64), + LLVM.storage_size(dl, Base.eltype(LLVM.value_type(ptr))), + ), + 0, + ) #=align=# end activeNum += 1 elseif T <: Duplicated || T <: DuplicatedNoNeed @@ -4392,9 +5600,13 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, parmsi = params[i] if T <: BatchMixedDuplicated - if GPUCompiler.deserves_argbox(NTuple{width, Base.RefValue{T′}}) + if GPUCompiler.deserves_argbox(NTuple{width,Base.RefValue{T′}}) njlvalue = LLVM.ArrayType(Int(width), T_prjlvalue) - parmsi = bitcast!(builder, parmsi, LLVM.PointerType(njlvalue, addrspace(value_type(parmsi)))) + parmsi = bitcast!( + builder, + parmsi, + LLVM.PointerType(njlvalue, addrspace(value_type(parmsi))), + ) parmsi = load!(builder, njlvalue, parmsi) end end @@ -4404,23 +5616,24 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, resty = isboxed ? llty : LLVM.PointerType(llty, Derived) ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, resty))) - for idx in 1:width - pv = (width == 1) ? parmsi : extract_value!(builder, parmsi, idx-1) - pv = bitcast!(builder, pv, LLVM.PointerType(llty, addrspace(value_type(pv)))) + for idx = 1:width + pv = (width == 1) ? parmsi : extract_value!(builder, parmsi, idx - 1) + pv = + bitcast!(builder, pv, LLVM.PointerType(llty, addrspace(value_type(pv)))) pv = addrspacecast!(builder, pv, LLVM.PointerType(llty, Derived)) if isboxed pv = load!(builder, llty, pv, "mixedboxload") end - ival = (width == 1 ) ? pv : insert_value!(builder, ival, pv, idx-1) + ival = (width == 1) ? pv : insert_value!(builder, ival, pv, idx - 1) end push!(realparms, ival) i += 1 elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed - isboxed = GPUCompiler.deserves_argbox(NTuple{width, T′}) + isboxed = GPUCompiler.deserves_argbox(NTuple{width,T′}) val = params[i] if isboxed - val = load!(builder, val) + val = load!(builder, val) end i += 1 push!(realparms, val) @@ -4430,7 +5643,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, llvmf = nested_codegen!(Mode, mod, funcspec, world) push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) Func_RT = Core.Compiler.typeinf_ext_toplevel(interp, funcspec).rettype - @assert Func_RT == NTuple{width, T′} + @assert Func_RT == NTuple{width,T′} _, psret, _ = get_return_info(Func_RT) args = LLVM.Value[] if psret !== nothing @@ -4439,7 +5652,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end res = LLVM.call!(builder, LLVM.function_type(llvmf), llvmf, args) if get_subprogram(llvmf) !== nothing - metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) ) + metadata(res)[LLVM.MD_dbg] = DILocation(0, 0, get_subprogram(llvm_f)) end if psret !== nothing res = load!(builder, convert(LLVMType, Func_RT), psret) @@ -4450,7 +5663,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end end - if is_adjoint && (rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated) + if is_adjoint && + (rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated) push!(realparms, params[i]) i += 1 end @@ -4466,7 +5680,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, val = call!(builder, LLVM.function_type(enzymefn), enzymefn, realparms) if get_subprogram(llvm_f) !== nothing - metadata(val)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) ) + metadata(val)[LLVM.MD_dbg] = DILocation(0, 0, get_subprogram(llvm_f)) end @inline function fixup_abi(index, value) @@ -4486,11 +5700,13 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, # if in split mode and the return is a union marked duplicated, upgrade floating point like shadow returns into ref{ty} since otherwise use of the value will create problems. # 3 is index of shadow - if existed[3] != 0 && sret_union && active_reg_inner(pactualRetType, (), world, #=justActive=#Val(true), #=UnionSret=#Val(true)) == ActiveState + if existed[3] != 0 && + sret_union && + active_reg_inner(pactualRetType, (), world, Val(true), Val(true)) == ActiveState #=UnionSret=# rewrite_union_returns_as_ref(enzymefn, data[3], world, width) end returnNum = 0 - for i in 1:3 + for i = 1:3 if existed[i] != 0 eval = val if data[i] != -1 @@ -4498,31 +5714,56 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end if i == 3 if rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated - ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue))) - for idx in 1:width - pv = (width == 1) ? eval : extract_value!(builder, eval, idx-1) - al0 = al = emit_allocobj!(builder, Base.RefValue{eltype(rettype)}, "batchmixedret") + ival = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue)), + ) + for idx = 1:width + pv = + (width == 1) ? eval : extract_value!(builder, eval, idx - 1) + al0 = + al = emit_allocobj!( + builder, + Base.RefValue{eltype(rettype)}, + "batchmixedret", + ) llty = value_type(pv) - al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) + al = bitcast!( + builder, + al, + LLVM.PointerType(llty, addrspace(value_type(al))), + ) store!(builder, pv, al) - emit_writebarrier!(builder, get_julia_inner_types(builder, al0, pv)) - ival = (width == 1 ) ? al0 : insert_value!(builder, ival, al0, idx-1) + emit_writebarrier!( + builder, + get_julia_inner_types(builder, al0, pv), + ) + ival = + (width == 1) ? al0 : + insert_value!(builder, ival, al0, idx - 1) end eval = ival end end eval = fixup_abi(i, eval) - ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) + ptr = inbounds_gep!( + builder, + jltype, + sret, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), returnNum), + ], + ) ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) si = store!(builder, eval, ptr) - returnNum+=1 + returnNum += 1 if i == 3 && shadow_init shadows = LLVM.Value[] if width == 1 push!(shadows, eval) else - for i in 1:width - push!(shadows, extract_value!(builder, eval, i-1)) + for i = 1:width + push!(shadows, extract_value!(builder, eval, i - 1)) end end @@ -4531,7 +5772,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, for shadowv in shadows c = call!(builder, LLVM.function_type(cf), cf, [shadowv]) if get_subprogram(llvm_f) !== nothing - metadata(c)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) ) + metadata(c)[LLVM.MD_dbg] = + DILocation(0, 0, get_subprogram(llvm_f)) end end end @@ -4541,14 +5783,24 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if i == 2 ty = actualRetType end - @assert !(isghostty(combinedReturn) || Core.Compiler.isconstType(combinedReturn) ) + @assert !( + isghostty(combinedReturn) || Core.Compiler.isconstType(combinedReturn) + ) @assert Core.Compiler.isconstType(ty) eval = makeInstanceOf(builder, ty) eval = fixup_abi(i, eval) - ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) + ptr = inbounds_gep!( + builder, + jltype, + sret, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), returnNum), + ], + ) ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) si = store!(builder, eval, ptr) - returnNum+=1 + returnNum += 1 end end @assert returnNum == numLLVMReturns @@ -4571,16 +5823,27 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, count_Sret += 1 end end - for returnNum in 0:(count_Sret-1) - eval = fixup_abi(returnNum+1, if count_llvm_Sret == 0 - makeInstanceOf(builder, sret_types[returnNum+1]) - elseif count_llvm_Sret == 1 - val - else - @assert count_llvm_Sret > 1 - extract_value!(builder, val, 1-returnNum) - end) - ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) + for returnNum = 0:(count_Sret-1) + eval = fixup_abi( + returnNum + 1, + if count_llvm_Sret == 0 + makeInstanceOf(builder, sret_types[returnNum+1]) + elseif count_llvm_Sret == 1 + val + else + @assert count_llvm_Sret > 1 + extract_value!(builder, val, 1 - returnNum) + end, + ) + ptr = inbounds_gep!( + builder, + jltype, + sret, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), returnNum), + ], + ) ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) si = store!(builder, eval, ptr) end @@ -4591,13 +5854,31 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if Mode == API.DEM_ReverseModeCombined if returnPrimal if !isghostty(literal_rt) - eval = fixup_abi(returnNum+1, if !isghostty(actualRetType) - extract_value!(builder, val, returnNum) - else - makeInstanceOf(builder, sret_types[returnNum+1]) - end) - store!(builder, eval, inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), length(elements(jltype))-1 )])) - returnNum+=1 + eval = fixup_abi( + returnNum + 1, + if !isghostty(actualRetType) + extract_value!(builder, val, returnNum) + else + makeInstanceOf(builder, sret_types[returnNum+1]) + end, + ) + store!( + builder, + eval, + inbounds_gep!( + builder, + jltype, + sret, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt( + LLVM.IntType(32), + length(elements(jltype)) - 1, + ), + ], + ), + ) + returnNum += 1 end end end @@ -4607,10 +5888,23 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, isboxed = GPUCompiler.deserves_argbox(T′) if !isboxed eval = extract_value!(builder, val, returnNum) - store!(builder, eval, inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0), LLVM.ConstantInt(LLVM.IntType(32), activeNum)])) - returnNum+=1 + store!( + builder, + eval, + inbounds_gep!( + builder, + jltype, + sret, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), 0), + LLVM.ConstantInt(LLVM.IntType(32), activeNum), + ], + ), + ) + returnNum += 1 end - activeNum+=1 + activeNum += 1 end end @assert (returnNum - activeNum) + (activeNum != 0 ? 1 : 0) == numLLVMReturns @@ -4618,21 +5912,32 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if returnRoots count = 0 - todo = Tuple{Vector{LLVM.Value},LLVM.LLVMType}[([LLVM.ConstantInt(LLVM.IntType(64), 0)], jltype)] + todo = Tuple{Vector{LLVM.Value},LLVM.LLVMType}[( + [LLVM.ConstantInt(LLVM.IntType(64), 0)], + jltype, + )] while length(todo) != 0 path, ty = popfirst!(todo) if isa(ty, LLVM.PointerType) - loc = inbounds_gep!(builder, root_ty, rootRet, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), count)]) - count+=1 + loc = inbounds_gep!( + builder, + root_ty, + rootRet, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), count), + ], + ) + count += 1 outloc = inbounds_gep!(builder, jltype, sret, path) store!(builder, load!(builder, ty, outloc), loc) continue end if isa(ty, LLVM.ArrayType) if any_jltypes(ty) - for i=1:length(ty) + for i = 1:length(ty) npath = copy(path) - push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) + push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i - 1)) push!(todo, (npath, eltype(ty))) end end @@ -4640,9 +5945,9 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end if isa(ty, LLVM.VectorType) if any_jltypes(ty) - for i=1:size(ty) + for i = 1:size(ty) npath = copy(path) - push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) + push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i - 1)) push!(todo, (npath, eltype(ty))) end end @@ -4652,7 +5957,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, for (i, t) in enumerate(LLVM.elements(ty)) if any_jltypes(t) npath = copy(path) - push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) + push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i - 1)) push!(todo, (npath, t)) end end @@ -4704,69 +6009,103 @@ function fixup_metadata!(f::LLVM.Function) end end -struct RemovedParam -end +struct RemovedParam end # Modified from GPUCompiler classify_arguments -function classify_arguments(source_sig::Type, codegen_ft::LLVM.FunctionType, has_sret::Bool, has_returnroots::Bool, has_swiftself::Bool, parmsRemoved::Vector{UInt64}) +function classify_arguments( + source_sig::Type, + codegen_ft::LLVM.FunctionType, + has_sret::Bool, + has_returnroots::Bool, + has_swiftself::Bool, + parmsRemoved::Vector{UInt64}, +) codegen_types = parameters(codegen_ft) args = [] codegen_i = 1 orig_i = 1 if has_sret - if !in(orig_i-1, parmsRemoved) + if !in(orig_i - 1, parmsRemoved) codegen_i += 1 end orig_i += 1 end if has_returnroots - if !in(orig_i-1, parmsRemoved) + if !in(orig_i - 1, parmsRemoved) codegen_i += 1 end orig_i += 1 end if has_swiftself - if !in(orig_i-1, parmsRemoved) + if !in(orig_i - 1, parmsRemoved) codegen_i += 1 end orig_i += 1 end for (source_i, source_typ) in enumerate(source_sig.parameters) if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) - push!(args, (cc=GPUCompiler.GHOST, typ=source_typ, arg_i=source_i)) + push!(args, (cc = GPUCompiler.GHOST, typ = source_typ, arg_i = source_i)) continue end - if in(orig_i-1, parmsRemoved) - push!(args, (cc=RemovedParam, typ=source_typ)) + if in(orig_i - 1, parmsRemoved) + push!(args, (cc = RemovedParam, typ = source_typ)) orig_i += 1 continue end codegen_typ = codegen_types[codegen_i] if codegen_typ isa LLVM.PointerType - llvm_source_typ = convert(LLVMType, source_typ; allow_boxed=true) + llvm_source_typ = convert(LLVMType, source_typ; allow_boxed = true) # pointers are used for multiple kinds of arguments # - literal pointer values if source_typ <: Ptr || source_typ <: Core.LLVMPtr @assert llvm_source_typ == codegen_typ - push!(args, (cc=GPUCompiler.BITS_VALUE, typ=source_typ, arg_i=source_i, - codegen=(typ=codegen_typ, i=codegen_i))) - # - boxed values - # XXX: use `deserves_retbox` instead? + push!( + args, + ( + cc = GPUCompiler.BITS_VALUE, + typ = source_typ, + arg_i = source_i, + codegen = (typ = codegen_typ, i = codegen_i), + ), + ) + # - boxed values + # XXX: use `deserves_retbox` instead? elseif llvm_source_typ isa LLVM.PointerType @assert llvm_source_typ == codegen_typ - push!(args, (cc=GPUCompiler.MUT_REF, typ=source_typ, arg_i=source_i, - codegen=(typ=codegen_typ, i=codegen_i))) - # - references to aggregates + push!( + args, + ( + cc = GPUCompiler.MUT_REF, + typ = source_typ, + arg_i = source_i, + codegen = (typ = codegen_typ, i = codegen_i), + ), + ) + # - references to aggregates else @assert llvm_source_typ != codegen_typ - push!(args, (cc=GPUCompiler.BITS_REF, typ=source_typ, arg_i=source_i, - codegen=(typ=codegen_typ, i=codegen_i))) + push!( + args, + ( + cc = GPUCompiler.BITS_REF, + typ = source_typ, + arg_i = source_i, + codegen = (typ = codegen_typ, i = codegen_i), + ), + ) end else - push!(args, (cc=GPUCompiler.BITS_VALUE, typ=source_typ, arg_i=source_i, - codegen=(typ=codegen_typ, i=codegen_i))) + push!( + args, + ( + cc = GPUCompiler.BITS_VALUE, + typ = source_typ, + arg_i = source_i, + codegen = (typ = codegen_typ, i = codegen_i), + ), + ) end codegen_i += 1 @@ -4778,9 +6117,9 @@ end function isSpecialPtr(Ty) if !isa(Ty, LLVM.PointerType) - return false - end - AS = LLVM.addrspace(Ty) + return false + end + AS = LLVM.addrspace(Ty) return 10 <= AS && AS <= 13 end @@ -4791,14 +6130,14 @@ mutable struct CountTrackedPointers end function CountTrackedPointers(T) - res = CountTrackedPointers(0, true, false) + res = CountTrackedPointers(0, true, false) if isa(T, LLVM.PointerType) if isSpecialPtr(T) res.count += 1 if LLVM.addrspace(T) != Tracked res.derived = true - end + end end elseif isa(T, LLVM.StructType) for ElT in elements(T) @@ -4807,43 +6146,43 @@ function CountTrackedPointers(T) res.all &= sub.all res.derived |= sub.derived end - elseif isa(T, LLVM.ArrayType) - sub = CountTrackedPointers(eltype(T)) - res.count += sub.count - res.all &= sub.all - res.derived |= sub.derived - res.count *= length(T) - elseif isa(T, LLVM.VectorType) - sub = CountTrackedPointers(eltype(T)) - res.count += sub.count - res.all &= sub.all - res.derived |= sub.derived - res.count *= size(T) + elseif isa(T, LLVM.ArrayType) + sub = CountTrackedPointers(eltype(T)) + res.count += sub.count + res.all &= sub.all + res.derived |= sub.derived + res.count *= length(T) + elseif isa(T, LLVM.VectorType) + sub = CountTrackedPointers(eltype(T)) + res.count += sub.count + res.all &= sub.all + res.derived |= sub.derived + res.count *= size(T) end if res.count == 0 res.all = false - end - return res + end + return res end # must deserve sret function deserves_rooting(T) - tracked = CountTrackedPointers(T) - @assert !tracked.derived - if tracked.count != 0 && !tracked.all - return true # tracked.count; - end - return false + tracked = CountTrackedPointers(T) + @assert !tracked.derived + if tracked.count != 0 && !tracked.all + return true # tracked.count; + end + return false end # https://github.com/JuliaLang/julia/blob/64378db18b512677fc6d3b012e6d1f02077af191/src/cgutils.cpp#L823 # returns if all unboxed -function for_each_uniontype_small(f, ty, counter=Ref(0)) +function for_each_uniontype_small(f, ty, counter = Ref(0)) if counter[] > 127 return false end if ty isa Union - allunbox = for_each_uniontype_small(f, ty.a, counter) + allunbox = for_each_uniontype_small(f, ty.a, counter) allunbox &= for_each_uniontype_small(f, ty.b, counter) return allunbox end @@ -4860,8 +6199,8 @@ end function union_alloca_type(UT) nbytes = 0 function inner(jlrettype) - if !(Base.issingletontype(jlrettype) &&isa(jlrettype, DataType)) - nbytes = max(nbytes, sizeof(jlrettype)) + if !(Base.issingletontype(jlrettype) && isa(jlrettype, DataType)) + nbytes = max(nbytes, sizeof(jlrettype)) end end for_each_uniontype_small(inner, UT) @@ -4873,7 +6212,9 @@ function is_sret(jlrettype) if jlrettype === Union{} # jlrettype == (jl_value_t*)jl_bottom_type return false - elseif Base.isstructtype(jlrettype) && Base.issingletontype(jlrettype) &&isa(jlrettype, DataType) + elseif Base.isstructtype(jlrettype) && + Base.issingletontype(jlrettype) && + isa(jlrettype, DataType) # jl_is_structtype(jlrettype) && jl_is_datatype_singleton((jl_datatype_t*)jlrettype) return false elseif jlrettype isa Union # jl_is_uniontype(jlrettype) @@ -4883,7 +6224,7 @@ function is_sret(jlrettype) end return false elseif !GPUCompiler.deserves_retbox(jlrettype) - rt = convert(LLVMType, jlrettype ) + rt = convert(LLVMType, jlrettype) if !isa(rt, LLVM.VoidType) && GPUCompiler.deserves_sret(jlrettype, rt) return true end @@ -4894,7 +6235,9 @@ function is_sret_union(jlrettype) if jlrettype === Union{} # jlrettype == (jl_value_t*)jl_bottom_type return false - elseif Base.isstructtype(jlrettype) && Base.issingletontype(jlrettype) &&isa(jlrettype, DataType) + elseif Base.isstructtype(jlrettype) && + Base.issingletontype(jlrettype) && + isa(jlrettype, DataType) # jl_is_structtype(jlrettype) && jl_is_datatype_singleton((jl_datatype_t*)jlrettype) return false elseif jlrettype isa Union # jl_is_uniontype(jlrettype) @@ -4907,19 +6250,23 @@ function is_sret_union(jlrettype) end # https://github.com/JuliaLang/julia/blob/0a696a3842750fcedca8832bc0aabe9096c7658f/src/codegen.cpp#L6812 -function get_return_info(jlrettype)::Tuple{Union{Nothing, Type}, Union{Nothing, Type}, Union{Nothing, Type}} +function get_return_info( + jlrettype, +)::Tuple{Union{Nothing,Type},Union{Nothing,Type},Union{Nothing,Type}} sret = nothing returnRoots = nothing rt = nothing if jlrettype === Union{} rt = Nothing - elseif Base.isstructtype(jlrettype) && Base.issingletontype(jlrettype) &&isa(jlrettype, DataType) + elseif Base.isstructtype(jlrettype) && + Base.issingletontype(jlrettype) && + isa(jlrettype, DataType) rt = Nothing elseif jlrettype isa Union nbytes = 0 allunbox = for_each_uniontype_small(jlrettype) do jlrettype if !(Base.issingletontype(jlrettype) && isa(jlrettype, DataType)) - nbytes = max(nbytes, sizeof(jlrettype)) + nbytes = max(nbytes, sizeof(jlrettype)) end end if nbytes != 0 @@ -4934,7 +6281,7 @@ function get_return_info(jlrettype)::Tuple{Union{Nothing, Type}, Union{Nothing, elseif jlrettype <: Tuple && in(Any, jlrettype.parameters) rt = Any elseif !GPUCompiler.deserves_retbox(jlrettype) - lRT = convert(LLVMType, jlrettype ) + lRT = convert(LLVMType, jlrettype) if !isa(lRT, LLVM.VoidType) && GPUCompiler.deserves_sret(jlrettype, lRT) sret = Ptr{jlrettype} tracked = CountTrackedPointers(lRT) @@ -4954,7 +6301,15 @@ function get_return_info(jlrettype)::Tuple{Union{Nothing, Type}, Union{Nothing, end # Modified from GPUCompiler/src/irgen.jl:365 lower_byval -function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function, actualRetType::Type, RetActivity, TT, run_enzyme) +function lower_convention( + functy::Type, + mod::LLVM.Module, + entry_f::LLVM.Function, + actualRetType::Type, + RetActivity, + TT, + run_enzyme, +) entry_ft = LLVM.function_type(entry_f) RT = LLVM.return_type(entry_ft) @@ -4977,9 +6332,17 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function # TODO removed implications retRemoved, parmsRemoved = removed_ret_parms(entry_f) - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(entry_f, i)))) for i in 1:length(collect(parameters(entry_f)))) + swiftself = any( + any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(entry_f, i)), + ), + ) for i = 1:length(collect(parameters(entry_f))) + ) @assert !swiftself "Swiftself attribute coming from differentiable context is not supported" - prargs = classify_arguments(functy, entry_ft, sret, returnRoots, swiftself, parmsRemoved) + prargs = + classify_arguments(functy, entry_ft, sret, returnRoots, swiftself, parmsRemoved) args = copy(prargs) filter!(args) do arg Base.@_inline_meta @@ -5011,7 +6374,12 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function push!(wrapper_types, typ) push!(wrapper_attrs, LLVM.Attribute[]) elseif arg.cc != GPUCompiler.BITS_REF - if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated) && run_enzyme + if TT != nothing && + ( + TT.parameters[arg.arg_i] <: MixedDuplicated || + TT.parameters[arg.arg_i] <: BatchMixedDuplicated + ) && + run_enzyme push!(boxedArgs, arg.arg_i) push!(raisedArgs, arg.arg_i) push!(wrapper_types, LLVM.PointerType(typ, Derived)) @@ -5022,7 +6390,12 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function end else # bits ref, and not boxed - if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated) && run_enzyme + if TT != nothing && + ( + TT.parameters[arg.arg_i] <: MixedDuplicated || + TT.parameters[arg.arg_i] <: BatchMixedDuplicated + ) && + run_enzyme push!(boxedArgs, arg.arg_i) push!(wrapper_types, typ) push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")]) @@ -5048,10 +6421,24 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function set_subprogram!(wrapper_f, sfn) end - hasReturnsTwice = any(map(k->kind(k)==kind(EnumAttribute("returns_twice")), collect(function_attributes(entry_f)))) - hasNoInline = any(map(k->kind(k)==kind(EnumAttribute("noinline")), collect(function_attributes(entry_f)))) + hasReturnsTwice = any( + map( + k -> kind(k) == kind(EnumAttribute("returns_twice")), + collect(function_attributes(entry_f)), + ), + ) + hasNoInline = any( + map( + k -> kind(k) == kind(EnumAttribute("noinline")), + collect(function_attributes(entry_f)), + ), + ) if hasNoInline - LLVM.API.LLVMRemoveEnumAttributeAtIndex(entry_f, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), kind(EnumAttribute("noinline"))) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + entry_f, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + kind(EnumAttribute("noinline")), + ) end push!(function_attributes(wrapper_f), EnumAttribute("returns_twice")) push!(function_attributes(entry_f), EnumAttribute("returns_twice")) @@ -5089,10 +6476,18 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function push!(nops, load!(builder, convert(LLVMType, arg.typ), parm)) elseif arg.arg_i in raisedArgs obj = emit_allocobj!(builder, arg.typ, "raisedArg") - bc = bitcast!(builder, obj, LLVM.PointerType(value_type(parm), addrspace(value_type(obj)))) + bc = bitcast!( + builder, + obj, + LLVM.PointerType(value_type(parm), addrspace(value_type(obj))), + ) store!(builder, parm, bc) emit_writebarrier!(builder, get_julia_inner_types(builder, obj, parm)) - addr = addrspacecast!(builder, bc, LLVM.PointerType(value_type(parm), Derived)) + addr = addrspacecast!( + builder, + bc, + LLVM.PointerType(value_type(parm), Derived), + ) push!(nops, addr) else push!(nops, parm) @@ -5133,17 +6528,25 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function dl = string(LLVM.datalayout(LLVM.parent(entry_f))) if sret if !in(0, parmsRemoved) - sretPtr = alloca!(builder, eltype(value_type(parameters(entry_f)[1])), "innersret") + sretPtr = alloca!( + builder, + eltype(value_type(parameters(entry_f)[1])), + "innersret", + ) ctx = LLVM.context(entry_f) if RetActivity <: Const metadata(sretPtr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end - metadata(sretPtr)["enzyme_type"] = to_md(typetree(Ptr{actualRetType}, ctx, - dl, seen), ctx) + metadata(sretPtr)["enzyme_type"] = + to_md(typetree(Ptr{actualRetType}, ctx, dl, seen), ctx) push!(wrapper_args, sretPtr) end if returnRoots && !in(1, parmsRemoved) - retRootPtr = alloca!(builder, eltype(value_type(parameters(entry_f)[1+sret])), "innerreturnroots") + retRootPtr = alloca!( + builder, + eltype(value_type(parameters(entry_f)[1+sret])), + "innerreturnroots", + ) # retRootPtr = alloca!(builder, parameters(wrapper_f)[1]) push!(wrapper_args, retRootPtr) end @@ -5160,48 +6563,95 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function # copy the argument value to a stack slot, and reference it. ty = value_type(parm) if !isa(ty, LLVM.PointerType) - throw(AssertionError("ty is not a LLVM.PointerType: entry_f = $(entry_f), args = $(args), parm = $(parm), ty = $(ty)")) + throw( + AssertionError( + "ty is not a LLVM.PointerType: entry_f = $(entry_f), args = $(args), parm = $(parm), ty = $(ty)", + ), + ) end - ptr = alloca!(builder, eltype(ty), LLVM.name(parm)*".innerparm") + ptr = alloca!(builder, eltype(ty), LLVM.name(parm) * ".innerparm") if TT !== nothing && TT.parameters[arg.arg_i] <: Const metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end ctx = LLVM.context(entry_f) - metadata(ptr)["enzyme_type"] = to_md(typetree(Ptr{arg.typ}, ctx, dl, seen), - ctx) + metadata(ptr)["enzyme_type"] = + to_md(typetree(Ptr{arg.typ}, ctx, dl, seen), ctx) if LLVM.addrspace(ty) != 0 ptr = addrspacecast!(builder, ptr, ty) end @assert eltype(ty) == value_type(wrapparm) store!(builder, wrapparm, ptr) push!(wrapper_args, ptr) - push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(arg.typ, ctx, dl, seen)))) - push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))))) - push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_VALUE)))) + push!( + parameter_attributes(wrapper_f, arg.codegen.i - sret - returnRoots), + StringAttribute( + "enzyme_type", + string(typetree(arg.typ, ctx, dl, seen)), + ), + ) + push!( + parameter_attributes(wrapper_f, arg.codegen.i - sret - returnRoots), + StringAttribute( + "enzymejl_parmtype", + string(convert(UInt, unsafe_to_pointer(arg.typ))), + ), + ) + push!( + parameter_attributes(wrapper_f, arg.codegen.i - sret - returnRoots), + StringAttribute( + "enzymejl_parmtype_ref", + string(UInt(GPUCompiler.BITS_VALUE)), + ), + ) elseif arg.arg_i in raisedArgs wrapparm = load!(builder, convert(LLVMType, arg.typ), wrapparm) ctx = LLVM.context(wrapparm) push!(wrapper_args, wrapparm) - push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(Base.RefValue{arg.typ}, ctx, dl, seen)))) - push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))))) - push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) + push!( + parameter_attributes(wrapper_f, arg.codegen.i - sret - returnRoots), + StringAttribute( + "enzyme_type", + string(typetree(Base.RefValue{arg.typ}, ctx, dl, seen)), + ), + ) + push!( + parameter_attributes(wrapper_f, arg.codegen.i - sret - returnRoots), + StringAttribute( + "enzymejl_parmtype", + string(convert(UInt, unsafe_to_pointer(arg.typ))), + ), + ) + push!( + parameter_attributes(wrapper_f, arg.codegen.i - sret - returnRoots), + StringAttribute( + "enzymejl_parmtype_ref", + string(UInt(GPUCompiler.BITS_REF)), + ), + ) else push!(wrapper_args, wrapparm) for attr in collect(parameter_attributes(entry_f, arg.codegen.i)) - push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), attr) + push!( + parameter_attributes(wrapper_f, arg.codegen.i - sret - returnRoots), + attr, + ) end end end res = call!(builder, LLVM.function_type(entry_f), entry_f, wrapper_args) if get_subprogram(entry_f) !== nothing - metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(entry_f) ) + metadata(res)[LLVM.MD_dbg] = DILocation(0, 0, get_subprogram(entry_f)) end callconv!(res, LLVM.callconv(entry_f)) if swiftself attr = EnumAttribute("swiftself") - LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(1+sret+returnRoots), attr) + LLVM.API.LLVMAddCallSiteAttribute( + res, + LLVM.API.LLVMAttributeIndex(1 + sret + returnRoots), + attr, + ) end # Box union return, from https://github.com/JuliaLang/julia/blob/81813164963f38dcd779d65ecd222fad8d7ed437/src/cgutils.cpp#L3138 @@ -5229,9 +6679,28 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function nobj = if sretPtr !== nothing obj = emit_allocobj!(builder, jlrettype, "boxunion") llty = convert(LLVMType, jlrettype) - ld = load!(builder, llty, bitcast!(builder, sretPtr, LLVM.PointerType(llty, addrspace(value_type(sretPtr))))) - store!(builder, ld, bitcast!(builder, obj, LLVM.PointerType(llty, addrspace(value_type(obj))))) - emit_writebarrier!(builder, get_julia_inner_types(builder, obj, ld)) + ld = load!( + builder, + llty, + bitcast!( + builder, + sretPtr, + LLVM.PointerType(llty, addrspace(value_type(sretPtr))), + ), + ) + store!( + builder, + ld, + bitcast!( + builder, + obj, + LLVM.PointerType(llty, addrspace(value_type(obj))), + ), + ) + emit_writebarrier!( + builder, + get_julia_inner_types(builder, obj, ld), + ) # memcpy!(builder, bitcast!(builder, obj, LLVM.PointerType(T_int8, addrspace(value_type(obj)))), 0, bitcast!(builder, sretPtr, LLVM.PointerType(T_int8)), 0, LLVM.ConstantInt(T_int64, sizeof(jlrettype))) obj else @@ -5240,35 +6709,93 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function ret!(builder, obj) end - LLVM.API.LLVMAddCase(sw, LLVM.ConstantInt(value_type(scase), counter), BB) - counter+=1 + LLVM.API.LLVMAddCase( + sw, + LLVM.ConstantInt(value_type(scase), counter), + BB, + ) + counter += 1 return end for_each_uniontype_small(inner, actualRetType) position!(builder, def) ret!(builder, extract_value!(builder, res, 0)) - - push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen)))) - push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType))))) - push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) + + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzyme_type", + string(typetree(actualRetType, ctx, dl, seen)), + ), + ) + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzymejl_parmtype", + string(convert(UInt, unsafe_to_pointer(actualRetType))), + ), + ) + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzymejl_parmtype_ref", + string(UInt(GPUCompiler.BITS_REF)), + ), + ) end elseif sret if sretPtr === nothing ret!(builder) else - push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen)))) - push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType))))) - push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzyme_type", + string(typetree(actualRetType, ctx, dl, seen)), + ), + ) + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzymejl_parmtype", + string(convert(UInt, unsafe_to_pointer(actualRetType))), + ), + ) + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzymejl_parmtype_ref", + string(UInt(GPUCompiler.BITS_REF)), + ), + ) ret!(builder, load!(builder, RT, sretPtr)) end elseif LLVM.return_type(entry_ft) == LLVM.VoidType() ret!(builder) else ctx = LLVM.context(wrapper_f) - push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen)))) - push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType))))) - push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzyme_type", + string(typetree(actualRetType, ctx, dl, seen)), + ), + ) + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzymejl_parmtype", + string(convert(UInt, unsafe_to_pointer(actualRetType))), + ), + ) + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzymejl_parmtype_ref", + string(UInt(GPUCompiler.BITS_REF)), + ), + ) ret!(builder, res) end dispose(builder) @@ -5279,12 +6806,18 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function linkage!(entry_f, LLVM.API.LLVMInternalLinkage) fixup_metadata!(entry_f) - + mi, rt = enzyme_custom_extract_mi(entry_f) attributes = function_attributes(wrapper_f) - push!(attributes, StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi))))) - push!(attributes, StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(rt))))) - + push!( + attributes, + StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi)))), + ) + push!( + attributes, + StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(rt)))), + ) + for prev in collect(function_attributes(entry_f)) if kind(prev) == kind(StringAttribute("enzyme_ta_norecur")) push!(attributes, prev) @@ -5313,11 +6846,15 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function end end if LLVM.version().major > 15 - if kind(prev) == kind(EnumAttribute("memory")) - old = MemoryEffect(value(attr)) - mem = MemoryEffect(( set_writing(getModRef(old, ArgMem)) << getLocationPos(ArgMem)) | (getModRef(old, InaccessibleMem) << getLocationPos(InaccessibleMem)) | (getModRef(old, Other) << getLocationPos(Other))) - push!(attributes, EnumAttribute("memory", mem.data)) - end + if kind(prev) == kind(EnumAttribute("memory")) + old = MemoryEffect(value(attr)) + mem = MemoryEffect( + (set_writing(getModRef(old, ArgMem)) << getLocationPos(ArgMem)) | + (getModRef(old, InaccessibleMem) << getLocationPos(InaccessibleMem)) | + (getModRef(old, Other) << getLocationPos(Other)), + ) + push!(attributes, EnumAttribute("memory", mem.data)) + end end if kind(prev) == kind(EnumAttribute("speculatable")) push!(attributes, prev) @@ -5336,26 +6873,45 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMReturnStatusAction) != 0 msg = sprint() do io println(io, string(mod)) - println(io, LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction)) + println( + io, + LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction), + ) println(io, string(wrapper_f)) - println(io, "parmsRemoved=", parmsRemoved, " retRemoved=", retRemoved, " prargs=", prargs) + println( + io, + "parmsRemoved=", + parmsRemoved, + " retRemoved=", + retRemoved, + " prargs=", + prargs, + ) println(io, "Broken function") end throw(LLVM.LLVMException(msg)) end - ModulePassManager() do pm + ModulePassManager() do pm always_inliner!(pm) LLVM.run!(pm, mod) end if !hasReturnsTwice - LLVM.API.LLVMRemoveEnumAttributeAtIndex(wrapper_f, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), kind(EnumAttribute("returns_twice"))) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + wrapper_f, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + kind(EnumAttribute("returns_twice")), + ) end if hasNoInline - LLVM.API.LLVMRemoveEnumAttributeAtIndex(wrapper_f, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), kind(EnumAttribute("alwaysinline"))) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + wrapper_f, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + kind(EnumAttribute("alwaysinline")), + ) push!(function_attributes(wrapper_f), EnumAttribute("noinline")) end - + # Fix phinodes used exclusively in extractvalue to be separate phi nodes phistofix = LLVM.PHIInst[] for bb in blocks(wrapper_f) @@ -5395,11 +6951,11 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function phis = LLVM.PHIInst[] for (i, t) in enumerate(LLVM.elements(st)) np = phi!(nb, t) - nvs = Tuple{LLVM.Value, LLVM.BasicBlock}[] - for (v, b) in LLVM.incoming(p) + nvs = Tuple{LLVM.Value,LLVM.BasicBlock}[] + for (v, b) in LLVM.incoming(p) prevbld = IRBuilder() position!(prevbld, terminator(b)) - push!(nvs, (extract_value!(prevbld, v, i-1), b)) + push!(nvs, (extract_value!(prevbld, v, i - 1), b)) end append!(LLVM.incoming(np), nvs) push!(phis, np) @@ -5429,7 +6985,9 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if haskey(globals(mod), "llvm.used") eraseInst(mod, globals(mod)["llvm.used"]) for u in user.(collect(uses(entry_f))) - if isa(u, LLVM.GlobalVariable) && endswith(LLVM.name(u), "_slot") && startswith(LLVM.name(u), "julia") + if isa(u, LLVM.GlobalVariable) && + endswith(LLVM.name(u), "_slot") && + startswith(LLVM.name(u), "julia") eraseInst(mod, u) end end @@ -5438,7 +6996,10 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMReturnStatusAction) != 0 msg = sprint() do io println(io, string(mod)) - println(io, LVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction)) + println( + io, + LVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction), + ) println(io, string(wrapper_f)) println(io, "Broken function") end @@ -5449,7 +7010,7 @@ end using Random # returns arg, return -function no_type_setting(@nospecialize(specTypes); world=nothing) +function no_type_setting(@nospecialize(specTypes); world = nothing) # Even though the julia type here is ptr{int8}, the actual data can be something else if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_simd) return (true, false) @@ -5462,22 +7023,35 @@ end const DumpPreOpt = Ref(false) -function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; - libraries::Bool=true, deferred_codegen::Bool=true, optimize::Bool=true, toplevel::Bool=true, - strip::Bool=false, validate::Bool=true, only_entry::Bool=false, parent_job::Union{Nothing, CompilerJob} = nothing) - params = job.config.params +function GPUCompiler.codegen( + output::Symbol, + job::CompilerJob{<:EnzymeTarget}; + libraries::Bool = true, + deferred_codegen::Bool = true, + optimize::Bool = true, + toplevel::Bool = true, + strip::Bool = false, + validate::Bool = true, + only_entry::Bool = false, + parent_job::Union{Nothing,CompilerJob} = nothing, +) + params = job.config.params if params.run_enzyme @assert eltype(params.rt) != Union{} end expectedTapeType = params.expectedTapeType - mode = params.mode + mode = params.mode TT = params.TT width = params.width abiwrap = params.abiwrap - primal = job.source + primal = job.source modifiedBetween = params.modifiedBetween - if length(modifiedBetween) != length(TT.parameters) - throw(AssertionError("length(modifiedBetween) [aka $(length(modifiedBetween))] != length(TT.parameters) [aka $(length(TT.parameters))] at TT=$TT")) + if length(modifiedBetween) != length(TT.parameters) + throw( + AssertionError( + "length(modifiedBetween) [aka $(length(modifiedBetween))] != length(TT.parameters) [aka $(length(TT.parameters))] at TT=$TT", + ), + ) end returnPrimal = params.returnPrimal @@ -5487,21 +7061,40 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if parent_job === nothing primal_target = DefaultCompilerTarget() primal_params = PrimalCompilerParams(mode) - primal_job = CompilerJob(primal, CompilerConfig(primal_target, primal_params; kernel=false), job.world) + primal_job = CompilerJob( + primal, + CompilerConfig(primal_target, primal_params; kernel = false), + job.world, + ) else - config2 = CompilerConfig(parent_job.config.target, parent_job.config.params; kernel=false, parent_job.config.entry_abi, parent_job.config.name, parent_job.config.always_inline) + config2 = CompilerConfig( + parent_job.config.target, + parent_job.config.params; + kernel = false, + parent_job.config.entry_abi, + parent_job.config.name, + parent_job.config.always_inline, + ) primal_job = CompilerJob(primal, config2, job.world) # TODO EnzymeInterp params, etc end - mod, meta = GPUCompiler.codegen(:llvm, primal_job; optimize=false, toplevel=toplevel, cleanup=false, validate=false, parent_job=parent_job) + mod, meta = GPUCompiler.codegen( + :llvm, + primal_job; + optimize = false, + toplevel = toplevel, + cleanup = false, + validate = false, + parent_job = parent_job, + ) prepare_llvm(mod, primal_job, meta) for f in functions(mod) permit_inlining!(f) end LLVM.ModulePassManager() do pm - API.AddPreserveNVVMPass!(pm, #=Begin=#true) + API.AddPreserveNVVMPass!(pm, true) #=Begin=# LLVM.run!(pm, mod) end @@ -5510,34 +7103,63 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; disableFallback = String[] - ForwardModeDerivatives = ("nrm2","dot","gemm","gemv","axpy","copy","scal", "symm", "syrk", "potrf") - ReverseModeDerivatives = ("nrm2","dot","gemm","gemv","axpy","copy","scal", "symm", "trmv", "syrk", "trmm", "trsm", "potrf") + ForwardModeDerivatives = + ("nrm2", "dot", "gemm", "gemv", "axpy", "copy", "scal", "symm", "syrk", "potrf") + ReverseModeDerivatives = ( + "nrm2", + "dot", + "gemm", + "gemv", + "axpy", + "copy", + "scal", + "symm", + "trmv", + "syrk", + "trmm", + "trsm", + "potrf", + ) ForwardModeTypes = ("s", "d", "c", "z") ReverseModeTypes = ("s", "d") # Tablegen BLAS does not support forward mode yet if !(mode == API.DEM_ForwardMode && params.runtimeActivity) for ty in (mode == API.DEM_ForwardMode ? ForwardModeTypes : ReverseModeTypes) - for func in (mode == API.DEM_ForwardMode ? ForwardModeDerivatives : ReverseModeDerivatives) + for func in ( + mode == API.DEM_ForwardMode ? ForwardModeDerivatives : + ReverseModeDerivatives + ) for prefix in ("", "cblas_") for ending in ("", "_", "64_", "_64_") - push!(disableFallback, prefix*ty*func*ending) + push!(disableFallback, prefix * ty * func * ending) end end end end end found = String[] - if bitcode_replacement() && API.EnzymeBitcodeReplacement(mod, disableFallback, found) != 0 + if bitcode_replacement() && + API.EnzymeBitcodeReplacement(mod, disableFallback, found) != 0 ModulePassManager() do pm instruction_combining!(pm) LLVM.run!(pm, mod) end toremove = [] for f in functions(mod) - if !any(map(k->kind(k)==kind(EnumAttribute("alwaysinline")), collect(function_attributes(f)))) + if !any( + map( + k -> kind(k) == kind(EnumAttribute("alwaysinline")), + collect(function_attributes(f)), + ), + ) continue end - if !any(map(k->kind(k)==kind(EnumAttribute("returns_twice")), collect(function_attributes(f)))) + if !any( + map( + k -> kind(k) == kind(EnumAttribute("returns_twice")), + collect(function_attributes(f)), + ), + ) push!(function_attributes(f), EnumAttribute("returns_twice")) push!(toremove, name(f)) end @@ -5578,7 +7200,14 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; for fname in toremove if haskey(functions(mod), fname) f = functions(mod)[fname] - LLVM.API.LLVMRemoveEnumAttributeAtIndex(f, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), kind(EnumAttribute("returns_twice"))) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + f, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + kind(EnumAttribute("returns_twice")), + ) end end GPUCompiler.@safe_warn "Using fallback BLAS replacements for ($found), performance may be degraded" @@ -5587,16 +7216,16 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; LLVM.run!(pm, mod) end end - + for f in functions(mod) mi, RT = enzyme_custom_extract_mi(f, false) if mi === nothing continue end - llRT, sret, returnRoots = get_return_info(RT) + llRT, sret, returnRoots = get_return_info(RT) retRemoved, parmsRemoved = removed_ret_parms(f) - + dl = string(LLVM.datalayout(LLVM.parent(f))) expectLen = (sret !== nothing) + (returnRoots !== nothing) @@ -5604,11 +7233,18 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) continue end - expectLen+=1 + expectLen += 1 end expectLen -= length(parmsRemoved) - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(f, i)))) for i in 1:length(collect(parameters(f)))) + swiftself = any( + any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(f, i)), + ), + ) for i = 1:length(collect(parameters(f))) + ) if swiftself expectLen += 1 @@ -5652,7 +7288,8 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; push!( parameter_attributes(f, arg.codegen.i), StringAttribute( - "enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))) + "enzymejl_parmtype", + string(convert(UInt, unsafe_to_pointer(arg.typ))), ), ) push!( @@ -5705,7 +7342,8 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end end - if llRT !== nothing && LLVM.return_type(LLVM.function_type(f)) != LLVM.VoidType() + if llRT !== nothing && + LLVM.return_type(LLVM.function_type(f)) != LLVM.VoidType() @assert !retRemoved rest = typetree(llRT, ctx, dl) push!(return_attributes(f), StringAttribute("enzyme_type", string(rest))) @@ -5726,7 +7364,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; actualRetType = nothing lowerConvention = true customDerivativeNames = String[] - fnsToInject = Tuple{Symbol, Type}[] + fnsToInject = Tuple{Symbol,Type}[] for (mi, k) in meta.compiled k_name = GPUCompiler.safe_name(k.specfunc) has_custom_rule = false @@ -5735,12 +7373,14 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; caller = mi if mode == API.DEM_ForwardMode - has_custom_rule = EnzymeRules.has_frule_from_sig(specTypes; world, method_table, caller) + has_custom_rule = + EnzymeRules.has_frule_from_sig(specTypes; world, method_table, caller) if has_custom_rule @safe_debug "Found frule for" mi.specTypes end else - has_custom_rule = EnzymeRules.has_rrule_from_sig(specTypes; world, method_table, caller) + has_custom_rule = + EnzymeRules.has_rrule_from_sig(specTypes; world, method_table, caller) if has_custom_rule @safe_debug "Found rrule for" mi.specTypes end @@ -5754,7 +7394,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if llvmfn == primalf actualRetType = k.ci.rettype end - + if EnzymeRules.noalias_from_sig(mi.specTypes; world, method_table, caller) push!(return_attributes(llvmfn), EnumAttribute("noalias")) for u in LLVM.uses(llvmfn) @@ -5764,22 +7404,26 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end cf = LLVM.called_operand(c) if cf == llvmfn - LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0)) + LLVM.API.LLVMAddCallSiteAttribute( + c, + LLVM.API.LLVMAttributeReturnIndex, + LLVM.EnumAttribute("noalias", 0), + ) end end end func = mi.specTypes.parameters[1] - + meth = mi.def name = meth.name - jlmod = meth.module + jlmod = meth.module - function handleCustom(llvmfn, name, attrs=[], setlink=true, noinl=true) + function handleCustom(llvmfn, name, attrs = [], setlink = true, noinl = true) attributes = function_attributes(llvmfn) custom[k_name] = linkage(llvmfn) if setlink - linkage!(llvmfn, LLVM.API.LLVMExternalLinkage) + linkage!(llvmfn, LLVM.API.LLVMExternalLinkage) end for a in attrs push!(attributes, a) @@ -5794,91 +7438,147 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; julia_activity_rule(llvmfn) if has_custom_rule - handleCustom(llvmfn, "enzyme_custom", [StringAttribute("enzyme_preserve_primal", "*")]) + handleCustom( + llvmfn, + "enzyme_custom", + [StringAttribute("enzyme_preserve_primal", "*")], + ) continue end sparam_vals = mi.specTypes.parameters[2:end] # mi.sparam_vals - if func == typeof(Base.eps) || func == typeof(Base.nextfloat) || func == typeof(Base.prevfloat) + if func == typeof(Base.eps) || + func == typeof(Base.nextfloat) || + func == typeof(Base.prevfloat) if LLVM.version().major <= 15 - handleCustom(llvmfn, "jl_inactive_inout", [StringAttribute("enzyme_inactive"), - EnumAttribute("readnone"), - EnumAttribute("speculatable"), - StringAttribute("enzyme_shouldrecompute") - ]) + handleCustom( + llvmfn, + "jl_inactive_inout", + [ + StringAttribute("enzyme_inactive"), + EnumAttribute("readnone"), + EnumAttribute("speculatable"), + StringAttribute("enzyme_shouldrecompute"), + ], + ) else - handleCustom(llvmfn, "jl_inactive_inout", [StringAttribute("enzyme_inactive"), - EnumAttribute("memory", NoEffects.data), - EnumAttribute("speculatable"), - StringAttribute("enzyme_shouldrecompute") - ]) + handleCustom( + llvmfn, + "jl_inactive_inout", + [ + StringAttribute("enzyme_inactive"), + EnumAttribute("memory", NoEffects.data), + EnumAttribute("speculatable"), + StringAttribute("enzyme_shouldrecompute"), + ], + ) end continue end if func == typeof(Base.to_tuple_type) if LLVM.version().major <= 15 - handleCustom(llvmfn, "jl_to_tuple_type", - [EnumAttribute("readonly"), + handleCustom( + llvmfn, + "jl_to_tuple_type", + [ + EnumAttribute("readonly"), EnumAttribute("inaccessiblememonly", 0), EnumAttribute("speculatable", 0), StringAttribute("enzyme_shouldrecompute"), StringAttribute("enzyme_inactive"), - ]) + ], + ) else - handleCustom(llvmfn, "jl_to_tuple_type", - [ - EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data), + handleCustom( + llvmfn, + "jl_to_tuple_type", + [ + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_Ref << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), EnumAttribute("inaccessiblememonly", 0), EnumAttribute("speculatable", 0), StringAttribute("enzyme_shouldrecompute"), StringAttribute("enzyme_inactive"), - ]) + ], + ) end continue end if func == typeof(Base.mightalias) if LLVM.version().major <= 15 - handleCustom(llvmfn, "jl_mightalias", - [EnumAttribute("readonly"), + handleCustom( + llvmfn, + "jl_mightalias", + [ + EnumAttribute("readonly"), StringAttribute("enzyme_shouldrecompute"), StringAttribute("enzyme_inactive"), StringAttribute("enzyme_no_escaping_allocation"), EnumAttribute("nofree"), StringAttribute("enzyme_ta_norecur"), - ], true, false) + ], + true, + false, + ) else - handleCustom(llvmfn, "jl_mightalias", - [ + handleCustom( + llvmfn, + "jl_mightalias", + [ EnumAttribute("memory", ReadOnlyEffects.data), StringAttribute("enzyme_shouldrecompute"), StringAttribute("enzyme_inactive"), StringAttribute("enzyme_no_escaping_allocation"), EnumAttribute("nofree"), StringAttribute("enzyme_ta_norecur"), - ], true, false) + ], + true, + false, + ) end continue end if func == typeof(Base.Threads.threadid) || func == typeof(Base.Threads.nthreads) name = (func == typeof(Base.Threads.threadid)) ? "jl_threadid" : "jl_nthreads" if LLVM.version().major <= 15 - handleCustom(llvmfn, name, - [EnumAttribute("readonly"), + handleCustom( + llvmfn, + name, + [ + EnumAttribute("readonly"), EnumAttribute("inaccessiblememonly"), EnumAttribute("speculatable"), StringAttribute("enzyme_shouldrecompute"), StringAttribute("enzyme_inactive"), - StringAttribute("enzyme_no_escaping_allocation") - ]) + StringAttribute("enzyme_no_escaping_allocation"), + ], + ) else - handleCustom(llvmfn, name, - [EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data), + handleCustom( + llvmfn, + name, + [ + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_Ref << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), EnumAttribute("speculatable"), StringAttribute("enzyme_shouldrecompute"), StringAttribute("enzyme_inactive"), - StringAttribute("enzyme_no_escaping_allocation") - ]) + StringAttribute("enzyme_no_escaping_allocation"), + ], + ) end continue end @@ -5889,45 +7589,143 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if func == typeof(Base.Checked.throw_overflowerr_binaryop) llvmfn = functions(mod)[k.specfunc] if LLVM.version().major <= 15 - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("readonly"), StringAttribute("enzyme_ta_norecur")]) + handleCustom( + llvmfn, + "enz_noop", + [ + StringAttribute("enzyme_inactive"), + EnumAttribute("readonly"), + StringAttribute("enzyme_ta_norecur"), + ], + ) else - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), - EnumAttribute("memory", ReadOnlyEffects.data), - StringAttribute("enzyme_ta_norecur")]) + handleCustom( + llvmfn, + "enz_noop", + [ + StringAttribute("enzyme_inactive"), + EnumAttribute("memory", ReadOnlyEffects.data), + StringAttribute("enzyme_ta_norecur"), + ], + ) end continue end - if EnzymeRules.is_inactive_from_sig(specTypes; world, method_table, caller) && has_method(Tuple{typeof(EnzymeRules.inactive), specTypes.parameters...}, world, method_table) - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation"), StringAttribute("enzyme_ta_norecur")]) + if EnzymeRules.is_inactive_from_sig(specTypes; world, method_table, caller) && + has_method( + Tuple{typeof(EnzymeRules.inactive),specTypes.parameters...}, + world, + method_table, + ) + handleCustom( + llvmfn, + "enz_noop", + [ + StringAttribute("enzyme_inactive"), + EnumAttribute("nofree"), + StringAttribute("enzyme_no_escaping_allocation"), + StringAttribute("enzyme_ta_norecur"), + ], + ) continue end - if EnzymeRules.is_inactive_noinl_from_sig(specTypes; world, method_table, caller) && has_method(Tuple{typeof(EnzymeRules.inactive_noinl), specTypes.parameters...}, world, method_table) - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation"), StringAttribute("enzyme_ta_norecur")], false, false) + if EnzymeRules.is_inactive_noinl_from_sig(specTypes; world, method_table, caller) && + has_method( + Tuple{typeof(EnzymeRules.inactive_noinl),specTypes.parameters...}, + world, + method_table, + ) + handleCustom( + llvmfn, + "enz_noop", + [ + StringAttribute("enzyme_inactive"), + EnumAttribute("nofree"), + StringAttribute("enzyme_no_escaping_allocation"), + StringAttribute("enzyme_ta_norecur"), + ], + false, + false, + ) for bb in blocks(llvmfn) for inst in instructions(bb) if isa(inst, LLVM.CallInst) - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("no_escaping_allocation")) - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("enzyme_inactive")) - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), EnumAttribute("nofree")) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + StringAttribute("no_escaping_allocation"), + ) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + StringAttribute("enzyme_inactive"), + ) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + EnumAttribute("nofree"), + ) end end end continue end if func === typeof(Base.match) - handleCustom(llvmfn, "base_match", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation")], false, false) + handleCustom( + llvmfn, + "base_match", + [ + StringAttribute("enzyme_inactive"), + EnumAttribute("nofree"), + StringAttribute("enzyme_no_escaping_allocation"), + ], + false, + false, + ) for bb in blocks(llvmfn) for inst in instructions(bb) if isa(inst, LLVM.CallInst) - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("no_escaping_allocation")) - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("enzyme_inactive")) - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), EnumAttribute("nofree")) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + StringAttribute("no_escaping_allocation"), + ) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + StringAttribute("enzyme_inactive"), + ) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + EnumAttribute("nofree"), + ) end end end continue end - if func == typeof(Base.enq_work) && length(sparam_vals) == 1 && first(sparam_vals) <: Task + if func == typeof(Base.enq_work) && + length(sparam_vals) == 1 && + first(sparam_vals) <: Task handleCustom(llvmfn, "jl_enq_work", [StringAttribute("enzyme_ta_norecur")]) continue end @@ -5943,7 +7741,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end continue end - + name, toinject, T = find_math_method(func, sparam_vals) if name === nothing continue @@ -5956,17 +7754,25 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; # If sret, force lower of primitive math fn sret = get_return_info(k.ci.rettype)[2] !== nothing if sret - cur = llvmfn == primalf - llvmfn, _, boxedArgs, loweredArgs = lower_convention(mi.specTypes, mod, llvmfn, k.ci.rettype, Duplicated, nothing, params.run_enzyme) - if cur - primalf = llvmfn - lowerConvention = false - end - k_name = LLVM.name(llvmfn) + cur = llvmfn == primalf + llvmfn, _, boxedArgs, loweredArgs = lower_convention( + mi.specTypes, + mod, + llvmfn, + k.ci.rettype, + Duplicated, + nothing, + params.run_enzyme, + ) + if cur + primalf = llvmfn + lowerConvention = false + end + k_name = LLVM.name(llvmfn) end name = string(name) - name = T == Float32 ? name*"f" : name + name = T == Float32 ? name * "f" : name attrs = if LLVM.version().major <= 15 [LLVM.EnumAttribute("readnone"), StringAttribute("enzyme_shouldrecompute")] @@ -5985,13 +7791,18 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; llvmfn = primalf FT = LLVM.function_type(llvmfn) - wrapper_f = LLVM.Function(mod, safe_name(LLVM.name(llvmfn)*"mustwrap"), FT) + wrapper_f = LLVM.Function(mod, safe_name(LLVM.name(llvmfn) * "mustwrap"), FT) let builder = IRBuilder() entry = BasicBlock(wrapper_f, "entry") position!(builder, entry) - res = call!(builder, LLVM.function_type(llvmfn), llvmfn, collect(parameters(wrapper_f))) + res = call!( + builder, + LLVM.function_type(llvmfn), + llvmfn, + collect(parameters(wrapper_f)), + ) sretkind = kind(if LLVM.version().major >= 12 TypeAttribute("sret", LLVM.Int32Type()) @@ -6001,7 +7812,11 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; for idx in length(collect(parameters(llvmfn))) for attr in collect(parameter_attributes(llvmfn, idx)) if kind(attr) == sretkind - LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(idx), attr) + LLVM.API.LLVMAddCallSiteAttribute( + res, + LLVM.API.LLVMAttributeIndex(idx), + attr, + ) end end end @@ -6017,8 +7832,14 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; attributes = function_attributes(wrapper_f) push!(attributes, StringAttribute("enzymejl_world", string(job.world))) mi, rt = enzyme_custom_extract_mi(primalf) - push!(attributes, StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi))))) - push!(attributes, StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(rt))))) + push!( + attributes, + StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi)))), + ) + push!( + attributes, + StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(rt)))), + ) primalf = wrapper_f end @@ -6027,10 +7848,18 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; primalf, returnRoots = primalf, false - if lowerConvention - primalf, returnRoots, boxedArgs, loweredArgs = lower_convention(source_sig, mod, primalf, actualRetType, job.config.params.rt, TT, params.run_enzyme) + if lowerConvention + primalf, returnRoots, boxedArgs, loweredArgs = lower_convention( + source_sig, + mod, + primalf, + actualRetType, + job.config.params.rt, + TT, + params.run_enzyme, + ) end - + if primal_job.config.target isa GPUCompiler.NativeCompilerTarget target_machine = JIT.get_tm() else @@ -6042,13 +7871,13 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; device_module = false if parent_job !== nothing if parent_job.config.target isa GPUCompiler.PTXCompilerTarget || - parent_job.config.target isa GPUCompiler.GCNCompilerTarget || - parent_job.config.target isa GPUCompiler.MetalCompilerTarget + parent_job.config.target isa GPUCompiler.GCNCompilerTarget || + parent_job.config.target isa GPUCompiler.MetalCompilerTarget parallel = true device_module = true end if parent_job.config.target isa GPUCompiler.GCNCompilerTarget || - parent_job.config.target isa GPUCompiler.MetalCompilerTarget + parent_job.config.target isa GPUCompiler.MetalCompilerTarget process_module = true end end @@ -6074,7 +7903,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if process_module GPUCompiler.optimize_module!(parent_job, mod) end - + for name in ("gpu_report_exception", "report_exception") if haskey(functions(mod), name) exc = functions(mod)[name] @@ -6092,42 +7921,51 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; for f in functions(mod), bb in blocks(f), inst in instructions(bb) fn = isa(inst, LLVM.CallInst) ? LLVM.called_operand(inst) : nothing - if !API.HasFromStack(inst) && isa(inst, LLVM.CallInst) && (!isa(fn, LLVM.Function) || isempty(blocks(fn))) - legal, source_typ = abs_typeof(inst) + if !API.HasFromStack(inst) && + isa(inst, LLVM.CallInst) && + (!isa(fn, LLVM.Function) || isempty(blocks(fn))) + legal, source_typ, byref = abs_typeof(inst) codegen_typ = value_type(inst) if legal typ = if codegen_typ isa LLVM.PointerType - llvm_source_typ = convert(LLVMType, source_typ; allow_boxed=true) + llvm_source_typ = convert(LLVMType, source_typ; allow_boxed = true) # pointers are used for multiple kinds of arguments # - literal pointer values if source_typ <: Ptr || source_typ <: Core.LLVMPtr source_typ - elseif llvm_source_typ isa LLVM.PointerType - #if llvm_source_typ != codegen_typ - # throw(AssertionError("llvmtype ($llvm_source_typ) is not codegen_typ ($codegen_typ), source_typ = $source_typ within $(string(inst))")) - #end - # push!(args, (cc=MUT_REF, typ=source_typ, name=source_name, idx=codegen_i)) + elseif byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF Ptr{source_typ} - # - references to aggregates else - @assert llvm_source_typ != codegen_typ - # push!(args, (cc=BITS_REF, typ=source_typ, name=source_name, idx=codegen_i)) - Ptr{source_typ} + println(string(mod)) + @show legal, source_typ, byref, llvm_source_typ, codegen_typ, string(inst) + @assert false end else source_typ end if isa(inst, LLVM.CallInst) - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_type", string(typetree(typ, ctx, dl, seen)))) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + LLVM.API.LLVMAttributeReturnIndex, + StringAttribute( + "enzyme_type", + string(typetree(typ, ctx, dl, seen)), + ), + ) else metadata(inst)["enzyme_type"] = to_md(typetree(typ, ctx, dl, seen), ctx) end elseif codegen_typ == T_prjlvalue if isa(inst, LLVM.CallInst) - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_type", "{[-1]:Pointer}")) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + LLVM.API.LLVMAttributeReturnIndex, + StringAttribute("enzyme_type", "{[-1]:Pointer}"), + ) else - metadata(inst)["enzyme_type"] = to_md(typetree(Ptr{Cvoid}, ctx, dl, seen), ctx) + metadata(inst)["enzyme_type"] = + to_md(typetree(Ptr{Cvoid}, ctx, dl, seen), ctx) end end end @@ -6139,20 +7977,36 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if length(blocks(fn)) != 0 continue end - + intr = LLVM.API.LLVMGetIntrinsicID(fn) - if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id || intr == LLVM.Intrinsic("llvm.memset").id - legal, jTy = abs_typeof(operands(inst)[1]) - sz = if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id - operands(inst)[3] - else - operands(inst)[3] - end + if intr == LLVM.Intrinsic("llvm.memcpy").id || + intr == LLVM.Intrinsic("llvm.memmove").id || + intr == LLVM.Intrinsic("llvm.memset").id + legal, jTy, byref = abs_typeof(operands(inst)[1]) + sz = + if intr == LLVM.Intrinsic("llvm.memcpy").id || + intr == LLVM.Intrinsic("llvm.memmove").id + operands(inst)[3] + else + operands(inst)[3] + end if legal && Base.isconcretetype(jTy) - if !(jTy isa UnionAll || jTy isa Union || jTy == Union{} || jTy === Tuple || (is_concrete_tuple(jTy) && any(T2 isa Core.TypeofVararg for T2 in jTy.parameters))) + if !( + jTy isa UnionAll || + jTy isa Union || + jTy == Union{} || + jTy === Tuple || + ( + is_concrete_tuple(jTy) && + any(T2 isa Core.TypeofVararg for T2 in jTy.parameters) + ) + ) if isa(sz, LLVM.ConstantInt) && sizeof(jTy) == convert(Int, sz) - metadata(inst)["enzyme_truetype"] = to_fullmd(jTy) + md = to_fullmd(jTy) + @assert byref == GPUCompiler.BITS_REF || + byref == GPUCompiler.MUT_REF + metadata(inst)["enzyme_truetype"] = md end end end @@ -6164,15 +8018,19 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; continue end - legal, jTy = abs_typeof(inst, true) + legal, jTy, byref = abs_typeof(inst, true) if !legal continue end if !guaranteed_const_nongen(jTy, world) continue - end + end if isa(inst, LLVM.CallInst) - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_inactive")) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + LLVM.API.LLVMAttributeReturnIndex, + StringAttribute("enzyme_inactive"), + ) else metadata(inst)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end @@ -6186,10 +8044,17 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; Ty = eltype(FT) reg = active_reg_inner(Ty, (), world) if reg == DupState || reg == MixedState - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(primalf, i)))) for i in 1:length(collect(parameters(primalf)))) - todo = LLVM.Value[parameters(primalf)[1+swiftself]] - done = Set{LLVM.Value}() - doneInst = Set{LLVM.Instruction}() + swiftself = any( + any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(primalf, i)), + ), + ) for i = 1:length(collect(parameters(primalf))) + ) + todo = LLVM.Value[parameters(primalf)[1+swiftself]] + done = Set{LLVM.Value}() + doneInst = Set{LLVM.Instruction}() while length(todo) != 0 cur = pop!(todo) if cur in done @@ -6206,7 +8071,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end if !mayWriteToMemory(user) - slegal , foundv = abs_typeof(user) + slegal, foundv, byref = abs_typeof(user) if slegal reg2 = active_reg_inner(foundv, (), world) if reg2 == ActiveState || reg2 == AnyState @@ -6221,7 +8086,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; # we are capturing the variable if operands(user)[1] == cur base = operands(user)[2] - while isa(base, LLVM.BitCastInst) || isa(base, LLVM.AddrSpaceCastInst) || isa(base, LLVM.GetElementPtrInst) + while isa(base, LLVM.BitCastInst) || + isa(base, LLVM.AddrSpaceCastInst) || + isa(base, LLVM.GetElementPtrInst) base = operands(base)[1] end if isa(base, LLVM.AllocaInst) @@ -6232,7 +8099,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end # we are storing into the variable if operands(user)[2] == cur - slegal , foundv = abs_typeof(operands(user)[1]) + slegal, foundv, byref = abs_typeof(operands(user)[1]) if slegal reg2 = active_reg_inner(foundv, (), world) if reg2 == AnyState @@ -6253,13 +8120,16 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end nm = LLVM.name(called) - if nm == "ijl_alloc_array_1d" || nm == "jl_alloc_array_1d" || - nm == "ijl_alloc_array_2d" || nm == "jl_alloc_array_2d" || - nm == "ijl_alloc_array_3d" || nm == "jl_alloc_array_3d" + if nm == "ijl_alloc_array_1d" || + nm == "jl_alloc_array_1d" || + nm == "ijl_alloc_array_2d" || + nm == "jl_alloc_array_2d" || + nm == "ijl_alloc_array_3d" || + nm == "jl_alloc_array_3d" continue end if is_readonly(called) - slegal , foundv = abs_typeof(user) + slegal, foundv, byref = abs_typeof(user) if slegal reg2 = active_reg_inner(foundv, (), world) if reg2 == ActiveState || reg2 == AnyState @@ -6269,13 +8139,15 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; push!(todo, user) continue end - if !isempty(blocks(called)) && length(collect(LLVM.uses(called))) == 1 - for (parm, op) in zip(LLVM.parameters(called), operands(user)[1:end-1]) + if !isempty(blocks(called)) && + length(collect(LLVM.uses(called))) == 1 + for (parm, op) in + zip(LLVM.parameters(called), operands(user)[1:end-1]) if op == cur push!(todo, parm) end end - slegal , foundv = abs_typeof(user) + slegal, foundv, byref = abs_typeof(user) if slegal reg2 = active_reg_inner(foundv, (), world) if reg2 == ActiveState || reg2 == AnyState @@ -6290,10 +8162,14 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; builder = LLVM.IRBuilder() position!(builder, user) - resstr = "Function argument passed to autodiff cannot be proven readonly.\nIf the the function argument cannot contain derivative data, instead call autodiff(Mode, Const(f), ...)\nSee https://enzyme.mit.edu/index.fcgi/julia/stable/faq/#Activity-of-temporary-storage for more information.\nThe potentially writing call is "*string(user)*", using "*string(cur) - slegal , foundv = absint(cur) + resstr = + "Function argument passed to autodiff cannot be proven readonly.\nIf the the function argument cannot contain derivative data, instead call autodiff(Mode, Const(f), ...)\nSee https://enzyme.mit.edu/index.fcgi/julia/stable/faq/#Activity-of-temporary-storage for more information.\nThe potentially writing call is " * + string(user) * + ", using " * + string(cur) + slegal, foundv = absint(cur) if slegal - resstr *= "of type "*string(foundv) + resstr *= "of type " * string(foundv) end emit_error(builder, user, resstr, EnzymeMutabilityException) end @@ -6339,7 +8215,10 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; cf = LLVM.called_operand(tmp) if isa(cf, LLVM.Function) nm = LLVM.name(cf) - if nm == "gpu_signal_exception" || nm == "gpu_report_exception" || nm == "ijl_throw" || nm == "jl_throw" + if nm == "gpu_signal_exception" || + nm == "gpu_report_exception" || + nm == "ijl_throw" || + nm == "jl_throw" shouldemit = false break end @@ -6350,14 +8229,28 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if shouldemit b = IRBuilder() position!(b, term) - emit_error(b, term, "Enzyme: The original primal code hits this error condition, thus differentiating it does not make sense") + emit_error( + b, + term, + "Enzyme: The original primal code hits this error condition, thus differentiating it does not make sense", + ) end end end - if !any(map(k->kind(k)==kind(EnumAttribute("alwaysinline")), collect(function_attributes(f)))) + if !any( + map( + k -> kind(k) == kind(EnumAttribute("alwaysinline")), + collect(function_attributes(f)), + ), + ) continue end - if !any(map(k->kind(k)==kind(EnumAttribute("returns_twice")), collect(function_attributes(f)))) + if !any( + map( + k -> kind(k) == kind(EnumAttribute("returns_twice")), + collect(function_attributes(f)), + ), + ) push!(function_attributes(f), EnumAttribute("returns_twice")) push!(toremove, name(f)) end @@ -6369,7 +8262,14 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; for fname in toremove if haskey(functions(mod), fname) f = functions(mod)[fname] - LLVM.API.LLVMRemoveEnumAttributeAtIndex(f, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), kind(EnumAttribute("returns_twice"))) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + f, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + kind(EnumAttribute("returns_twice")), + ) end end else @@ -6378,86 +8278,192 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end LLVM.ModulePassManager() do pm - API.AddPreserveNVVMPass!(pm, #=Begin=#false) + API.AddPreserveNVVMPass!(pm, false) #=Begin=# LLVM.run!(pm, mod) end if parent_job !== nothing if parent_job.config.target isa GPUCompiler.PTXCompilerTarget - arg1 = ("sin", "cos", "tan", "log2", "exp", "exp2", - "exp10", "cosh", "sinh", "tanh", "atan", - "asin", "acos", "log", "log10", "log1p", "acosh", - "asinh", "atanh", "expm1", "cbrt", - "rcbrt", "j0", "j1", "y0", "y1", - "erf", "erfinv", "erfc", "erfcx", "erfcinv", - "remquo", "tgamma", - "round", "fdim", "logb", "isinf", - "sqrt", "fabs", "atan2", ) - # isinf, finite "modf", "fmod", "remainder", - # "rnorm3d", "norm4d", "rnorm4d", "norm", "rnorm", - # "hypot", "rhypot", - # "yn", "jn", "norm3d", "ilogb", powi - # "normcdfinv", "normcdf", "lgamma", "ldexp", "scalbn", "frexp", - # arg1 = ("atan2", "fmax", "pow") - for n in arg1, (T, pf, lpf) in ((LLVM.DoubleType(), "", "f64"), (LLVM.FloatType(), "f", "f32")) - fname = "__nv_"*n*pf - if !haskey(functions(mod), fname) - FT = LLVM.FunctionType(T, [T], vararg=false) - wrapper_f = LLVM.Function(mod, fname, FT) - llname = "llvm."*n*"."*lpf - push!(function_attributes(wrapper_f), StringAttribute("implements", llname)) - end - end - end + arg1 = ( + "sin", + "cos", + "tan", + "log2", + "exp", + "exp2", + "exp10", + "cosh", + "sinh", + "tanh", + "atan", + "asin", + "acos", + "log", + "log10", + "log1p", + "acosh", + "asinh", + "atanh", + "expm1", + "cbrt", + "rcbrt", + "j0", + "j1", + "y0", + "y1", + "erf", + "erfinv", + "erfc", + "erfcx", + "erfcinv", + "remquo", + "tgamma", + "round", + "fdim", + "logb", + "isinf", + "sqrt", + "fabs", + "atan2", + ) + # isinf, finite "modf", "fmod", "remainder", + # "rnorm3d", "norm4d", "rnorm4d", "norm", "rnorm", + # "hypot", "rhypot", + # "yn", "jn", "norm3d", "ilogb", powi + # "normcdfinv", "normcdf", "lgamma", "ldexp", "scalbn", "frexp", + # arg1 = ("atan2", "fmax", "pow") + for n in arg1, + (T, pf, lpf) in + ((LLVM.DoubleType(), "", "f64"), (LLVM.FloatType(), "f", "f32")) + + fname = "__nv_" * n * pf + if !haskey(functions(mod), fname) + FT = LLVM.FunctionType(T, [T], vararg = false) + wrapper_f = LLVM.Function(mod, fname, FT) + llname = "llvm." * n * "." * lpf + push!( + function_attributes(wrapper_f), + StringAttribute("implements", llname), + ) + end + end + end if parent_job.config.target isa GPUCompiler.GCNCompilerTarget - arg1 = ("acos", "acosh", "asin", - "asinh", "atan2", "atan", - "atanh", "cbrt", "ceil", - "copysign", "cos", "native_cos", - "cosh", "cospi", "i0", - "i1", "erfc", "erfcinv", - "erfcx", "erf", "erfinv", - "exp10", "native_exp10", "exp2", - "exp", "native_exp", "expm1", - "fabs", "fdim", "floor", - "fma", "fmax", "fmin", - "fmod", "frexp", "hypot", - "ilogb", "isfinite", "isinf", - "isnan", "j0", "j1", - "ldexp", "lgamma", "log10", - "native_log10", "log1p", "log2", - "log2", "logb", "log", - "native_log", "modf", "nearbyint", - "nextafter", "len3", "len4", - "ncdf", "ncdfinv", "pow", - "pown", "rcbrt", "remainder", - "remquo", "rhypot", "rint", - "rlen3", "rlen4", "round", - "rsqrt", "scalb", "scalbn", - "signbit", "sincos", "sincospi", - "sin", "native_sin", "sinh", - "sinpi", "sqrt", "native_sqrt", - "tan", "tanh", "tgamma", - "trunc", "y0", "y1") - for n in arg1, (T, pf, lpf) in ((LLVM.DoubleType(), "", "f64"), (LLVM.FloatType(), "f", "f32")) - fname = "__ocml_"*n*"_"*lpf + arg1 = ( + "acos", + "acosh", + "asin", + "asinh", + "atan2", + "atan", + "atanh", + "cbrt", + "ceil", + "copysign", + "cos", + "native_cos", + "cosh", + "cospi", + "i0", + "i1", + "erfc", + "erfcinv", + "erfcx", + "erf", + "erfinv", + "exp10", + "native_exp10", + "exp2", + "exp", + "native_exp", + "expm1", + "fabs", + "fdim", + "floor", + "fma", + "fmax", + "fmin", + "fmod", + "frexp", + "hypot", + "ilogb", + "isfinite", + "isinf", + "isnan", + "j0", + "j1", + "ldexp", + "lgamma", + "log10", + "native_log10", + "log1p", + "log2", + "log2", + "logb", + "log", + "native_log", + "modf", + "nearbyint", + "nextafter", + "len3", + "len4", + "ncdf", + "ncdfinv", + "pow", + "pown", + "rcbrt", + "remainder", + "remquo", + "rhypot", + "rint", + "rlen3", + "rlen4", + "round", + "rsqrt", + "scalb", + "scalbn", + "signbit", + "sincos", + "sincospi", + "sin", + "native_sin", + "sinh", + "sinpi", + "sqrt", + "native_sqrt", + "tan", + "tanh", + "tgamma", + "trunc", + "y0", + "y1", + ) + for n in arg1, + (T, pf, lpf) in + ((LLVM.DoubleType(), "", "f64"), (LLVM.FloatType(), "f", "f32")) + + fname = "__ocml_" * n * "_" * lpf if !haskey(functions(mod), fname) - FT = LLVM.FunctionType(T, [T], vararg=false) + FT = LLVM.FunctionType(T, [T], vararg = false) wrapper_f = LLVM.Function(mod, fname, FT) - llname = "llvm."*n*"."*lpf - push!(function_attributes(wrapper_f), StringAttribute("implements", llname)) + llname = "llvm." * n * "." * lpf + push!( + function_attributes(wrapper_f), + StringAttribute("implements", llname), + ) end end end - end + end for (name, fnty) in fnsToInject - for (T, JT, pf) in ((LLVM.DoubleType(), Float64, ""), (LLVM.FloatType(), Float32, "f")) - fname = String(name)*pf + for (T, JT, pf) in + ((LLVM.DoubleType(), Float64, ""), (LLVM.FloatType(), Float32, "f")) + fname = String(name) * pf if haskey(functions(mod), fname) funcspec = GPUCompiler.methodinstance(fnty, Tuple{JT}, world) llvmf = nested_codegen!(mode, mod, funcspec, world) push!(function_attributes(llvmf), StringAttribute("implements", fname)) end - end + end end API.EnzymeReplaceFunctionImplementation(mod) @@ -6478,7 +8484,8 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end end end - for fname in ["__enzyme_float", "__enzyme_double", "__enzyme_integer", "__enzyme_pointer"] + for fname in + ["__enzyme_float", "__enzyme_double", "__enzyme_integer", "__enzyme_pointer"] haskey(functions(mod), fname) || continue f = functions(mod)[fname] for u in uses(f) @@ -6504,7 +8511,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if parent_job !== nothing reinsert_gcmarker!(adjointf) augmented_primalf !== nothing && reinsert_gcmarker!(augmented_primalf) - post_optimze!(mod, target_machine, #=machine=#false) + post_optimze!(mod, target_machine, false) #=machine=# end adjointf = functions(mod)[adjointf_name] @@ -6525,34 +8532,127 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; use_primal = mode == API.DEM_ReverseModePrimal entry = use_primal ? augmented_primalf : adjointf - return mod, (;adjointf, augmented_primalf, entry, compiled=meta.compiled, TapeType) + return mod, (; adjointf, augmented_primalf, entry, compiled = meta.compiled, TapeType) end # Compiler result -struct CompileResult{AT, PT} +struct CompileResult{AT,PT} adjoint::AT primal::PT TapeType::Type end -@inline (thunk::PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, RT, TT, Width, ReturnPrimal} = -enzyme_call(Val(false), thunk.adjoint, PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) - -@inline (thunk::CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal} = -enzyme_call(Val(false), thunk.adjoint, CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) +@inline (thunk::PrimalErrorThunk{PT,FA,RT,TT,Width,ReturnPrimal})( + fn::FA, + args..., +) where {PT,FA,RT,TT,Width,ReturnPrimal} = enzyme_call( + Val(false), + thunk.adjoint, + PrimalErrorThunk{PT,FA,RT,TT,Width,ReturnPrimal}, + Val(Width), + Val(ReturnPrimal), + TT, + RT, + fn, + Cvoid, + args..., +) -@inline (thunk::ForwardModeThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal} = -enzyme_call(Val(false), thunk.adjoint, ForwardModeThunk{PT, FA, RT, TT, Width, ReturnPrimal}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) +@inline (thunk::CombinedAdjointThunk{PT,FA,RT,TT,Width,ReturnPrimal})( + fn::FA, + args..., +) where {PT,FA,Width,RT,TT,ReturnPrimal} = enzyme_call( + Val(false), + thunk.adjoint, + CombinedAdjointThunk{PT,FA,RT,TT,Width,ReturnPrimal}, + Val(Width), + Val(ReturnPrimal), + TT, + RT, + fn, + Cvoid, + args..., +) -@inline (thunk::AdjointThunk{PT, FA, RT, TT, Width, TapeT})(fn::FA, args...) where {PT, FA, Width, RT, TT, TapeT} = -enzyme_call(Val(false), thunk.adjoint, AdjointThunk{PT, FA, RT, TT, Width, TapeT}, Val(Width), #=ReturnPrimal=#Val(false), TT, RT, fn, TapeT, args...) -@inline raw_enzyme_call(thunk::AdjointThunk{PT, FA, RT, TT, Width, TapeT}, fn::FA, args...) where {PT, FA, Width, RT, TT, TapeT} = -enzyme_call(Val(true), thunk.adjoint, AdjointThunk, Val(Width), #=ReturnPrimal=#Val(false), TT, RT, fn, TapeT, args...) +@inline (thunk::ForwardModeThunk{PT,FA,RT,TT,Width,ReturnPrimal})( + fn::FA, + args..., +) where {PT,FA,Width,RT,TT,ReturnPrimal} = enzyme_call( + Val(false), + thunk.adjoint, + ForwardModeThunk{PT,FA,RT,TT,Width,ReturnPrimal}, + Val(Width), + Val(ReturnPrimal), + TT, + RT, + fn, + Cvoid, + args..., +) -@inline (thunk::AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeT})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal, TapeT} = -enzyme_call(Val(false), thunk.primal, AugmentedForwardThunk, Val(Width), Val(ReturnPrimal), TT, RT, fn, TapeT, args...) -@inline raw_enzyme_call(thunk::AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeT}, fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal, TapeT} = -enzyme_call(Val(true), thunk.primal, AugmentedForwardThunk, Val(Width), Val(ReturnPrimal), TT, RT, fn, TapeT, args...) +@inline (thunk::AdjointThunk{PT,FA,RT,TT,Width,TapeT})( + fn::FA, + args..., +) where {PT,FA,Width,RT,TT,TapeT} = enzyme_call( + Val(false), + thunk.adjoint, + AdjointThunk{PT,FA,RT,TT,Width,TapeT}, + Val(Width), + Val(false), + TT, + RT, + fn, + TapeT, + args..., +) #=ReturnPrimal=# +@inline raw_enzyme_call( + thunk::AdjointThunk{PT,FA,RT,TT,Width,TapeT}, + fn::FA, + args..., +) where {PT,FA,Width,RT,TT,TapeT} = enzyme_call( + Val(true), + thunk.adjoint, + AdjointThunk, + Val(Width), + Val(false), + TT, + RT, + fn, + TapeT, + args..., +) #=ReturnPrimal=# + +@inline (thunk::AugmentedForwardThunk{PT,FA,RT,TT,Width,ReturnPrimal,TapeT})( + fn::FA, + args..., +) where {PT,FA,Width,RT,TT,ReturnPrimal,TapeT} = enzyme_call( + Val(false), + thunk.primal, + AugmentedForwardThunk, + Val(Width), + Val(ReturnPrimal), + TT, + RT, + fn, + TapeT, + args..., +) +@inline raw_enzyme_call( + thunk::AugmentedForwardThunk{PT,FA,RT,TT,Width,ReturnPrimal,TapeT}, + fn::FA, + args..., +) where {PT,FA,Width,RT,TT,ReturnPrimal,TapeT} = enzyme_call( + Val(true), + thunk.primal, + AugmentedForwardThunk, + Val(Width), + Val(ReturnPrimal), + TT, + RT, + fn, + TapeT, + args..., +) function jl_set_typeof(v::Ptr{Cvoid}, T) @@ -6561,7 +8661,7 @@ function jl_set_typeof(v::Ptr{Cvoid}, T) return nothing end -@generated function splatnew(::Type{T}, args::TT) where {T,TT <: Tuple} +@generated function splatnew(::Type{T}, args::TT) where {T,TT<:Tuple} return quote Base.@_inline_meta $(Expr(:splatnew, :T, :args)) @@ -6570,7 +8670,12 @@ end # Recursively return x + f(y), where y is active, otherwise x -@inline function recursive_add(x::T, y::T, f::F=identity, forcelhs::F2=guaranteed_const) where {T, F, F2} +@inline function recursive_add( + x::T, + y::T, + f::F = identity, + forcelhs::F2 = guaranteed_const, +) where {T,F,F2} if forcelhs(T) return x end @@ -6582,31 +8687,41 @@ end end) end -@inline function recursive_add(x::T, y::T, f::F=identity, forcelhs::F2=guaranteed_const) where {T<:AbstractFloat, F, F2} +@inline function recursive_add( + x::T, + y::T, + f::F = identity, + forcelhs::F2 = guaranteed_const, +) where {T<:AbstractFloat,F,F2} if forcelhs(T) return x end return x + f(y) end -@inline function recursive_add(x::T, y::T, f::F=identity, forcelhs::F2=guaranteed_const) where {T<:Complex, F, F2} +@inline function recursive_add( + x::T, + y::T, + f::F = identity, + forcelhs::F2 = guaranteed_const, +) where {T<:Complex,F,F2} if forcelhs(T) return x end return x + f(y) end -@inline mutable_register(::Type{T}) where T <: Integer = true -@inline mutable_register(::Type{T}) where T <: AbstractFloat = false -@inline mutable_register(::Type{Complex{T}}) where T <: AbstractFloat = false -@inline mutable_register(::Type{T}) where T <: Tuple = false -@inline mutable_register(::Type{T}) where T <: NamedTuple = false +@inline mutable_register(::Type{T}) where {T<:Integer} = true +@inline mutable_register(::Type{T}) where {T<:AbstractFloat} = false +@inline mutable_register(::Type{Complex{T}}) where {T<:AbstractFloat} = false +@inline mutable_register(::Type{T}) where {T<:Tuple} = false +@inline mutable_register(::Type{T}) where {T<:NamedTuple} = false @inline mutable_register(::Type{Core.Box}) = true -@inline mutable_register(::Type{T}) where T <: Array = true -@inline mutable_register(::Type{T}) where T = ismutabletype(T) +@inline mutable_register(::Type{T}) where {T<:Array} = true +@inline mutable_register(::Type{T}) where {T} = ismutabletype(T) # Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) -@inline function recursive_accumulate(x::Array{T}, y::Array{T}, f::F=identity) where {T, F} +@inline function recursive_accumulate(x::Array{T}, y::Array{T}, f::F = identity) where {T,F} if !mutable_register(T) for I in eachindex(x) prev = x[I] @@ -6617,16 +8732,16 @@ end # Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) -@inline function recursive_accumulate(x::Core.Box, y::Core.Box, f::F=identity) where {F} +@inline function recursive_accumulate(x::Core.Box, y::Core.Box, f::F = identity) where {F} recursive_accumulate(x.contents, y.contents, seen, f) end -@inline function recursive_accumulate(x::T, y::T, f::F=identity) where {T, F} +@inline function recursive_accumulate(x::T, y::T, f::F = identity) where {T,F} @assert !Base.isabstracttype(T) @assert Base.isconcretetype(T) nf = fieldcount(T) - for i in 1:nf + for i = 1:nf if isdefined(x, i) xi = getfield(x, i) ST = Core.Typeof(xi) @@ -6646,9 +8761,13 @@ end elseif T <: AbstractFloat return one(T) elseif T <: Complex - error("Attempted to use automatic pullback (differential return value) deduction on a either a type unstable function returning an active complex number, or autodiff_deferred returning an active complex number. For the first case, please type stabilize your code, e.g. by specifying autodiff(Reverse, f->f(x)::Complex, ...). For the second case, please use regular non-deferred autodiff") + error( + "Attempted to use automatic pullback (differential return value) deduction on a either a type unstable function returning an active complex number, or autodiff_deferred returning an active complex number. For the first case, please type stabilize your code, e.g. by specifying autodiff(Reverse, f->f(x)::Complex, ...). For the second case, please use regular non-deferred autodiff", + ) else - error("Active return values with automatic pullback (differential return value) deduction only supported for floating-like values and not type $T. If mutable memory, please use Duplicated. Otherwise, you can explicitly specify a pullback by using split mode, e.g. autodiff_thunk(ReverseSplitWithPrimal, ...)") + error( + "Active return values with automatic pullback (differential return value) deduction only supported for floating-like values and not type $T. If mutable memory, please use Duplicated. Otherwise, you can explicitly specify a pullback by using split mode, e.g. autodiff_thunk(ReverseSplitWithPrimal, ...)", + ) end end @@ -6656,56 +8775,72 @@ function add_one_in_place(x) if x isa Base.RefValue x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x)))) else - error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string(x)) + error( + "Enzyme Mutability Error: Cannot add one in place to immutable value " * + string(x), + ) end return nothing end -@generated function enzyme_call(::Val{RawCall}, fptr::PT, ::Type{CC}, ::Val{width}, ::Val{returnPrimal}, tt::Type{T}, - rt::Type{RT}, fn::FA, ::Type{TapeType}, args::Vararg{Any, N}) where {RawCall, PT, FA, T, RT, TapeType, N, CC, width, returnPrimal} +@generated function enzyme_call( + ::Val{RawCall}, + fptr::PT, + ::Type{CC}, + ::Val{width}, + ::Val{returnPrimal}, + tt::Type{T}, + rt::Type{RT}, + fn::FA, + ::Type{TapeType}, + args::Vararg{Any,N}, +) where {RawCall,PT,FA,T,RT,TapeType,N,CC,width,returnPrimal} JuliaContext() do ctx Base.@_inline_meta F = eltype(FA) - is_forward = CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk + is_forward = + CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk is_adjoint = CC <: AdjointThunk || CC <: CombinedAdjointThunk - is_split = CC <: AdjointThunk || CC <: AugmentedForwardThunk + is_split = CC <: AdjointThunk || CC <: AugmentedForwardThunk needs_tape = CC <: AdjointThunk - argtt = tt.parameters[1] - rettype = rt.parameters[1] + argtt = tt.parameters[1] + rettype = rt.parameters[1] argtypes = DataType[argtt.parameters...] - argexprs = Union{Expr, Symbol}[:(args[$i]) for i in 1:N] + argexprs = Union{Expr,Symbol}[:(args[$i]) for i = 1:N] if false && CC <: PrimalErrorThunk - primargs = [quote - convert($(eltype(T)), $(argexprs[i]).val) - end for (i, T) in enumerate(argtypes)] + primargs = [ + quote + convert($(eltype(T)), $(argexprs[i]).val) + end for (i, T) in enumerate(argtypes) + ] return quote fn.val($(primargs...)) - error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up") + error( + "Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up", + ) end end if !RawCall && !(CC <: PrimalErrorThunk) - if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + if rettype <: Active || + rettype <: MixedDuplicated || + rettype <: BatchMixedDuplicated if length(argtypes) + is_adjoint + needs_tape != length(argexprs) return quote - @show $width - @show $(length(argtypes)), $is_adjoint, $needs_tape, $(length(argexprs)) - @show $argtypes - @show $argexprs throw(MethodError($CC(fptr), (fn, args...))) end end elseif rettype <: Const - if length(argtypes) + needs_tape != length(argexprs) + if length(argtypes) + needs_tape != length(argexprs) return quote throw(MethodError($CC(fptr), (fn, args...))) end end else - if length(argtypes) + needs_tape != length(argexprs) + if length(argtypes) + needs_tape != length(argexprs) return quote throw(MethodError($CC(fptr), (fn, args...))) end @@ -6715,14 +8850,18 @@ end types = DataType[] - if !(rettype <: Const) && (isghostty(eltype(rettype)) || Core.Compiler.isconstType(eltype(rettype)) || eltype(rettype) === DataType) + if !(rettype <: Const) && ( + isghostty(eltype(rettype)) || + Core.Compiler.isconstType(eltype(rettype)) || + eltype(rettype) === DataType + ) rrt = eltype(rettype) error("Return type `$rrt` not marked Const, but is ghost or const type.") end - sret_types = [] # Julia types of all returned variables + sret_types = [] # Julia types of all returned variables # By ref values we create and need to preserve - ccexprs = Union{Expr, Symbol}[] # The expressions passed to the `llvmcall` + ccexprs = Union{Expr,Symbol}[] # The expressions passed to the `llvmcall` if !isghostty(F) && !Core.Compiler.isconstType(F) isboxed = GPUCompiler.deserves_argbox(F) @@ -6745,8 +8884,8 @@ end push!(types, Any) elseif width == 1 push!(types, F) - else - push!(types, NTuple{width, F}) + else + push!(types, NTuple{width,F}) end push!(ccexprs, argexpr) end @@ -6759,7 +8898,7 @@ end source_typ = eltype(T) expr = argexprs[i] - i+=1 + i += 1 if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) @assert T <: Const if is_adjoint @@ -6798,13 +8937,13 @@ end if width == 1 push!(ActiveRetTypes, source_typ) else - push!(ActiveRetTypes, NTuple{width, source_typ}) + push!(ActiveRetTypes, NTuple{width,source_typ}) end end elseif T <: Duplicated || T <: DuplicatedNoNeed if RawCall argexpr = argexprs[i] - i+=1 + i += 1 else argexpr = Expr(:., expr, QuoteNode(:dval)) end @@ -6820,15 +8959,15 @@ end elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed if RawCall argexpr = argexprs[i] - i+=1 + i += 1 else argexpr = Expr(:., expr, QuoteNode(:dval)) end - isboxedvec = GPUCompiler.deserves_argbox(NTuple{width, source_typ}) + isboxedvec = GPUCompiler.deserves_argbox(NTuple{width,source_typ}) if isboxedvec push!(types, Any) else - push!(types, NTuple{width, source_typ}) + push!(types, NTuple{width,source_typ}) end if is_adjoint push!(ActiveRetTypes, Nothing) @@ -6837,7 +8976,7 @@ end elseif T <: MixedDuplicated if RawCall argexpr = argexprs[i] - i+=1 + i += 1 else argexpr = Expr(:., expr, QuoteNode(:dval)) end @@ -6845,19 +8984,20 @@ end if is_adjoint push!(ActiveRetTypes, Nothing) end - push!(ccexprs, argexpr) + push!(ccexprs, argexpr) elseif T <: BatchMixedDuplicated if RawCall argexpr = argexprs[i] - i+=1 + i += 1 else argexpr = Expr(:., expr, QuoteNode(:dval)) end - isboxedvec = GPUCompiler.deserves_argbox(NTuple{width, Base.RefValue{source_typ}}) + isboxedvec = + GPUCompiler.deserves_argbox(NTuple{width,Base.RefValue{source_typ}}) if isboxedvec push!(types, Any) else - push!(types, NTuple{width, Base.RefValue{source_typ}}) + push!(types, NTuple{width,Base.RefValue{source_typ}}) end if is_adjoint push!(ActiveRetTypes, Nothing) @@ -6870,8 +9010,8 @@ end jlRT = eltype(rettype) if typeof(jlRT) == UnionAll - # Future improvement, add type assertion on load - jlRT = DataType + # Future improvement, add type assertion on load + jlRT = DataType end if is_sret_union(jlRT) @@ -6879,20 +9019,22 @@ end end # API.DFT_OUT_DIFF - if is_adjoint - if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + if is_adjoint + if rettype <: Active || + rettype <: MixedDuplicated || + rettype <: BatchMixedDuplicated # TODO handle batch width - if rettype <: Active + if rettype <: Active @assert allocatedinline(jlRT) end j_drT = if width == 1 jlRT else - NTuple{width, jlRT} + NTuple{width,jlRT} end push!(types, j_drT) push!(ccexprs, argexprs[i]) - i+=1 + i += 1 end end @@ -6901,20 +9043,23 @@ end push!(types, TapeType) push!(ccexprs, argexprs[i]) end - i+=1 + i += 1 end if is_adjoint NT = Tuple{ActiveRetTypes...} - if any(any_jltypes(convert(LLVM.LLVMType, b; allow_boxed=true)) for b in ActiveRetTypes) + if any( + any_jltypes(convert(LLVM.LLVMType, b; allow_boxed = true)) for + b in ActiveRetTypes + ) NT = AnonymousStruct(NT) end push!(sret_types, NT) end - + if !(CC <: PrimalErrorThunk) - @assert i == length(argexprs)+1 + @assert i == length(argexprs) + 1 end # Tape @@ -6922,7 +9067,7 @@ end push!(sret_types, TapeType) end - if returnPrimal && !(CC <: ForwardModeThunk) + if returnPrimal && !(CC <: ForwardModeThunk) push!(sret_types, jlRT) end if is_forward @@ -6934,9 +9079,9 @@ end elseif rettype <: MixedDuplicated push!(sret_types, Base.RefValue{jlRT}) elseif rettype <: BatchDuplicated || rettype <: BatchDuplicatedNoNeed - push!(sret_types, AnonymousStruct(NTuple{width, jlRT})) + push!(sret_types, AnonymousStruct(NTuple{width,jlRT})) elseif rettype <: BatchMixedDuplicated - push!(sret_types, AnonymousStruct(NTuple{width, Base.RefValue{jlRT}})) + push!(sret_types, AnonymousStruct(NTuple{width,Base.RefValue{jlRT}})) elseif CC <: AugmentedForwardThunk push!(sret_types, Nothing) elseif rettype <: Const @@ -6946,17 +9091,21 @@ end end end - if returnPrimal && (CC <: ForwardModeThunk) + if returnPrimal && (CC <: ForwardModeThunk) push!(sret_types, jlRT) end # calls fptr - llvmtys = LLVMType[convert(LLVMType, x; allow_boxed=true) for x in types] + llvmtys = LLVMType[convert(LLVMType, x; allow_boxed = true) for x in types] T_void = convert(LLVMType, Nothing) - combinedReturn = (CC <: PrimalErrorThunk && eltype(rettype) == Union{}) ? Union{} : Tuple{sret_types...} - if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types) + combinedReturn = + (CC <: PrimalErrorThunk && eltype(rettype) == Union{}) ? Union{} : + Tuple{sret_types...} + if any( + any_jltypes(convert(LLVM.LLVMType, T; allow_boxed = true)) for T in sret_types + ) combinedReturn = AnonymousStruct(combinedReturn) end uses_sret = is_sret(combinedReturn) @@ -6999,7 +9148,10 @@ end if returnRoots tracked = CountTrackedPointers(jltype) - pushfirst!(callparams, alloca!(builder, LLVM.ArrayType(T_prjlvalue, tracked.count))) + pushfirst!( + callparams, + alloca!(builder, LLVM.ArrayType(T_prjlvalue, tracked.count)), + ) pushfirst!(callparams, alloca!(builder, jltype)) end @@ -7007,7 +9159,11 @@ end tape = callparams[end] if TapeType <: EnzymeTapeToLoad llty = from_tape_type(eltype(TapeType)) - tape = bitcast!(builder, tape, LLVM.PointerType(llty, LLVM.addrspace(value_type(tape)))) + tape = bitcast!( + builder, + tape, + LLVM.PointerType(llty, LLVM.addrspace(value_type(tape))), + ) tape = load!(builder, llty, tape) API.SetMustCache!(tape) callparams[end] = tape @@ -7018,10 +9174,13 @@ end end if !(GPUCompiler.isghosttype(PT) || Core.Compiler.isconstType(PT)) - FT = LLVM.FunctionType(returnRoots ? T_void : T_ret, [value_type(x) for x in callparams]) + FT = LLVM.FunctionType( + returnRoots ? T_void : T_ret, + [value_type(x) for x in callparams], + ) lfn = inttoptr!(builder, lfn, LLVM.PointerType(FT)) else - val_inner(::Type{Val{V}}) where V = V + val_inner(::Type{Val{V}}) where {V} = V submod, subname = val_inner(PT) # TODO, consider optimization # However, julia will optimize after this, so no need @@ -7032,7 +9191,7 @@ end end r = call!(builder, FT, lfn, callparams) - + if returnRoots attr = if LLVM.version().major >= 12 TypeAttribute("sret", jltype) @@ -7049,7 +9208,7 @@ end ret!(builder) end reinsert_gcmarker!(llvm_f) - + ir = string(mod) fn = LLVM.name(llvm_f) @@ -7058,16 +9217,23 @@ end if !(GPUCompiler.isghosttype(PT) || Core.Compiler.isconstType(PT)) return quote Base.@_inline_meta - Base.llvmcall(($ir, $fn), $combinedReturn, - Tuple{$PT, $(types...)}, - fptr, $(ccexprs...)) + Base.llvmcall( + ($ir, $fn), + $combinedReturn, + Tuple{$PT,$(types...)}, + fptr, + $(ccexprs...), + ) end else return quote Base.@_inline_meta - Base.llvmcall(($ir, $fn), $combinedReturn, - Tuple{$(types...)}, - $(ccexprs...)) + Base.llvmcall( + ($ir, $fn), + $combinedReturn, + Tuple{$(types...)}, + $(ccexprs...), + ) end end end @@ -7079,24 +9245,38 @@ end function _link(job, (mod, adjoint_name, primal_name, TapeType)) if job.config.params.ABI <: InlineABI - return CompileResult(Val((Symbol(mod), Symbol(adjoint_name))), Val((Symbol(mod), Symbol(primal_name))), TapeType) + return CompileResult( + Val((Symbol(mod), Symbol(adjoint_name))), + Val((Symbol(mod), Symbol(primal_name))), + TapeType, + ) end # Now invoke the JIT jitted_mod = JIT.add!(mod) adjoint_addr = JIT.lookup(jitted_mod, adjoint_name) - adjoint_ptr = pointer(adjoint_addr) + adjoint_ptr = pointer(adjoint_addr) if adjoint_ptr === C_NULL - throw(GPUCompiler.InternalCompilerError(job, "Failed to compile Enzyme thunk, adjoint not found")) + throw( + GPUCompiler.InternalCompilerError( + job, + "Failed to compile Enzyme thunk, adjoint not found", + ), + ) end if primal_name === nothing primal_ptr = C_NULL else primal_addr = JIT.lookup(jitted_mod, primal_name) - primal_ptr = pointer(primal_addr) + primal_ptr = pointer(primal_addr) if primal_ptr === C_NULL - throw(GPUCompiler.InternalCompilerError(job, "Failed to compile Enzyme thunk, primal not found")) + throw( + GPUCompiler.InternalCompilerError( + job, + "Failed to compile Enzyme thunk, primal not found", + ), + ) end end @@ -7106,8 +9286,8 @@ end const DumpPostOpt = Ref(false) # actual compilation -function _thunk(job, postopt::Bool=true) - mod, meta = codegen(:llvm, job; optimize=false) +function _thunk(job, postopt::Bool = true) + mod, meta = codegen(:llvm, job; optimize = false) adjointf, augmented_primalf = meta.adjointf, meta.augmented_primalf adjoint_name = name(adjointf) @@ -7117,14 +9297,14 @@ function _thunk(job, postopt::Bool=true) else primal_name = nothing end - + LLVM.ModulePassManager() do pm add!(pm, FunctionPass("ReinsertGCMarker", reinsert_gcmarker_pass!)) LLVM.run!(pm, mod) end # Run post optimization pipeline - if postopt + if postopt if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI post_optimze!(mod, JIT.get_tm()) if DumpPostOpt[] @@ -7137,7 +9317,7 @@ function _thunk(job, postopt::Bool=true) return (mod, adjoint_name, primal_name, meta.TapeType) end -const cache = Dict{UInt, CompileResult}() +const cache = Dict{UInt,CompileResult}() const cache_lock = ReentrantLock() @inline function cached_compilation(@nospecialize(job::CompilerJob))::CompileResult @@ -7167,20 +9347,65 @@ end @inline remove_innerty(::Type{<:MixedDuplicated}) = MixedDuplicated @inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated -@inline function thunkbase(ctx, mi::Core.MethodInstance, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}, ::Val{RuntimeActivity}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI, ErrIfFuncWritten, RuntimeActivity} +@inline function thunkbase( + ctx, + mi::Core.MethodInstance, + ::Val{World}, + ::Type{FA}, + ::Type{A}, + tt::Type{TT}, + ::Val{Mode}, + ::Val{width}, + ::Val{ModifiedBetween}, + ::Val{ReturnPrimal}, + ::Val{ShadowInit}, + ::Type{ABI}, + ::Val{ErrIfFuncWritten}, + ::Val{RuntimeActivity}, +) where { + FA<:Annotation, + A<:Annotation, + TT, + Mode, + ModifiedBetween, + width, + ReturnPrimal, + ShadowInit, + World, + ABI, + ErrIfFuncWritten, + RuntimeActivity, +} target = Compiler.EnzymeTarget() - params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI, ErrIfFuncWritten, RuntimeActivity) - tmp_job = if World isa Nothing - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) - else - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) - end + params = Compiler.EnzymeCompilerParams( + Tuple{FA,TT.parameters...}, + Mode, + width, + remove_innerty(A), + true, + true, + ModifiedBetween, + ReturnPrimal, + ShadowInit, + UnknownTapeType, + ABI, + ErrIfFuncWritten, + RuntimeActivity, + ) #=abiwrap=# + tmp_job = if World isa Nothing + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false)) + else + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false), World) + end interp = GPUCompiler.get_interpreter(tmp_job) # TODO check compile return here, early # rrt = Core.Compiler.return_type(f, primal.tt) # nothing - rrt = something(Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), Any) + rrt = something( + Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), + Any, + ) rrt = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype run_enzyme = true @@ -7188,12 +9413,12 @@ end A2 = if rrt == Union{} run_enzyme = false Const - else + else A end - + if run_enzyme && !(A2 <: Const) && guaranteed_const_nongen(rrt, World) - estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant" + estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant" return error(estr) end @@ -7207,13 +9432,27 @@ end # @assert eltype(A) == rrt A2 end - - params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, run_enzyme, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI, ErrIfFuncWritten, RuntimeActivity) - job = if World isa Nothing - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) + + params = Compiler.EnzymeCompilerParams( + Tuple{FA,TT.parameters...}, + Mode, + width, + rt2, + run_enzyme, + true, + ModifiedBetween, + ReturnPrimal, + ShadowInit, + UnknownTapeType, + ABI, + ErrIfFuncWritten, + RuntimeActivity, + ) #=abiwrap=# + job = if World isa Nothing + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false)) else - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) - end + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false), World) + end # We need to use primal as the key, to lookup the right method # but need to mixin the hash of the adjoint to avoid cache collisions # This is counter-intuitive since we would expect the cache to be split @@ -7223,7 +9462,7 @@ end compile_result = cached_compilation(job) if !run_enzyme - ErrT = PrimalErrorThunk{typeof(compile_result.adjoint), FA, rt2, TT, width, ReturnPrimal} + ErrT = PrimalErrorThunk{typeof(compile_result.adjoint),FA,rt2,TT,width,ReturnPrimal} if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient return (ErrT(compile_result.adjoint), ErrT(compile_result.adjoint)) else @@ -7231,71 +9470,227 @@ end end elseif Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient TapeType = compile_result.TapeType - AugT = AugmentedForwardThunk{typeof(compile_result.primal), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal, TapeType} - AdjT = AdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, TapeType} + AugT = AugmentedForwardThunk{ + typeof(compile_result.primal), + FA, + rt2, + Tuple{params.TT.parameters[2:end]...}, + width, + ReturnPrimal, + TapeType, + } + AdjT = AdjointThunk{ + typeof(compile_result.adjoint), + FA, + rt2, + Tuple{params.TT.parameters[2:end]...}, + width, + TapeType, + } return (AugT(compile_result.primal), AdjT(compile_result.adjoint)) elseif Mode == API.DEM_ReverseModeCombined - CAdjT = CombinedAdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal} + CAdjT = CombinedAdjointThunk{ + typeof(compile_result.adjoint), + FA, + rt2, + Tuple{params.TT.parameters[2:end]...}, + width, + ReturnPrimal, + } return CAdjT(compile_result.adjoint) elseif Mode == API.DEM_ForwardMode - FMT = ForwardModeThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal} + FMT = ForwardModeThunk{ + typeof(compile_result.adjoint), + FA, + rt2, + Tuple{params.TT.parameters[2:end]...}, + width, + ReturnPrimal, + } return FMT(compile_result.adjoint) else @assert false end end -@inline function thunk(mi::Core.MethodInstance, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}, ::Val{RuntimeActivity}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, ABI, ErrIfFuncWritten, RuntimeActivity} - ts_ctx = JuliaContext() - ctx = context(ts_ctx) - activate(ctx) - try - return thunkbase(ctx, mi, Val(#=World=#nothing), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - finally - deactivate(ctx) - dispose(ts_ctx) - end -end - -@inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}, ::Val{RuntimeActivity}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI, ErrIfFuncWritten, RuntimeActivity} - mi = fspec(eltype(FA), TT, World) - ts_ctx = JuliaContext() - ctx = context(ts_ctx) - activate(ctx) - res = try - thunkbase(ctx, mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - finally - deactivate(ctx) - dispose(ts_ctx) - end - return quote - Base.@_inline_meta - return $(res) - end +@inline function thunk( + mi::Core.MethodInstance, + ::Type{FA}, + ::Type{A}, + tt::Type{TT}, + ::Val{Mode}, + ::Val{width}, + ::Val{ModifiedBetween}, + ::Val{ReturnPrimal}, + ::Val{ShadowInit}, + ::Type{ABI}, + ::Val{ErrIfFuncWritten}, + ::Val{RuntimeActivity}, +) where { + FA<:Annotation, + A<:Annotation, + TT, + Mode, + ModifiedBetween, + width, + ReturnPrimal, + ShadowInit, + ABI, + ErrIfFuncWritten, + RuntimeActivity, +} + ts_ctx = JuliaContext() + ctx = context(ts_ctx) + activate(ctx) + try + return thunkbase( + ctx, + mi, + Val(nothing), + FA, + A, + TT, + Val(Mode), + Val(width), + Val(ModifiedBetween), + Val(ReturnPrimal), + Val(ShadowInit), + ABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=World=# + finally + deactivate(ctx) + dispose(ts_ctx) + end +end + +@inline @generated function thunk( + ::Val{World}, + ::Type{FA}, + ::Type{A}, + tt::Type{TT}, + ::Val{Mode}, + ::Val{width}, + ::Val{ModifiedBetween}, + ::Val{ReturnPrimal}, + ::Val{ShadowInit}, + ::Type{ABI}, + ::Val{ErrIfFuncWritten}, + ::Val{RuntimeActivity}, +) where { + FA<:Annotation, + A<:Annotation, + TT, + Mode, + ModifiedBetween, + width, + ReturnPrimal, + ShadowInit, + World, + ABI, + ErrIfFuncWritten, + RuntimeActivity, +} + mi = fspec(eltype(FA), TT, World) + ts_ctx = JuliaContext() + ctx = context(ts_ctx) + activate(ctx) + res = try + thunkbase( + ctx, + mi, + Val(World), + FA, + A, + TT, + Val(Mode), + Val(width), + Val(ModifiedBetween), + Val(ReturnPrimal), + Val(ShadowInit), + ABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) + finally + deactivate(ctx) + dispose(ts_ctx) + end + return quote + Base.@_inline_meta + return $(res) + end end import GPUCompiler: deferred_codegen_jobs -@generated function deferred_codegen(::Val{World}, ::Type{FA}, ::Val{TT}, ::Val{A},::Val{Mode}, - ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal},::Val{ShadowInit},::Type{ExpectedTapeType}, ::Val{ErrIfFuncWritten}, ::Val{RuntimeActivity}) where {World, FA<:Annotation,TT, A, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, ErrIfFuncWritten, RuntimeActivity} +@generated function deferred_codegen( + ::Val{World}, + ::Type{FA}, + ::Val{TT}, + ::Val{A}, + ::Val{Mode}, + ::Val{width}, + ::Val{ModifiedBetween}, + ::Val{ReturnPrimal}, + ::Val{ShadowInit}, + ::Type{ExpectedTapeType}, + ::Val{ErrIfFuncWritten}, + ::Val{RuntimeActivity}, +) where { + World, + FA<:Annotation, + TT, + A, + Mode, + width, + ModifiedBetween, + ReturnPrimal, + ShadowInit, + ExpectedTapeType, + ErrIfFuncWritten, + RuntimeActivity, +} JuliaContext() do ctx Base.@_inline_meta mi = fspec(eltype(FA), TT, World) target = EnzymeTarget() - rt2 = if A isa UnionAll - params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI, ErrIfFuncWritten, RuntimeActivity) - tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) - + rt2 = if A isa UnionAll + params = EnzymeCompilerParams( + Tuple{FA,TT.parameters...}, + Mode, + width, + remove_innerty(A), + true, + true, + ModifiedBetween, + ReturnPrimal, + ShadowInit, + ExpectedTapeType, + FFIABI, + ErrIfFuncWritten, + RuntimeActivity, + ) #=abiwrap=# + tmp_job = Compiler.CompilerJob( + mi, + CompilerConfig(target, params; kernel = false), + World, + ) + interp = GPUCompiler.get_interpreter(tmp_job) - rrt = something(Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), Any) + rrt = something( + Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), + Any, + ) # Don't error here but default to nothing return since in cuda context we don't use the device overrides if rrt == Union{} rrt = Nothing end - + if !(A <: Const) && guaranteed_const_nongen(rrt, World) estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant" return quote @@ -7307,9 +9702,24 @@ import GPUCompiler: deferred_codegen_jobs @assert A isa DataType A end - - params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI, ErrIfFuncWritten, RuntimeActivity) - job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) + + params = EnzymeCompilerParams( + Tuple{FA,TT.parameters...}, + Mode, + width, + rt2, + true, + true, + ModifiedBetween, + ReturnPrimal, + ShadowInit, + ExpectedTapeType, + FFIABI, + ErrIfFuncWritten, + RuntimeActivity, + ) #=abiwrap=# + job = + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false), World) addr = get_trampoline(job) id = Base.reinterpret(Int, pointer(addr)) @@ -7317,7 +9727,13 @@ import GPUCompiler: deferred_codegen_jobs quote Base.@_inline_meta - ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $(reinterpret(Ptr{Cvoid}, id))) + ccall( + "extern deferred_codegen", + llvmcall, + Ptr{Cvoid}, + (Ptr{Cvoid},), + $(reinterpret(Ptr{Cvoid}, id)), + ) end end end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index c167581c3a..4d48297ae5 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -1,6 +1,12 @@ module Interpreter import Enzyme: API -using Core.Compiler: AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, OptimizationParams, MethodInstance +using Core.Compiler: + AbstractInterpreter, + InferenceResult, + InferenceParams, + InferenceState, + OptimizationParams, + MethodInstance using GPUCompiler: @safe_debug if VERSION < v"1.11.0-DEV.1552" using GPUCompiler: CodeCache, WorldView, @safe_debug @@ -18,11 +24,11 @@ else end struct EnzymeInterpreter <: AbstractInterpreter -@static if HAS_INTEGRATED_CACHE - token::Any -else - code_cache::CodeCache -end + @static if HAS_INTEGRATED_CACHE + token::Any + else + code_cache::CodeCache + end method_table::Union{Nothing,Core.MethodTable} # Cache of inference results for this particular interpreter @@ -37,11 +43,16 @@ end mode::API.CDerivativeMode end -function EnzymeInterpreter(cache_or_token, mt::Union{Nothing,Core.MethodTable}, world::UInt, mode::API.CDerivativeMode) +function EnzymeInterpreter( + cache_or_token, + mt::Union{Nothing,Core.MethodTable}, + world::UInt, + mode::API.CDerivativeMode, +) @assert world <= Base.get_world_counter() parms = @static if VERSION < v"1.12" - InferenceParams(unoptimize_throw_blocks=false) + InferenceParams(unoptimize_throw_blocks = false) else InferenceParams() end @@ -57,9 +68,9 @@ function EnzymeInterpreter(cache_or_token, mt::Union{Nothing,Core.MethodTable}, world, # parameters for inference and optimization - parms, + parms, OptimizationParams(), - mode + mode, ) end @@ -70,7 +81,8 @@ Core.Compiler.get_inference_cache(interp::EnzymeInterpreter) = interp.local_cach @static if HAS_INTEGRATED_CACHE Core.Compiler.cache_owner(interp::EnzymeInterpreter) = interp.token else - Core.Compiler.code_cache(interp::EnzymeInterpreter) = WorldView(interp.code_cache, interp.world) + Core.Compiler.code_cache(interp::EnzymeInterpreter) = + WorldView(interp.code_cache, interp.world) end # No need to do any locking since we're not putting our results into the runtime cache @@ -87,14 +99,14 @@ Core.Compiler.may_discard_trees(::EnzymeInterpreter) = false Core.Compiler.verbose_stmt_info(::EnzymeInterpreter) = false if isdefined(Base.Experimental, Symbol("@overlay")) -Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = - Core.Compiler.OverlayMethodTable(interp.world, interp.method_table) + Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = + Core.Compiler.OverlayMethodTable(interp.world, interp.method_table) else -# On 1.6- CUDA.jl will poison the method table at the end of the world -# using GPUCompiler: WorldOverlayMethodTable -# Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = -# WorldOverlayMethodTable(interp.world) + # On 1.6- CUDA.jl will poison the method table at the end of the world + # using GPUCompiler: WorldOverlayMethodTable + # Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = + # WorldOverlayMethodTable(interp.world) end function is_alwaysinline_func(@nospecialize(TT)) @@ -114,8 +126,11 @@ function is_primitive_func(@nospecialize(TT)) end # FIXME(@wsmoses): For which types should we not inline? - if ft === typeof(Base.wait) || ft === typeof(Base._wait) || ft === typeof(Base.enq_work) || - ft === typeof(Base.Threads.threadid) || ft == typeof(Base.Threads.nthreads) || + if ft === typeof(Base.wait) || + ft === typeof(Base._wait) || + ft === typeof(Base.enq_work) || + ft === typeof(Base.Threads.threadid) || + ft == typeof(Base.Threads.nthreads) || ft === typeof(Base.Threads.threading_run) return true end @@ -123,7 +138,7 @@ function is_primitive_func(@nospecialize(TT)) end function isKWCallSignature(@nospecialize(TT)) - return TT <: Tuple{typeof(Core.kwcall), Any, Any, Vararg} + return TT <: Tuple{typeof(Core.kwcall),Any,Any,Vararg} end function simplify_kw(@nospecialize specTypes) @@ -137,27 +152,46 @@ end import Core.Compiler: CallInfo struct NoInlineCallInfo <: CallInfo info::CallInfo # wrapped call - tt # ::Type + tt::Any # ::Type kind::Symbol - NoInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt), kind::Symbol) = new(info, tt, kind) + NoInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt), kind::Symbol) = + new(info, tt, kind) end Core.Compiler.nsplit_impl(info::NoInlineCallInfo) = Core.Compiler.nsplit(info.info) -Core.Compiler.getsplit_impl(info::NoInlineCallInfo, idx::Int) = Core.Compiler.getsplit(info.info, idx) -Core.Compiler.getresult_impl(info::NoInlineCallInfo, idx::Int) = Core.Compiler.getresult(info.info, idx) +Core.Compiler.getsplit_impl(info::NoInlineCallInfo, idx::Int) = + Core.Compiler.getsplit(info.info, idx) +Core.Compiler.getresult_impl(info::NoInlineCallInfo, idx::Int) = + Core.Compiler.getresult(info.info, idx) struct AlwaysInlineCallInfo <: CallInfo info::CallInfo # wrapped call - tt # ::Type + tt::Any # ::Type AlwaysInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt)) = new(info, tt) end Core.Compiler.nsplit_impl(info::AlwaysInlineCallInfo) = Core.Compiler.nsplit(info.info) -Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getsplit(info.info, idx) -Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult(info.info, idx) +Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = + Core.Compiler.getsplit(info.info, idx) +Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = + Core.Compiler.getresult(info.info, idx) using Core.Compiler: ArgInfo, StmtInfo, AbsIntState -function Core.Compiler.abstract_call_gf_by_type(interp::EnzymeInterpreter, @nospecialize(f), - arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype), sv::AbsIntState, max_methods::Int) - ret = @invoke Core.Compiler.abstract_call_gf_by_type(interp::AbstractInterpreter, f::Any, - arginfo::ArgInfo, si::StmtInfo, atype::Any, sv::AbsIntState, max_methods::Int) +function Core.Compiler.abstract_call_gf_by_type( + interp::EnzymeInterpreter, + @nospecialize(f), + arginfo::ArgInfo, + si::StmtInfo, + @nospecialize(atype), + sv::AbsIntState, + max_methods::Int, +) + ret = @invoke Core.Compiler.abstract_call_gf_by_type( + interp::AbstractInterpreter, + f::Any, + arginfo::ArgInfo, + si::StmtInfo, + atype::Any, + sv::AbsIntState, + max_methods::Int, + ) callinfo = ret.info method_table = Core.Compiler.method_table(interp) specTypes = simplify_kw(atype) @@ -175,21 +209,43 @@ function Core.Compiler.abstract_call_gf_by_type(interp::EnzymeInterpreter, @nosp callinfo = NoInlineCallInfo(callinfo, atype, :rrule) end @static if VERSION ≥ v"1.11-" - return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) + return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) else - return Core.Compiler.CallMeta(ret.rt, ret.effects, callinfo) + return Core.Compiler.CallMeta(ret.rt, ret.effects, callinfo) end end let # overload `inlining_policy` @static if VERSION ≥ v"1.11.0-DEV.879" - sigs_ex = :(interp::EnzymeInterpreter, @nospecialize(src), @nospecialize(info::Core.Compiler.CallInfo), stmt_flag::UInt32) - args_ex = :(interp::AbstractInterpreter, src::Any, info::Core.Compiler.CallInfo, stmt_flag::UInt32) + sigs_ex = :( + interp::EnzymeInterpreter, + @nospecialize(src), + @nospecialize(info::Core.Compiler.CallInfo), + stmt_flag::UInt32, + ) + args_ex = :( + interp::AbstractInterpreter, + src::Any, + info::Core.Compiler.CallInfo, + stmt_flag::UInt32, + ) else - sigs_ex = :(interp::EnzymeInterpreter, - @nospecialize(src), @nospecialize(info::Core.Compiler.CallInfo), stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) - args_ex = :(interp::AbstractInterpreter, - src::Any, info::Core.Compiler.CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) + sigs_ex = :( + interp::EnzymeInterpreter, + @nospecialize(src), + @nospecialize(info::Core.Compiler.CallInfo), + stmt_flag::UInt8, + mi::MethodInstance, + argtypes::Vector{Any}, + ) + args_ex = :( + interp::AbstractInterpreter, + src::Any, + info::Core.Compiler.CallInfo, + stmt_flag::UInt8, + mi::MethodInstance, + argtypes::Vector{Any}, + ) end @eval function Core.Compiler.inlining_policy($(sigs_ex.args...)) if info isa NoInlineCallInfo @@ -212,20 +268,36 @@ let # overload `inlining_policy` end end -import Core.Compiler: abstract_call, abstract_call_known, ArgInfo, StmtInfo, AbsIntState, get_max_methods, - CallMeta, Effects, NoCallInfo, widenconst, mapany, MethodResultPure +import Core.Compiler: + abstract_call, + abstract_call_known, + ArgInfo, + StmtInfo, + AbsIntState, + get_max_methods, + CallMeta, + Effects, + NoCallInfo, + widenconst, + mapany, + MethodResultPure struct AutodiffCallInfo <: CallInfo # ... info::CallInfo end -function abstract_call_known(interp::EnzymeInterpreter, @nospecialize(f), - arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, - max_methods::Int = get_max_methods(interp, f, sv)) +function abstract_call_known( + interp::EnzymeInterpreter, + @nospecialize(f), + arginfo::ArgInfo, + si::StmtInfo, + sv::AbsIntState, + max_methods::Int = get_max_methods(interp, f, sv), +) (; fargs, argtypes) = arginfo - + if f === Enzyme.within_autodiff if length(argtypes) != 1 @static if VERSION < v"1.11.0-" @@ -235,26 +307,48 @@ function abstract_call_known(interp::EnzymeInterpreter, @nospecialize(f), end end @static if VERSION < v"1.11.0-" - return CallMeta(Core.Const(true), Core.Compiler.EFFECTS_TOTAL, MethodResultPure()) + return CallMeta( + Core.Const(true), + Core.Compiler.EFFECTS_TOTAL, + MethodResultPure(), + ) else - return CallMeta(Core.Const(true), Union{}, Core.Compiler.EFFECTS_TOTAL, MethodResultPure()) + return CallMeta( + Core.Const(true), + Union{}, + Core.Compiler.EFFECTS_TOTAL, + MethodResultPure(), + ) end end if f === Enzyme.autodiff && length(argtypes) >= 4 - if widenconst(argtypes[2]) <: Enzyme.Mode && widenconst(argtypes[3]) <: Enzyme.Annotation && widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation} - arginfo2 = ArgInfo( - fargs isa Nothing ? nothing : [:(Enzyme.autodiff_deferred), fargs[2:end]...], - [Core.Const(Enzyme.autodiff_deferred), argtypes[2:end]...] - ) - return abstract_call_known( - interp, Enzyme.autodiff_deferred, arginfo2, - si, sv, max_methods) - end + if widenconst(argtypes[2]) <: Enzyme.Mode && + widenconst(argtypes[3]) <: Enzyme.Annotation && + widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation} + arginfo2 = ArgInfo( + fargs isa Nothing ? nothing : + [:(Enzyme.autodiff_deferred), fargs[2:end]...], + [Core.Const(Enzyme.autodiff_deferred), argtypes[2:end]...], + ) + return abstract_call_known( + interp, + Enzyme.autodiff_deferred, + arginfo2, + si, + sv, + max_methods, + ) + end end return Base.@invoke abstract_call_known( - interp::AbstractInterpreter, f, arginfo::ArgInfo, - si::StmtInfo, sv::AbsIntState, max_methods::Int) + interp::AbstractInterpreter, + f, + arginfo::ArgInfo, + si::StmtInfo, + sv::AbsIntState, + max_methods::Int, + ) end end diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 2e3e8194c9..d11daaa0b3 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -15,24 +15,52 @@ struct PipelineConfig cleanup::Cint end -const RunAttributor = Ref(true) - -function pipeline_options(; lower_intrinsics=true, dump_native=false, external_use=false, llvm_only=false, always_inline=true, enable_early_simplifications=true, - enable_early_optimizations=true, - enable_scalar_optimizations=true, - enable_loop_optimizations=true, - enable_vector_pipeline=true, - remove_ni=true, - cleanup=true, Size=0, Speedup=3) - return PipelineConfig(Speedup, Size, lower_intrinsics, dump_native, external_use, llvm_only, always_inline, enable_early_simplifications, enable_early_optimizations, enable_scalar_optimizations, enable_loop_optimizations, enable_vector_pipeline, remove_ni, cleanup) +const RunAttributor = Ref(true) + +function pipeline_options(; + lower_intrinsics = true, + dump_native = false, + external_use = false, + llvm_only = false, + always_inline = true, + enable_early_simplifications = true, + enable_early_optimizations = true, + enable_scalar_optimizations = true, + enable_loop_optimizations = true, + enable_vector_pipeline = true, + remove_ni = true, + cleanup = true, + Size = 0, + Speedup = 3, +) + return PipelineConfig( + Speedup, + Size, + lower_intrinsics, + dump_native, + external_use, + llvm_only, + always_inline, + enable_early_simplifications, + enable_early_optimizations, + enable_scalar_optimizations, + enable_loop_optimizations, + enable_vector_pipeline, + remove_ni, + cleanup, + ) end -function run_jl_pipeline(pm, tm; kwargs...) - config = Ref(pipeline_options(;kwargs...)) +function run_jl_pipeline(pm, tm; kwargs...) + config = Ref(pipeline_options(; kwargs...)) function jl_pipeline(m) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm - @ccall jl_build_newpm_pipeline(mpm.ref::Ptr{Cvoid}, pb.ref::Ptr{Cvoid}, config::Ptr{PipelineConfig})::Cvoid + @ccall jl_build_newpm_pipeline( + mpm.ref::Ptr{Cvoid}, + pb.ref::Ptr{Cvoid}, + config::Ptr{PipelineConfig}, + )::Cvoid end LLVM.run!(mpm, m, tm) end @@ -53,10 +81,10 @@ end else function gc_invariant_verifier_tm!(pm, tm, cond) function gc_invariant_verifier(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm - add!(fpm, GCInvariantVerifierPass(;strong=cond)) + add!(fpm, GCInvariantVerifierPass(; strong = cond)) end end run!(pb, mod) @@ -74,7 +102,7 @@ end else function propagate_julia_addrsp_tm!(pm, tm) function prop_julia_addr(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, PropagateJuliaAddrspacesPass()) @@ -95,7 +123,7 @@ end else function alloc_opt_tm!(pm, tm) function alloc_opt(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, AllocOptPass()) @@ -116,7 +144,7 @@ end else function remove_ni_tm!(pm, tm) function remove_ni(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, RemoveNIPass()) end @@ -135,7 +163,7 @@ end else function julia_licm_tm!(pm, tm) function julia_licm(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, NewPMLoopPassManager()) do lpm @@ -159,7 +187,7 @@ end else function lower_simdloop_tm!(pm, tm) function lower_simdloop(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, NewPMLoopPassManager()) do lpm @@ -181,13 +209,28 @@ function loop_optimizations_tm!(pm, tm) @static if true || VERSION < v"1.11-" lower_simdloop_tm!(pm, tm) licm!(pm) - if LLVM.version() >= v"15" + if LLVM.version() >= v"15" simple_loop_unswitch_legacy!(pm) else loop_unswitch!(pm) end else - run_jl_pipeline(pm, tm; lower_intrinsics=false, dump_native=false, external_use=false, llvm_only=false, always_inline=false, enable_early_simplifications=false, enable_early_optimizations=false, enable_scalar_optimizations=false, enable_loop_optimizations=true, enable_vector_pipeline=false, remove_ni=false, cleanup=false) + run_jl_pipeline( + pm, + tm; + lower_intrinsics = false, + dump_native = false, + external_use = false, + llvm_only = false, + always_inline = false, + enable_early_simplifications = false, + enable_early_optimizations = false, + enable_scalar_optimizations = false, + enable_loop_optimizations = true, + enable_vector_pipeline = false, + remove_ni = false, + cleanup = false, + ) end end @@ -205,7 +248,7 @@ function more_loop_optimizations_tm!(pm, tm) # Subsequent passes not stripping metadata from terminator instruction_combining!(pm) # TODO: createInstSimplifyLegacy jl_inst_simplify!(pm) - + ind_var_simplify!(pm) loop_deletion!(pm) loop_unroll!(pm) # TODO: in Julia createSimpleLoopUnroll @@ -224,7 +267,22 @@ function more_loop_optimizations_tm!(pm, tm) # IndVarSimplifyPass # LoopDeletionPass # LoopFullUnrollPass - run_jl_pipeline(pm, tm; lower_intrinsics=false, dump_native=false, external_use=false, llvm_only=false, always_inline=false, enable_early_simplifications=false, enable_early_optimizations=false, enable_scalar_optimizations=false, enable_loop_optimizations=true, enable_vector_pipeline=false, remove_ni=false, cleanup=false) + run_jl_pipeline( + pm, + tm; + lower_intrinsics = false, + dump_native = false, + external_use = false, + llvm_only = false, + always_inline = false, + enable_early_simplifications = false, + enable_early_optimizations = false, + enable_scalar_optimizations = false, + enable_loop_optimizations = true, + enable_vector_pipeline = false, + remove_ni = false, + cleanup = false, + ) end end @@ -235,7 +293,7 @@ end else function demote_float16_tm!(pm, tm) function demote_float16(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, DemoteFloat16Pass()) @@ -256,7 +314,7 @@ end else function lower_exc_handlers_tm!(pm, tm) function lower_exc_handlers(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, LowerExcHandlersPass()) @@ -277,7 +335,7 @@ end else function lower_ptls_tm!(pm, tm, dump_native) function lower_ptls(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, LowerPTLSPass()) end @@ -296,7 +354,7 @@ end else function combine_mul_add_tm!(pm, tm) function combine_mul_add(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, CombineMulAddPass()) @@ -317,7 +375,7 @@ end else function late_lower_gc_frame_tm!(pm, tm) function late_lower_gc_frame(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, LateLowerGCPass()) @@ -338,7 +396,7 @@ end else function final_lower_gc_tm!(pm, tm) function final_lower_gc(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, FinalLowerGCPass()) @@ -355,17 +413,17 @@ end @static if VERSION < v"1.11-" function cpu_features_tm!(pm, tm) @static if isdefined(LLVM.Interop, :cpu_features!) - LLVM.Interop.cpu_features!(pm) + LLVM.Interop.cpu_features!(pm) else - @static if isdefined(GPUCompiler, :cpu_features!) + @static if isdefined(GPUCompiler, :cpu_features!) GPUCompiler.cpu_features!(pm) - end + end end end else function cpu_features_tm!(pm, tm) function cpu_features(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, CPUFeaturesPass()) end @@ -379,7 +437,7 @@ end function addNA(inst, node::LLVM.Metadata, MD) md = metadata(inst) - next = nothing + next = nothing if haskey(md, MD) next = LLVM.MDNode(Metadata[node, operands(md[MD])...]) else @@ -405,7 +463,7 @@ function addr13NoAlias(mod::LLVM.Module) end end elseif isa(inst, LLVM.LoadInst) - ty =value_type(inst) + ty = value_type(inst) if isa(ty, LLVM.PointerType) if addrspace(ty) == 13 addNA(inst, aliasscope, LLVM.MD_alias_scope) @@ -432,7 +490,7 @@ end # turn this into load/store, as this is more # amenable to caching analysis infrastructure function memcpy_alloca_to_loadstore(mod::LLVM.Module) - dl = datalayout(mod) + dl = datalayout(mod) for f in functions(mod) if length(blocks(f)) != 0 bb = first(blocks(f)) @@ -441,21 +499,24 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module) if !isa(alloca, LLVM.AllocaInst) continue end - todo = Tuple{LLVM.Instruction, LLVM.Value}[(alloca, alloca)] + todo = Tuple{LLVM.Instruction,LLVM.Value}[(alloca, alloca)] copy = nothing legal = true elty = LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(alloca)) lifetimestarts = LLVM.Instruction[] while length(todo) > 0 cur, prev = pop!(todo) - if isa(cur, LLVM.AllocaInst) || isa(cur, LLVM.AddrSpaceCastInst) || isa(cur, LLVM.BitCastInst) + if isa(cur, LLVM.AllocaInst) || + isa(cur, LLVM.AddrSpaceCastInst) || + isa(cur, LLVM.BitCastInst) for u in LLVM.uses(cur) u = LLVM.user(u) push!(todo, (u, cur)) end continue end - if isa(cur, LLVM.CallInst) && isa(LLVM.called_operand(cur), LLVM.Function) + if isa(cur, LLVM.CallInst) && + isa(LLVM.called_operand(cur), LLVM.Function) intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(cur)) if intr == LLVM.Intrinsic("llvm.lifetime.start").id push!(lifetimestarts, cur) @@ -466,7 +527,9 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module) end if intr == LLVM.Intrinsic("llvm.memcpy").id sz = operands(cur)[3] - if operands(cur)[1] == prev && isa(sz, LLVM.ConstantInt) && convert(Int, sz) == sizeof(dl, elty) + if operands(cur)[1] == prev && + isa(sz, LLVM.ConstantInt) && + convert(Int, sz) == sizeof(dl, elty) if copy === nothing || copy == cur copy = cur continue @@ -479,13 +542,16 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module) if isa(cur, LLVM.LoadInst) continue end - if isa(cur, LLVM.CallInst) && isa(LLVM.called_operand(cur), LLVM.Function) + if isa(cur, LLVM.CallInst) && + isa(LLVM.called_operand(cur), LLVM.Function) legalc = true for (i, ci) in enumerate(operands(cur)[1:end-1]) if ci == prev nocapture = false readonly = false - for a in collect(parameter_attributes(LLVM.called_operand(cur), i)) + for a in collect( + parameter_attributes(LLVM.called_operand(cur), i), + ) if kind(a) == kind(EnumAttribute("readonly")) readonly = true end @@ -510,21 +576,35 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module) legal = false break end - + if legal && copy !== nothing B = LLVM.IRBuilder() position!(B, copy) dst = operands(copy)[1] src = operands(copy)[2] - dst0 = bitcast!(B, dst, LLVM.PointerType(LLVM.IntType(8), addrspace(value_type(dst)))) + dst0 = bitcast!( + B, + dst, + LLVM.PointerType(LLVM.IntType(8), addrspace(value_type(dst))), + ) - dst = bitcast!(B, dst, LLVM.PointerType(elty, addrspace(value_type(dst)))) - src = bitcast!(B, src, LLVM.PointerType(elty, addrspace(value_type(src)))) + dst = + bitcast!(B, dst, LLVM.PointerType(elty, addrspace(value_type(dst)))) + src = + bitcast!(B, src, LLVM.PointerType(elty, addrspace(value_type(src)))) src = load!(B, elty, src) - FT = LLVM.FunctionType(LLVM.VoidType(), [LLVM.IntType(64), value_type(dst0)]) + FT = LLVM.FunctionType( + LLVM.VoidType(), + [LLVM.IntType(64), value_type(dst0)], + ) lifetimestart, _ = get_function!(mod, "llvm.lifetime.start.p0i8", FT) - call!(B, FT, lifetimestart, LLVM.Value[LLVM.ConstantInt(Int64(sizeof(dl, elty))), dst0]) + call!( + B, + FT, + lifetimestart, + LLVM.Value[LLVM.ConstantInt(Int64(sizeof(dl, elty))), dst0], + ) store!(B, src, dst) push!(todel, copy) end @@ -601,289 +681,376 @@ function nodecayed_phis!(mod::LLVM.Module) end end - offty = LLVM.IntType(8*sizeof(Int)) + offty = LLVM.IntType(8 * sizeof(Int)) i8 = LLVM.IntType(8) for addr in (11, 13) - nextvs = Dict{LLVM.PHIInst, LLVM.PHIInst}() - mtodo = Vector{LLVM.PHIInst}[] - goffsets = Dict{LLVM.PHIInst, LLVM.PHIInst}() - nonphis = LLVM.Instruction[] - anyV = false - for bb in blocks(f) - todo = LLVM.PHIInst[] - nonphi = nothing - for inst in instructions(bb) - if !isa(inst, LLVM.PHIInst) - nonphi = inst - break - end - ty = value_type(inst) - if !isa(ty, LLVM.PointerType) - continue - end - if addrspace(ty) != addr - continue - end - if addr == 11 - all_args = true - addrtodo = Value[inst] - seen = Set{LLVM.Value}() + nextvs = Dict{LLVM.PHIInst,LLVM.PHIInst}() + mtodo = Vector{LLVM.PHIInst}[] + goffsets = Dict{LLVM.PHIInst,LLVM.PHIInst}() + nonphis = LLVM.Instruction[] + anyV = false + for bb in blocks(f) + todo = LLVM.PHIInst[] + nonphi = nothing + for inst in instructions(bb) + if !isa(inst, LLVM.PHIInst) + nonphi = inst + break + end + ty = value_type(inst) + if !isa(ty, LLVM.PointerType) + continue + end + if addrspace(ty) != addr + continue + end + if addr == 11 + all_args = true + addrtodo = Value[inst] + seen = Set{LLVM.Value}() - while length(addrtodo) != 0 - v = pop!(addrtodo) - base = get_base_object(v) - if in(base, seen) - continue - end - push!(seen, base) - if isa(base, LLVM.Argument) && addrspace(value_type(base)) == 11 - continue - end - if isa(base, LLVM.PHIInst) - for (v, _) in LLVM.incoming(base) - push!(addrtodo, v) + while length(addrtodo) != 0 + v = pop!(addrtodo) + base = get_base_object(v) + if in(base, seen) + continue end + push!(seen, base) + if isa(base, LLVM.Argument) && addrspace(value_type(base)) == 11 + continue + end + if isa(base, LLVM.PHIInst) + for (v, _) in LLVM.incoming(base) + push!(addrtodo, v) + end + continue + end + all_args = false + break + end + if all_args continue end - all_args = false - break end - if all_args - continue + + push!(todo, inst) + nb = IRBuilder() + position!(nb, inst) + el_ty = if addr == 11 + eltype(ty) + else + LLVM.StructType(LLVM.LLVMType[]) end - end - - push!(todo, inst) - nb = IRBuilder() - position!(nb, inst) - el_ty = if addr == 11 - eltype(ty) - else - LLVM.StructType(LLVM.LLVMType[]) - end - nphi = phi!(nb, LLVM.PointerType(el_ty, 10), "nodecayed." * LLVM.name(inst)) - nextvs[inst] = nphi - anyV = true + nphi = phi!( + nb, + LLVM.PointerType(el_ty, 10), + "nodecayed." * LLVM.name(inst), + ) + nextvs[inst] = nphi + anyV = true - goffsets[inst] = phi!(nb, offty, "nodecayedoff." * LLVM.name(inst)) + goffsets[inst] = phi!(nb, offty, "nodecayedoff." * LLVM.name(inst)) + end + push!(mtodo, todo) + push!(nonphis, nonphi) end - push!(mtodo, todo) - push!(nonphis, nonphi) - end - for (bb, todo, nonphi) in zip(blocks(f), mtodo, nonphis) + for (bb, todo, nonphi) in zip(blocks(f), mtodo, nonphis) - for inst in todo - ty = value_type(inst) - el_ty = if addr == 11 - eltype(ty) - else - LLVM.StructType(LLVM.LLVMType[]) - end - nvs = Tuple{LLVM.Value, LLVM.BasicBlock}[] - offsets = Tuple{LLVM.Value, LLVM.BasicBlock}[] - for (v, pb) in LLVM.incoming(inst) - done = false - for ((nv, pb0), (offset, pb1)) in zip(nvs, offsets) - if pb0 == pb - push!(nvs, (nv, pb)) - push!(offsets, (offset, pb)) - done = true - break - end - end - if done - continue - end - b = IRBuilder() - position!(b, terminator(pb)) + for inst in todo + ty = value_type(inst) + el_ty = if addr == 11 + eltype(ty) + else + LLVM.StructType(LLVM.LLVMType[]) + end + nvs = Tuple{LLVM.Value,LLVM.BasicBlock}[] + offsets = Tuple{LLVM.Value,LLVM.BasicBlock}[] + for (v, pb) in LLVM.incoming(inst) + done = false + for ((nv, pb0), (offset, pb1)) in zip(nvs, offsets) + if pb0 == pb + push!(nvs, (nv, pb)) + push!(offsets, (offset, pb)) + done = true + break + end + end + if done + continue + end + b = IRBuilder() + position!(b, terminator(pb)) - v0 = v - @inline function getparent(v, offset, hasload) - if addr == 11 && addrspace(value_type(v)) == 10 - return v, offset, hasload - end - if addr == 13 && hasload && addrspace(value_type(v)) == 10 - return v, offset, hasload - end - if addr == 13 && isa(v, LLVM.LoadInst) && !hasload - return getparent(operands(v)[1], offset, true) - end + v0 = v + @inline function getparent(v, offset, hasload) + if addr == 11 && addrspace(value_type(v)) == 10 + return v, offset, hasload + end + if addr == 13 && hasload && addrspace(value_type(v)) == 10 + return v, offset, hasload + end + if addr == 13 && isa(v, LLVM.LoadInst) && !hasload + return getparent(operands(v)[1], offset, true) + end - if addr == 13 && isa(v, LLVM.ConstantExpr) - if opcode(v) == LLVM.API.LLVMAddrSpaceCast - v2 = operands(v)[1] - if addrspace(value_type(v2)) == 0 - if addr == 13 && isa(v, LLVM.ConstantExpr) - v2 = const_addrspacecast(operands(v)[1], LLVM.PointerType(eltype(value_type(v)), 10)) - return v2, offset, hasload + if addr == 13 && isa(v, LLVM.ConstantExpr) + if opcode(v) == LLVM.API.LLVMAddrSpaceCast + v2 = operands(v)[1] + if addrspace(value_type(v2)) == 0 + if addr == 13 && isa(v, LLVM.ConstantExpr) + v2 = const_addrspacecast( + operands(v)[1], + LLVM.PointerType(eltype(value_type(v)), 10), + ) + return v2, offset, hasload + end + end end end - end - end - if addr == 11 && isa(v, LLVM.ConstantExpr) - if opcode(v) == LLVM.API.LLVMAddrSpaceCast - v2 = operands(v)[1] - if addrspace(value_type(v2)) == 10 - return v2, offset, hasload + if addr == 11 && isa(v, LLVM.ConstantExpr) + if opcode(v) == LLVM.API.LLVMAddrSpaceCast + v2 = operands(v)[1] + if addrspace(value_type(v2)) == 10 + return v2, offset, hasload + end + if addrspace(value_type(v2)) == 0 + if addr == 11 + v2 = const_addrspacecast( + v2, + LLVM.PointerType(eltype(value_type(v)), 10), + ) + return v2, offset, hasload + end + end + if LLVM.isnull(v2) + v2 = const_addrspacecast( + v2, + LLVM.PointerType(eltype(value_type(v)), 10), + ) + return v2, offset, hasload + end + end end - if addrspace(value_type(v2)) == 0 - if addr == 11 - v2 = const_addrspacecast(v2, LLVM.PointerType(eltype(value_type(v)), 10)) + + if isa(v, LLVM.AddrSpaceCastInst) + if addrspace(value_type(operands(v)[1])) == 0 + v2 = addrspacecast!( + b, + operands(v)[1], + LLVM.PointerType(eltype(value_type(v)), 10), + ) return v2, offset, hasload end + nv, noffset, nhasload = + getparent(operands(v)[1], offset, hasload) + if eltype(value_type(nv)) != eltype(value_type(v)) + nv = bitcast!( + b, + nv, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(nv)), + ), + ) + end + return nv, noffset, nhasload end - if LLVM.isnull(v2) - v2 = const_addrspacecast(v2, LLVM.PointerType(eltype(value_type(v)), 10)) - return v2, offset, hasload - end - end - end - if isa(v, LLVM.AddrSpaceCastInst) - if addrspace(value_type(operands(v)[1])) == 0 - v2 = addrspacecast!(b, operands(v)[1], LLVM.PointerType(eltype(value_type(v)), 10)) - return v2, offset, hasload - end - nv, noffset, nhasload = getparent(operands(v)[1], offset, hasload) - if eltype(value_type(nv)) != eltype(value_type(v)) - nv = bitcast!(b, nv, LLVM.PointerType(eltype(value_type(v)), addrspace(value_type(nv)))) - end - return nv, noffset, nhasload - end + if isa(v, LLVM.BitCastInst) + v2, offset, skipload = + getparent(operands(v)[1], offset, hasload) + v2 = bitcast!( + b, + v2, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(v2)), + ), + ) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end - if isa(v, LLVM.BitCastInst) - v2, offset, skipload = getparent(operands(v)[1], offset, hasload) - v2 = bitcast!(b, v2, LLVM.PointerType(eltype(value_type(v)), addrspace(value_type(v2)))) - @assert eltype(value_type(v2)) == eltype(value_type(v)) - return v2, offset, skipload - end + if isa(v, LLVM.GetElementPtrInst) && all( + x -> (isa(x, LLVM.ConstantInt) && convert(Int, x) == 0), + operands(v)[2:end], + ) + v2, offset, skipload = + getparent(operands(v)[1], offset, hasload) + v2 = bitcast!( + b, + v2, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(v2)), + ), + ) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end - if isa(v, LLVM.GetElementPtrInst) && all(x->(isa(x, LLVM.ConstantInt) && convert(Int, x) == 0), operands(v)[2:end]) - v2, offset, skipload = getparent(operands(v)[1], offset, hasload) - v2 = bitcast!(b, v2, LLVM.PointerType(eltype(value_type(v)), addrspace(value_type(v2)))) - @assert eltype(value_type(v2)) == eltype(value_type(v)) - return v2, offset, skipload - end + if isa(v, LLVM.GetElementPtrInst) && !hasload + v2, offset, skipload = + getparent(operands(v)[1], offset, hasload) + offset = nuwadd!( + b, + offset, + API.EnzymeComputeByteOffsetOfGEP(b, v, offty), + ) + v2 = bitcast!( + b, + v2, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(v2)), + ), + ) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end - if isa(v, LLVM.GetElementPtrInst) && !hasload - v2, offset, skipload = getparent(operands(v)[1], offset, hasload) - offset = nuwadd!(b, offset, API.EnzymeComputeByteOffsetOfGEP(b, v, offty)) - v2 = bitcast!(b, v2, LLVM.PointerType(eltype(value_type(v)), addrspace(value_type(v2)))) - @assert eltype(value_type(v2)) == eltype(value_type(v)) - return v2, offset, skipload - end + if isa(v, LLVM.ConstantExpr) && + opcode(v) == LLVM.API.LLVMGetElementPtr && + !hasload + v2, offset, skipload = + getparent(operands(v)[1], offset, hasload) + offset = nuwadd!( + b, + offset, + API.EnzymeComputeByteOffsetOfGEP(b, v, offty), + ) + v2 = bitcast!( + b, + v2, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(v2)), + ), + ) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end - if isa(v, LLVM.ConstantExpr) && opcode(v) == LLVM.API.LLVMGetElementPtr && !hasload - v2, offset, skipload = getparent(operands(v)[1], offset, hasload) - offset = nuwadd!(b, offset, API.EnzymeComputeByteOffsetOfGEP(b, v, offty)) - v2 = bitcast!(b, v2, LLVM.PointerType(eltype(value_type(v)), addrspace(value_type(v2)))) - @assert eltype(value_type(v2)) == eltype(value_type(v)) - return v2, offset, skipload - end + undeforpoison = isa(v, LLVM.UndefValue) + @static if LLVM.version() >= v"12" + undeforpoison |= isa(v, LLVM.PoisonValue) + end + if undeforpoison + return LLVM.UndefValue( + LLVM.PointerType(eltype(value_type(v)), 10), + ), + offset, + addr == 13 + end - undeforpoison = isa(v, LLVM.UndefValue) - @static if LLVM.version() >= v"12" - undeforpoison |= isa(v, LLVM.PoisonValue) - end - if undeforpoison - return LLVM.UndefValue(LLVM.PointerType(eltype(value_type(v)),10)), offset, addr == 13 - end + if isa(v, LLVM.PHIInst) && !hasload && haskey(goffsets, v) + offset = nuwadd!(b, offset, goffsets[v]) + nv = nextvs[v] + return nv, offset, addr == 13 + end - if isa(v, LLVM.PHIInst) && !hasload && haskey(goffsets, v) - offset = nuwadd!(b, offset, goffsets[v]) - nv = nextvs[v] - return nv, offset, addr == 13 - end + if isa(v, LLVM.SelectInst) + lhs_v, lhs_offset, lhs_skipload = + getparent(operands(v)[2], offset, hasload) + rhs_v, rhs_offset, rhs_skipload = + getparent(operands(v)[3], offset, hasload) + if value_type(lhs_v) != value_type(rhs_v) || + value_type(lhs_offset) != value_type(rhs_offset) || + lhs_skipload != rhs_skipload + msg = sprint() do io + println( + io, + "Could not analyze [select] garbage collection behavior of", + ) + println(io, " v0: ", string(v0)) + println(io, " v: ", string(v)) + println(io, " offset: ", string(offset)) + println(io, " hasload: ", string(hasload)) + println(io, " lhs_v", lhs_v) + println(io, " rhs_v", rhs_v) + println(io, " lhs_offset", lhs_offset) + println(io, " rhs_offset", rhs_offset) + println(io, " lhs_skipload", lhs_skipload) + println(io, " rhs_skipload", rhs_skipload) + end + bt = GPUCompiler.backtrace(inst) + throw(EnzymeInternalError(msg, string(f), bt)) + end + return select!(b, operands(v)[1], lhs_v, rhs_v), + select!(b, operands(v)[1], lhs_offset, rhs_offset), + lhs_skipload + end - if isa(v, LLVM.SelectInst) - lhs_v, lhs_offset, lhs_skipload = getparent(operands(v)[2], offset, hasload) - rhs_v, rhs_offset, rhs_skipload = getparent(operands(v)[3], offset, hasload) - if value_type(lhs_v) != value_type(rhs_v) || value_type(lhs_offset) != value_type(rhs_offset) || lhs_skipload != rhs_skipload msg = sprint() do io - println(io, "Could not analyze [select] garbage collection behavior of") + println(io, "Could not analyze garbage collection behavior of") + println(io, " inst: ", string(inst)) println(io, " v0: ", string(v0)) println(io, " v: ", string(v)) println(io, " offset: ", string(offset)) println(io, " hasload: ", string(hasload)) - println(io, " lhs_v", lhs_v) - println(io, " rhs_v", rhs_v) - println(io, " lhs_offset", lhs_offset) - println(io, " rhs_offset", rhs_offset) - println(io, " lhs_skipload", lhs_skipload) - println(io, " rhs_skipload", rhs_skipload) end bt = GPUCompiler.backtrace(inst) throw(EnzymeInternalError(msg, string(f), bt)) end - return select!(b, operands(v)[1], lhs_v, rhs_v), select!(b, operands(v)[1], lhs_offset, rhs_offset), lhs_skipload - end - msg = sprint() do io - println(io, "Could not analyze garbage collection behavior of") - println(io, " inst: ", string(inst)) - println(io, " v0: ", string(v0)) - println(io, " v: ", string(v)) - println(io, " offset: ", string(offset)) - println(io, " hasload: ", string(hasload)) - end - bt = GPUCompiler.backtrace(inst) - throw(EnzymeInternalError(msg, string(f), bt)) - end + v, offset, hadload = getparent(v, LLVM.ConstantInt(offty, 0), false) - v, offset, hadload = getparent(v, LLVM.ConstantInt(offty, 0), false) - - if addr == 13 - @assert hadload - end + if addr == 13 + @assert hadload + end - if eltype(value_type(v)) != el_ty - v = bitcast!(b, v, LLVM.PointerType(el_ty, addrspace(value_type(v)))) - end - push!(nvs, (v, pb)) - push!(offsets, (offset, pb)) - end + if eltype(value_type(v)) != el_ty + v = bitcast!( + b, + v, + LLVM.PointerType(el_ty, addrspace(value_type(v))), + ) + end + push!(nvs, (v, pb)) + push!(offsets, (offset, pb)) + end - nb = IRBuilder() - position!(nb, inst) - - offset = goffsets[inst] - append!(LLVM.incoming(offset), offsets) - if all(x->x[1]==offsets[1][1], offsets) - offset = offsets[1][1] - end + nb = IRBuilder() + position!(nb, inst) - nphi = nextvs[inst] - if !all(x->x[1]==nvs[1][1], nvs) - append!(LLVM.incoming(nphi), nvs) - else - replace_uses!(nphi, nvs[1][1]) - LLVM.API.LLVMInstructionEraseFromParent(nphi) - nphi = nvs[1][1] - end + offset = goffsets[inst] + append!(LLVM.incoming(offset), offsets) + if all(x -> x[1] == offsets[1][1], offsets) + offset = offsets[1][1] + end - position!(nb, nonphi) - if addr == 13 - nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10)) - nphi = addrspacecast!(nb, nphi, LLVM.PointerType(ty, 11)) - nphi = load!(nb, ty, nphi) - else - nphi = addrspacecast!(nb, nphi, ty) - end - if !isa(offset, LLVM.ConstantInt) || convert(Int64, offset) != 0 - nphi = bitcast!(nb, nphi, LLVM.PointerType(i8, addrspace(ty))) - nphi = gep!(nb, i8, nphi, [offset]) - nphi = bitcast!(nb, nphi, ty) + nphi = nextvs[inst] + if !all(x -> x[1] == nvs[1][1], nvs) + append!(LLVM.incoming(nphi), nvs) + else + replace_uses!(nphi, nvs[1][1]) + LLVM.API.LLVMInstructionEraseFromParent(nphi) + nphi = nvs[1][1] + end + + position!(nb, nonphi) + if addr == 13 + nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10)) + nphi = addrspacecast!(nb, nphi, LLVM.PointerType(ty, 11)) + nphi = load!(nb, ty, nphi) + else + nphi = addrspacecast!(nb, nphi, ty) + end + if !isa(offset, LLVM.ConstantInt) || convert(Int64, offset) != 0 + nphi = bitcast!(nb, nphi, LLVM.PointerType(i8, addrspace(ty))) + nphi = gep!(nb, i8, nphi, [offset]) + nphi = bitcast!(nb, nphi, ty) + end + replace_uses!(inst, nphi) + end + for inst in todo + LLVM.API.LLVMInstructionEraseFromParent(inst) + end end - replace_uses!(inst, nphi) - end - for inst in todo - LLVM.API.LLVMInstructionEraseFromParent(inst) end - end - end end return nothing end @@ -910,56 +1077,58 @@ function fix_decayaddr!(mod::LLVM.Module) temp = nothing for u in LLVM.uses(inst) st = LLVM.user(u) - # Storing _into_ the decay addr is okay - # we just cannot store the decayed addr into - # somewhere - if isa(st, LLVM.StoreInst) - if operands(st)[2] == inst - LLVM.API.LLVMSetOperand(st, 2-1, operands(inst)[1]) - continue - end - end - if isa(st, LLVM.LoadInst) - LLVM.API.LLVMSetOperand(st, 1-1, operands(inst)[1]) + # Storing _into_ the decay addr is okay + # we just cannot store the decayed addr into + # somewhere + if isa(st, LLVM.StoreInst) + if operands(st)[2] == inst + LLVM.API.LLVMSetOperand(st, 2 - 1, operands(inst)[1]) continue - end - # if isa(st, LLVM.InsertValueInst) - # if operands(st)[1] == inst - # push!(invalid, st) - # LLVM.API.LLVMSetOperand(st, 1-1, LLVM.UndefValue(value_type(inst))) - # continue - # end - # if operands(st)[2] == inst - # push!(invalid, st) - # LLVM.API.LLVMSetOperand(st, 2-1, LLVM.UndefValue(value_type(inst))) - # continue - # end - # end - if !isa(st, LLVM.CallInst) - bt = GPUCompiler.backtrace(st) - msg = sprint() do io::IO - println(io, string(f)) - println(io, inst) - println(io, st) - print(io, "Illegal decay of nonnull\n") - if bt !== nothing - print(io,"\nCaused by:") - Base.show_backtrace(io, bt) - println(io) - end - end - throw(AssertionError(msg)) - end - + end + end + if isa(st, LLVM.LoadInst) + LLVM.API.LLVMSetOperand(st, 1 - 1, operands(inst)[1]) + continue + end + # if isa(st, LLVM.InsertValueInst) + # if operands(st)[1] == inst + # push!(invalid, st) + # LLVM.API.LLVMSetOperand(st, 1-1, LLVM.UndefValue(value_type(inst))) + # continue + # end + # if operands(st)[2] == inst + # push!(invalid, st) + # LLVM.API.LLVMSetOperand(st, 2-1, LLVM.UndefValue(value_type(inst))) + # continue + # end + # end + if !isa(st, LLVM.CallInst) + bt = GPUCompiler.backtrace(st) + msg = sprint() do io::IO + println(io, string(f)) + println(io, inst) + println(io, st) + print(io, "Illegal decay of nonnull\n") + if bt !== nothing + print(io, "\nCaused by:") + Base.show_backtrace(io, bt) + println(io) + end + end + throw(AssertionError(msg)) + end + fop = operands(st)[end] - + intr = LLVM.API.LLVMGetIntrinsicID(fop) - if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id || intr == LLVM.Intrinsic("llvm.memset").id + if intr == LLVM.Intrinsic("llvm.memcpy").id || + intr == LLVM.Intrinsic("llvm.memmove").id || + intr == LLVM.Intrinsic("llvm.memset").id newvs = LLVM.Value[] for (i, v) in enumerate(operands(st)[1:end-1]) if v == inst - LLVM.API.LLVMSetOperand(st, i-1, operands(inst)[1]) + LLVM.API.LLVMSetOperand(st, i - 1, operands(inst)[1]) push!(newvs, operands(inst)[1]) continue end @@ -976,22 +1145,36 @@ function fix_decayaddr!(mod::LLVM.Module) newi = memset!(nb, newvs[1], newvs[2], newvs[3], 0) end - for idx = [LLVM.API.LLVMAttributeFunctionIndex, LLVM.API.LLVMAttributeReturnIndex, [LLVM.API.LLVMAttributeIndex(i) for i in 1:(length(operands(st))-1)]...] + for idx in [ + LLVM.API.LLVMAttributeFunctionIndex, + LLVM.API.LLVMAttributeReturnIndex, + [ + LLVM.API.LLVMAttributeIndex(i) for + i = 1:(length(operands(st))-1) + ]..., + ] idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) - count = LLVM.API.LLVMGetCallSiteAttributeCount(st, idx); - - Attrs = Base.unsafe_convert(Ptr{LLVM.API.LLVMAttributeRef}, Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef)*count)) + count = LLVM.API.LLVMGetCallSiteAttributeCount(st, idx) + + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) LLVM.API.LLVMGetCallSiteAttributes(st, idx, Attrs) - for j in 1:count - LLVM.API.LLVMAddCallSiteAttribute(newi, idx, unsafe_load(Attrs, j)) + for j = 1:count + LLVM.API.LLVMAddCallSiteAttribute( + newi, + idx, + unsafe_load(Attrs, j), + ) end Libc.free(Attrs) end - + API.EnzymeCopyMetadata(newi, st) - + LLVM.API.LLVMInstructionEraseFromParent(st) - continue + continue end mayread = false maywrite = false @@ -1051,7 +1234,7 @@ function fix_decayaddr!(mod::LLVM.Module) end throw(AssertionError(msg)) end - + elt = eltype(value_type(inst)) if temp === nothing nb = IRBuilder() @@ -1089,12 +1272,13 @@ function pre_attr!(mod::LLVM.Module) if isempty(blocks(fn)) continue end - if linkage(fn) != LLVM.API.LLVMInternalLinkage && linkage(fn) != LLVM.API.LLVMPrivateLinkage + if linkage(fn) != LLVM.API.LLVMInternalLinkage && + linkage(fn) != LLVM.API.LLVMPrivateLinkage continue end - + fty = LLVM.FunctionType(fn) - nfn = LLVM.Function(mod, "enzyme_attr_prev_"*LLVM.name(enzymefn), fty) + nfn = LLVM.Function(mod, "enzyme_attr_prev_" * LLVM.name(enzymefn), fty) LLVM.IRBuilder() do builder entry = BasicBlock(nfn, "entry") position!(builder, entry) @@ -1111,74 +1295,89 @@ function pre_attr!(mod::LLVM.Module) end function jl_inst_simplify!(PM) - ccall((:LLVMAddJLInstSimplifyPass, API.libEnzyme), Cvoid, (LLVM.API.LLVMPassManagerRef,), PM) + ccall( + (:LLVMAddJLInstSimplifyPass, API.libEnzyme), + Cvoid, + (LLVM.API.LLVMPassManagerRef,), + PM, + ) end -function post_attr!(mod::LLVM.Module) -end +function post_attr!(mod::LLVM.Module) end function prop_global!(g) newfns = String[] changed = false - todo = Tuple{Vector{Cuint},LLVM.Value}[] - for u in LLVM.uses(g) - u = LLVM.user(u) - push!(todo, (Cuint[],u)) - end - while length(todo) > 0 - path, var = pop!(todo) - if isa(var, LLVM.LoadInst) - B = IRBuilder() - position!(B, var) - res = LLVM.initializer(g) - for p in path - res = extract_value!(B, res, p) - end - changed = true - for u in LLVM.uses(var) - u = LLVM.user(u) - if isa(u, LLVM.CallInst) - f2 = LLVM.called_operand(u) - if isa(f2, LLVM.Function) - push!(newfns, LLVM.name(f2)) - end - end - end - replace_uses!(var, res) - eraseInst(LLVM.parent(var), var) - continue + todo = Tuple{Vector{Cuint},LLVM.Value}[] + for u in LLVM.uses(g) + u = LLVM.user(u) + push!(todo, (Cuint[], u)) + end + while length(todo) > 0 + path, var = pop!(todo) + if isa(var, LLVM.LoadInst) + B = IRBuilder() + position!(B, var) + res = LLVM.initializer(g) + for p in path + res = extract_value!(B, res, p) end - if isa(var, LLVM.AddrSpaceCastInst) - for u in LLVM.uses(var) - u = LLVM.user(u) - push!(todo, (path, u)) + changed = true + for u in LLVM.uses(var) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + f2 = LLVM.called_operand(u) + if isa(f2, LLVM.Function) + push!(newfns, LLVM.name(f2)) + end end - continue end - if isa(var, LLVM.ConstantExpr) && opcode(var) == LLVM.API.LLVMAddrSpaceCast - for u in LLVM.uses(var) - u = LLVM.user(u) - push!(todo, (path, u)) - end - continue + replace_uses!(var, res) + eraseInst(LLVM.parent(var), var) + continue + end + if isa(var, LLVM.AddrSpaceCastInst) + for u in LLVM.uses(var) + u = LLVM.user(u) + push!(todo, (path, u)) end - if isa(var, LLVM.GetElementPtrInst) - if all(isa(v, LLVM.ConstantInt) for v in operands(var)[2:end]) - if convert(Cuint, operands(var)[2]) == 0 - for u in LLVM.uses(var) - u = LLVM.user(u) - push!(todo, (vcat(path,collect((convert(Cuint, v) for v in operands(var)[3:end]))), u)) - end + continue + end + if isa(var, LLVM.ConstantExpr) && opcode(var) == LLVM.API.LLVMAddrSpaceCast + for u in LLVM.uses(var) + u = LLVM.user(u) + push!(todo, (path, u)) + end + continue + end + if isa(var, LLVM.GetElementPtrInst) + if all(isa(v, LLVM.ConstantInt) for v in operands(var)[2:end]) + if convert(Cuint, operands(var)[2]) == 0 + for u in LLVM.uses(var) + u = LLVM.user(u) + push!( + todo, + ( + vcat( + path, + collect(( + convert(Cuint, v) for v in operands(var)[3:end] + )), + ), + u, + ), + ) end - continue end + continue end end + end return changed, newfns end # From https://llvm.org/doxygen/IR_2Instruction_8cpp_source.html#l00959 -function mayWriteToMemory(inst::LLVM.Instruction; err_is_readonly=false)::Bool +function mayWriteToMemory(inst::LLVM.Instruction; err_is_readonly = false)::Bool # we will ignore fense here if isa(inst, LLVM.StoreInst) return true @@ -1200,11 +1399,14 @@ function mayWriteToMemory(inst::LLVM.Instruction; err_is_readonly=false)::Bool end if isa(inst, LLVM.CallInst) || isa(inst, LLVM.InvokeInst) || isa(inst, LLVM.CallBrInst) idx = reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex) - count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx); - - Attrs = Base.unsafe_convert(Ptr{LLVM.API.LLVMAttributeRef}, Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef)*count)) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) + + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) - for j in 1:count + for j = 1:count attr = LLVM.Attribute(unsafe_load(Attrs, j)) if kind(attr) == kind(EnumAttribute("readnone")) return false @@ -1298,14 +1500,14 @@ function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String}) end end end - + changed = set_readonly!(fn) if length(calls) == 0 || hasUser return changed end - for c in calls + for c in calls parentf = LLVM.parent(LLVM.parent(c)) push!(next, LLVM.name(parentf)) LLVM.API.LLVMInstructionEraseFromParent(c) @@ -1317,7 +1519,8 @@ end function propagate_returned!(mod::LLVM.Module) globs = LLVM.GlobalVariable[] for g in globals(mod) - if linkage(g) == LLVM.API.LLVMInternalLinkage || linkage(g) == LLVM.API.LLVMPrivateLinkage + if linkage(g) == LLVM.API.LLVMInternalLinkage || + linkage(g) == LLVM.API.LLVMPrivateLinkage if !isconstant(g) continue end @@ -1344,19 +1547,33 @@ function propagate_returned!(mod::LLVM.Module) changed = true end attrs = collect(function_attributes(fn)) - prevent = any(kind(attr) == kind(StringAttribute("enzyme_preserve_primal")) for attr in attrs) + prevent = any( + kind(attr) == kind(StringAttribute("enzyme_preserve_primal")) for + attr in attrs + ) # if any(kind(attr) == kind(EnumAttribute("noinline")) for attr in attrs) # continue # end argn = nothing toremove = Int64[] for (i, arg) in enumerate(parameters(fn)) - if any(kind(attr) == kind(EnumAttribute("returned")) for attr in collect(parameter_attributes(fn, i))) + if any( + kind(attr) == kind(EnumAttribute("returned")) for + attr in collect(parameter_attributes(fn, i)) + ) argn = i end # remove unused sret-like - if !prevent && (linkage(fn) == LLVM.API.LLVMInternalLinkage || linkage(fn) == LLVM.API.LLVMPrivateLinkage) && any(kind(attr) == kind(EnumAttribute("nocapture")) for attr in collect(parameter_attributes(fn, i))) + if !prevent && + ( + linkage(fn) == LLVM.API.LLVMInternalLinkage || + linkage(fn) == LLVM.API.LLVMPrivateLinkage + ) && + any( + kind(attr) == kind(EnumAttribute("nocapture")) for + attr in collect(parameter_attributes(fn, i)) + ) val = nothing illegalUse = false torem = LLVM.Instruction[] @@ -1454,7 +1671,10 @@ function propagate_returned!(mod::LLVM.Module) end # interprocedural const prop from callers of arg - if !prevent && (linkage(fn) == LLVM.API.LLVMInternalLinkage || linkage(fn) == LLVM.API.LLVMPrivateLinkage) + if !prevent && ( + linkage(fn) == LLVM.API.LLVMInternalLinkage || + linkage(fn) == LLVM.API.LLVMPrivateLinkage + ) val = nothing illegalUse = false for u in LLVM.uses(fn) @@ -1479,9 +1699,9 @@ function propagate_returned!(mod::LLVM.Module) continue end @static if LLVM.version() >= v"12" - if isa(ops[i], LLVM.PoisonValue) - continue - end + if isa(ops[i], LLVM.PoisonValue) + continue + end end if ops[i] == arg continue @@ -1546,10 +1766,13 @@ function propagate_returned!(mod::LLVM.Module) end end if !baduse - push!(toremove, i-1) + push!(toremove, i - 1) end end - illegalUse = !(linkage(fn) == LLVM.API.LLVMInternalLinkage || linkage(fn) == LLVM.API.LLVMPrivateLinkage) + illegalUse = !( + linkage(fn) == LLVM.API.LLVMInternalLinkage || + linkage(fn) == LLVM.API.LLVMPrivateLinkage + ) hasAnyUse = false for u in LLVM.uses(fn) un = LLVM.user(u) @@ -1588,7 +1811,9 @@ function propagate_returned!(mod::LLVM.Module) end end #if the function return has no users whatsoever, remove it - if argn === nothing && !hasAnyUse && LLVM.return_type(LLVM.function_type(fn)) != LLVM.VoidType() + if argn === nothing && + !hasAnyUse && + LLVM.return_type(LLVM.function_type(fn)) != LLVM.VoidType() argn = -1 end if argn === nothing && length(toremove) == 0 @@ -1605,7 +1830,9 @@ function propagate_returned!(mod::LLVM.Module) un = LLVM.user(u) push!(next, LLVM.name(LLVM.parent(LLVM.parent(un)))) end - nfn = LLVM.Function(API.EnzymeCloneFunctionWithoutReturnOrArgs(fn, keepret, toremove)) + nfn = LLVM.Function( + API.EnzymeCloneFunctionWithoutReturnOrArgs(fn, keepret, toremove), + ) for u in LLVM.uses(fn) un = LLVM.user(u) push!(todo, un) @@ -1620,7 +1847,7 @@ function propagate_returned!(mod::LLVM.Module) eraseInst(mod, fn) changed = true catch - break + break end end if !changed @@ -1629,7 +1856,8 @@ function propagate_returned!(mod::LLVM.Module) todo = LLVM.Function[] for name in next fn = functions(mod)[name] - if linkage(fn) == LLVM.API.LLVMInternalLinkage || linkage(fn) == LLVM.API.LLVMPrivateLinkage + if linkage(fn) == LLVM.API.LLVMInternalLinkage || + linkage(fn) == LLVM.API.LLVMPrivateLinkage has_user = false for u in LLVM.uses(fn) has_user = true @@ -1651,11 +1879,11 @@ function detect_writeonly!(mod::LLVM.Module) end for (i, a) in enumerate(parameters(f)) if isa(value_type(a), LLVM.PointerType) - todo = Tuple{LLVM.Value, LLVM.Instruction}[] + todo = Tuple{LLVM.Value,LLVM.Instruction}[] for u in LLVM.uses(a) push!(todo, (a, LLVM.user(u))) end - seen = Set{Tuple{LLVM.Value, LLVM.Instruction}}() + seen = Set{Tuple{LLVM.Value,LLVM.Instruction}}() mayread = false maywrite = false while length(todo) > 0 @@ -1665,20 +1893,22 @@ function detect_writeonly!(mod::LLVM.Module) end push!(seen, cur) curv, curi = cur - + if isa(curi, LLVM.StoreInst) if operands(curi)[1] != curv maywrite = true continue end end - + if isa(curi, LLVM.LoadInst) mayread = true continue end - if isa(curi, LLVM.GetElementPtrInst) || isa(curi, LLVM.BitCastInst) || isa(curi, LLVM.AddrSpaceCastInst) + if isa(curi, LLVM.GetElementPtrInst) || + isa(curi, LLVM.BitCastInst) || + isa(curi, LLVM.AddrSpaceCastInst) for u in LLVM.uses(curi) push!(todo, (curi, LLVM.user(u))) end @@ -1687,20 +1917,47 @@ function detect_writeonly!(mod::LLVM.Module) mayread = true maywrite = true end - if any(map(k->kind(k)==kind(EnumAttribute("readnone")), collect(parameter_attributes(f, i)))) + if any( + map( + k -> kind(k) == kind(EnumAttribute("readnone")), + collect(parameter_attributes(f, i)), + ), + ) mayread = false maywrite = false end - if any(map(k->kind(k)==kind(EnumAttribute("readonly")), collect(parameter_attributes(f, i)))) + if any( + map( + k -> kind(k) == kind(EnumAttribute("readonly")), + collect(parameter_attributes(f, i)), + ), + ) maywrite = false end - if any(map(k->kind(k)==kind(EnumAttribute("writeonly")), collect(parameter_attributes(f, i)))) + if any( + map( + k -> kind(k) == kind(EnumAttribute("writeonly")), + collect(parameter_attributes(f, i)), + ), + ) mayread = false end - - LLVM.API.LLVMRemoveEnumAttributeAtIndex(f, LLVM.API.LLVMAttributeIndex(i), kind(EnumAttribute("readnone"))) - LLVM.API.LLVMRemoveEnumAttributeAtIndex(f, LLVM.API.LLVMAttributeIndex(i), kind(EnumAttribute("readonly"))) - LLVM.API.LLVMRemoveEnumAttributeAtIndex(f, LLVM.API.LLVMAttributeIndex(i), kind(EnumAttribute("writeonly"))) + + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + f, + LLVM.API.LLVMAttributeIndex(i), + kind(EnumAttribute("readnone")), + ) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + f, + LLVM.API.LLVMAttributeIndex(i), + kind(EnumAttribute("readonly")), + ) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + f, + LLVM.API.LLVMAttributeIndex(i), + kind(EnumAttribute("writeonly")), + ) if !mayread && !maywrite push!(parameter_attributes(f, i), LLVM.EnumAttribute("readnone", 0)) @@ -1752,17 +2009,26 @@ function validate_return_roots!(mod) if length(enzyme_srets) >= 1 && length(srets) == 0 @assert enzyme_srets[1] == 1 VT = LLVM.VoidType() - if length(enzyme_srets) == 1 && LLVM.return_type(LLVM.function_type(f)) == VT && length(enzyme_srets_v) == 0 + if length(enzyme_srets) == 1 && + LLVM.return_type(LLVM.function_type(f)) == VT && + length(enzyme_srets_v) == 0 # Upgrading to sret requires writeonly - if !any(kind(attr) == kind(EnumAttribute("writeonly")) for attr in collect(parameter_attributes(f, 1))) - msg = sprint() do io::IO - println(io, "Enzyme internal error (not writeonly sret)") - println(io, string(f)) - println(io, "collect(parameter_attributes(f, 1))=", collect(parameter_attributes(f, 1))) - end - throw(AssertionError(msg)) - end - + if !any( + kind(attr) == kind(EnumAttribute("writeonly")) for + attr in collect(parameter_attributes(f, 1)) + ) + msg = sprint() do io::IO + println(io, "Enzyme internal error (not writeonly sret)") + println(io, string(f)) + println( + io, + "collect(parameter_attributes(f, 1))=", + collect(parameter_attributes(f, 1)), + ) + end + throw(AssertionError(msg)) + end + alty = nothing for u in LLVM.uses(f) u = LLVM.user(u) @@ -1770,13 +2036,13 @@ function validate_return_roots!(mod) @assert LLVM.called_operand(u) == f alop = operands(u)[1] if !isa(alop, LLVM.AllocaInst) - msg = sprint() do io::IO - println(io, "Enzyme internal error (!isa(alop, LLVM.AllocaInst))") - println(io, "alop=", alop) - println(io, "u=", u) - println(io, "f=", string(f)) - end - throw(AssertionError(msg)) + msg = sprint() do io::IO + println(io, "Enzyme internal error (!isa(alop, LLVM.AllocaInst))") + println(io, "alop=", alop) + println(io, "u=", u) + println(io, "f=", string(f)) + end + throw(AssertionError(msg)) end @assert isa(alop, LLVM.AllocaInst) @@ -1791,8 +2057,17 @@ function validate_return_roots!(mod) else EnumAttribute("sret") end - LLVM.API.LLVMAddCallSiteAttribute(u, LLVM.API.LLVMAttributeIndex(1), attr) - LLVM.API.LLVMRemoveCallSiteStringAttribute(u, LLVM.API.LLVMAttributeIndex(1), "enzyme_sret", length("enzyme_sret")) + LLVM.API.LLVMAddCallSiteAttribute( + u, + LLVM.API.LLVMAttributeIndex(1), + attr, + ) + LLVM.API.LLVMRemoveCallSiteStringAttribute( + u, + LLVM.API.LLVMAttributeIndex(1), + "enzyme_sret", + length("enzyme_sret"), + ) end @assert alty !== nothing attr = if LLVM.version().major >= 12 @@ -1806,7 +2081,7 @@ function validate_return_roots!(mod) srets = [(1, attr)] enzyme_srets = Int[] else - + enzyme_srets2 = Int[] for idx in enzyme_srets alty = nothing @@ -1821,10 +2096,18 @@ function validate_return_roots!(mod) if any_jltypes(nty) bad = true end - LLVM.API.LLVMRemoveCallSiteStringAttribute(u, LLVM.API.LLVMAttributeIndex(idx), "enzyme_sret", length("enzyme_sret")) + LLVM.API.LLVMRemoveCallSiteStringAttribute( + u, + LLVM.API.LLVMAttributeIndex(idx), + "enzyme_sret", + length("enzyme_sret"), + ) end if !bad - delete!(parameter_attributes(f, idx), StringAttribute("enzyme_sret")) + delete!( + parameter_attributes(f, idx), + StringAttribute("enzyme_sret"), + ) else push!(enzyme_srets2, idx) end @@ -1832,16 +2115,16 @@ function validate_return_roots!(mod) enzyme_srets = enzyme_srets2 if length(enzyme_srets) != 0 - msg = sprint() do io::IO - println(io, "Enzyme internal error (length(enzyme_srets) != 0)") - println(io, "f=", string(f)) - println(io, "enzyme_srets=", enzyme_srets) - println(io, "enzyme_srets_v=", enzyme_srets_v) - println(io, "srets=", srets) - println(io, "rroots=", rroots) - println(io, "rroots_v=", rroots_v) - end - throw(AssertionError(msg)) + msg = sprint() do io::IO + println(io, "Enzyme internal error (length(enzyme_srets) != 0)") + println(io, "f=", string(f)) + println(io, "enzyme_srets=", enzyme_srets) + println(io, "enzyme_srets_v=", enzyme_srets_v) + println(io, "srets=", srets) + println(io, "rroots=", rroots) + println(io, "rroots_v=", rroots_v) + end + throw(AssertionError(msg)) end end end @@ -1860,7 +2143,7 @@ function validate_return_roots!(mod) end end -function checkNoAssumeFalse(mod, shouldshow=false) +function checkNoAssumeFalse(mod, shouldshow = false) for f in functions(mod) for bb in blocks(f), inst in instructions(bb) if !isa(inst, LLVM.CallInst) @@ -1885,7 +2168,8 @@ function checkNoAssumeFalse(mod, shouldshow=false) end end if isa(op, LLVM.ICmpInst) - if predicate_int(op) == LLVM.API.LLVMIntNE && operands(op)[1] == operands(op)[2] + if predicate_int(op) == LLVM.API.LLVMIntNE && + operands(op)[1] == operands(op)[2] msg = sprint() do io println(io, "Enzyme Internal Error: non-icmp assume condition") println(io, "mod=", string(mod)) @@ -1913,17 +2197,55 @@ function removeDeadArgs!(mod::LLVM.Module, tm) LLVM.run!(pm, mod) end # Prevent dead-arg-elimination of functions which we may require args for in the derivative - funcT = LLVM.FunctionType(LLVM.VoidType(), LLVMType[], vararg=true) + funcT = LLVM.FunctionType(LLVM.VoidType(), LLVMType[], vararg = true) if LLVM.version().major <= 15 - func, _ = get_function!(mod, "llvm.enzymefakeuse", funcT, [EnumAttribute("readnone"), EnumAttribute("nofree")]) - rfunc, _ = get_function!(mod, "llvm.enzymefakeread", funcT, [EnumAttribute("readonly"), EnumAttribute("nofree"), EnumAttribute("argmemonly")]) - sfunc, _ = get_function!(mod, "llvm.enzyme.sret_use", funcT, [EnumAttribute("readonly"), EnumAttribute("nofree"), EnumAttribute("argmemonly")]) + func, _ = get_function!( + mod, + "llvm.enzymefakeuse", + funcT, + [EnumAttribute("readnone"), EnumAttribute("nofree")], + ) + rfunc, _ = get_function!( + mod, + "llvm.enzymefakeread", + funcT, + [ + EnumAttribute("readonly"), + EnumAttribute("nofree"), + EnumAttribute("argmemonly"), + ], + ) + sfunc, _ = get_function!( + mod, + "llvm.enzyme.sret_use", + funcT, + [ + EnumAttribute("readonly"), + EnumAttribute("nofree"), + EnumAttribute("argmemonly"), + ], + ) else - func, _ = get_function!(mod, "llvm.enzymefakeuse", funcT, [EnumAttribute("memory", NoEffects.data), EnumAttribute("nofree")]) - rfunc, _ = get_function!(mod, "llvm.enzymefakeread", funcT, [EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")]) - sfunc, _ = get_function!(mod, "llvm.enzyme.sret_use", funcT, [EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")]) + func, _ = get_function!( + mod, + "llvm.enzymefakeuse", + funcT, + [EnumAttribute("memory", NoEffects.data), EnumAttribute("nofree")], + ) + rfunc, _ = get_function!( + mod, + "llvm.enzymefakeread", + funcT, + [EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")], + ) + sfunc, _ = get_function!( + mod, + "llvm.enzyme.sret_use", + funcT, + [EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")], + ) end - + for fn in functions(mod) if isempty(blocks(fn)) continue @@ -1933,7 +2255,12 @@ function removeDeadArgs!(mod::LLVM.Module, tm) # active both can occur on 4. If the original sret is removed (at index 1) we no longer need # to preserve this. for idx in (2, 3, 4) - if length(collect(parameters(fn))) >= idx && any( ( kind(attr) == kind(StringAttribute("enzymejl_returnRoots")) || kind(attr) == kind(StringAttribute("enzymejl_returnRoots_v"))) for attr in collect(parameter_attributes(fn, idx))) + if length(collect(parameters(fn))) >= idx && any( + ( + kind(attr) == kind(StringAttribute("enzymejl_returnRoots")) || + kind(attr) == kind(StringAttribute("enzymejl_returnRoots_v")) + ) for attr in collect(parameter_attributes(fn, idx)) + ) for u in LLVM.uses(fn) u = LLVM.user(u) @assert isa(u, LLVM.CallInst) @@ -1944,7 +2271,9 @@ function removeDeadArgs!(mod::LLVM.Module, tm) cl = call!(B, funcT, rfunc, LLVM.Value[inp]) if isa(value_type(inp), LLVM.PointerType) LLVM.API.LLVMAddCallSiteAttribute( - cl, LLVM.API.LLVMAttributeIndex(1), EnumAttribute("nocapture") + cl, + LLVM.API.LLVMAttributeIndex(1), + EnumAttribute("nocapture"), ) end end @@ -1960,7 +2289,13 @@ function removeDeadArgs!(mod::LLVM.Module, tm) continue end attrs = collect(parameter_attributes(fn, idx)) - if any( ( kind(attr) == sretkind || kind(attr) == kind(StringAttribute("enzyme_sret")) || kind(attr) == kind(StringAttribute("enzyme_sret_v")) ) for attr in attrs) + if any( + ( + kind(attr) == sretkind || + kind(attr) == kind(StringAttribute("enzyme_sret")) || + kind(attr) == kind(StringAttribute("enzyme_sret_v")) + ) for attr in attrs + ) for u in LLVM.uses(fn) u = LLVM.user(u) if isa(u, LLVM.ConstantExpr) @@ -1977,14 +2312,18 @@ function removeDeadArgs!(mod::LLVM.Module, tm) cl = call!(B, funcT, sfunc, LLVM.Value[inp]) if isa(value_type(inp), LLVM.PointerType) LLVM.API.LLVMAddCallSiteAttribute( - cl, LLVM.API.LLVMAttributeIndex(1), EnumAttribute("nocapture") + cl, + LLVM.API.LLVMAttributeIndex(1), + EnumAttribute("nocapture"), ) end end end end attrs = collect(function_attributes(fn)) - prevent = any(kind(attr) == kind(StringAttribute("enzyme_preserve_primal")) for attr in attrs) + prevent = any( + kind(attr) == kind(StringAttribute("enzyme_preserve_primal")) for attr in attrs + ) # && any(kind(attr) == kind(StringAttribute("enzyme_math")) for attr in attrs) if prevent B = IRBuilder() @@ -2009,7 +2348,7 @@ function removeDeadArgs!(mod::LLVM.Module, tm) API.EnzymeAddAttributorLegacyPass(pm) LLVM.run!(pm, mod) end - end + end end propagate_returned!(mod) ModulePassManager() do pm @@ -2088,7 +2427,7 @@ function optimize!(mod::LLVM.Module, tm) alloc_opt_tm!(pm, tm) loop_idiom!(pm) loop_rotate!(pm) - + loop_optimizations_tm!(pm, tm) instruction_combining!(pm) @@ -2099,7 +2438,7 @@ function optimize!(mod::LLVM.Module, tm) alloc_opt_tm!(pm, tm) scalar_repl_aggregates_ssa!(pm) # SSA variant? gvn!(pm) - + # This InstCombine needs to be after GVN # Otherwise it will generate load chains in GPU code... instruction_combining!(pm) @@ -2117,7 +2456,7 @@ function optimize!(mod::LLVM.Module, tm) jump_threading!(pm) correlated_value_propagation!(pm) # SLP_Vectorizer -- not for Enzyme - + LLVM.run!(pm, mod) aggressive_dce!(pm) @@ -2243,7 +2582,7 @@ function addMachinePasses!(pm, tm) gvn!(pm) end -function addJuliaLegalizationPasses!(pm, tm, lower_intrinsics=true) +function addJuliaLegalizationPasses!(pm, tm, lower_intrinsics = true) if lower_intrinsics # LowerPTLS removes an indirect call. As a result, it is likely to trigger # LLVM's devirtualization heuristics, which would result in the entire @@ -2267,7 +2606,7 @@ function addJuliaLegalizationPasses!(pm, tm, lower_intrinsics=true) sccp!(pm) # Remove dead use of ptls dce!(pm) - lower_ptls_tm!(pm, tm, #=dump_native=# false) + lower_ptls_tm!(pm, tm, false) #=dump_native=# instruction_combining!(pm) jl_inst_simplify!(pm) # Clean up write barrier and ptls lowering @@ -2278,7 +2617,7 @@ function addJuliaLegalizationPasses!(pm, tm, lower_intrinsics=true) end end -function post_optimze!(mod, tm, machine=true) +function post_optimze!(mod, tm, machine = true) addr13NoAlias(mod) removeDeadArgs!(mod, tm) for f in collect(functions(mod)) @@ -2289,7 +2628,14 @@ function post_optimze!(mod, tm, machine=true) end out_error = Ref{Cstring}() if LLVM.API.LLVMVerifyModule(mod, LLVM.API.LLVMReturnStatusAction, out_error) != 0 - throw(LLVM.LLVMException("broken gc calling conv fix\n"*string(unsafe_string(out_error[]))*"\n"*string(mod))) + throw( + LLVM.LLVMException( + "broken gc calling conv fix\n" * + string(unsafe_string(out_error[])) * + "\n" * + string(mod), + ), + ) end LLVM.ModulePassManager() do pm addTargetPasses!(pm, tm, LLVM.triple(mod)) diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index 482a961b52..4b8f2d202a 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -3,7 +3,7 @@ module JIT using LLVM using Libdl -import LLVM:TargetMachine +import LLVM: TargetMachine import GPUCompiler import ..Compiler @@ -13,8 +13,8 @@ export get_trampoline struct CompilerInstance jit::LLVM.JuliaOJIT - lctm::Union{LLVM.LazyCallThroughManager, Nothing} - ism::Union{LLVM.IndirectStubsManager, Nothing} + lctm::Union{LLVM.LazyCallThroughManager,Nothing} + ism::Union{LLVM.IndirectStubsManager,Nothing} end function LLVM.dispose(ci::CompilerInstance) @@ -35,24 +35,24 @@ get_tm() = tm[] get_jit() = jit[].jit function absolute_symbol_materialization(name, ptr) - address = LLVM.API.LLVMOrcJITTargetAddress(reinterpret(UInt, ptr)) - flags = LLVM.API.LLVMJITSymbolFlags(LLVM.API.LLVMJITSymbolGenericFlagsExported, 0) - symbol = LLVM.API.LLVMJITEvaluatedSymbol(address, flags) - gv = if LLVM.version() >= v"15" - LLVM.API.LLVMOrcCSymbolMapPair(name, symbol) - else - LLVM.API.LLVMJITCSymbolMapPair(name, symbol) - end - return LLVM.absolute_symbols(Ref(gv)) + address = LLVM.API.LLVMOrcJITTargetAddress(reinterpret(UInt, ptr)) + flags = LLVM.API.LLVMJITSymbolFlags(LLVM.API.LLVMJITSymbolGenericFlagsExported, 0) + symbol = LLVM.API.LLVMJITEvaluatedSymbol(address, flags) + gv = if LLVM.version() >= v"15" + LLVM.API.LLVMOrcCSymbolMapPair(name, symbol) + else + LLVM.API.LLVMJITCSymbolMapPair(name, symbol) + end + return LLVM.absolute_symbols(Ref(gv)) end function define_absolute_symbol(jd, name) - ptr = LLVM.find_symbol(name) - if ptr !== C_NULL - LLVM.define(jd, absolute_symbol_materialization(name, ptr)) - return true - end - return false + ptr = LLVM.find_symbol(name) + if ptr !== C_NULL + LLVM.define(jd, absolute_symbol_materialization(name, ptr)) + return true + end + return false end function __init__() @@ -68,7 +68,7 @@ function __init__() tempTM = LLVM.JITTargetMachine(LLVM.triple(), cpu_name(), cpu_features(); optlevel) LLVM.asm_verbosity!(tempTM, true) tm[] = tempTM - + lljit = JuliaOJIT() jd_main = JITDylib(lljit) @@ -77,10 +77,10 @@ function __init__() dg = LLVM.CreateDynamicLibrarySearchGeneratorForProcess(prefix) LLVM.add!(jd_main, dg) - if Sys.iswindows() && Int === Int64 - # TODO can we check isGNU? - define_absolute_symbol(jd_main, mangle(lljit, "___chkstk_ms")) - end + if Sys.iswindows() && Int === Int64 + # TODO can we check isGNU? + define_absolute_symbol(jd_main, mangle(lljit, "___chkstk_ms")) + end es = ExecutionSession(lljit) try @@ -95,12 +95,18 @@ function __init__() hnd = unsafe_load(cglobal(:jl_libjulia_handle, Ptr{Cvoid})) for (k, v) in Compiler.JuliaGlobalNameMap ptr = unsafe_load(Base.reinterpret(Ptr{Ptr{Cvoid}}, Libdl.dlsym(hnd, k))) - LLVM.define(jd_main, absolute_symbol_materialization(mangle(lljit, "ejl_"*k), ptr)) + LLVM.define( + jd_main, + absolute_symbol_materialization(mangle(lljit, "ejl_" * k), ptr), + ) end for (k, v) in Compiler.JuliaEnzymeNameMap ptr = Compiler.unsafe_to_ptr(v) - LLVM.define(jd_main, absolute_symbol_materialization(mangle(lljit, "ejl_"*k), ptr)) + LLVM.define( + jd_main, + absolute_symbol_materialization(mangle(lljit, "ejl_" * k), ptr), + ) end atexit() do @@ -123,13 +129,15 @@ end function add_trampoline!(jd, (lljit, lctm, ism), entry, target) flags = LLVM.API.LLVMJITSymbolFlags( - LLVM.API.LLVMJITSymbolGenericFlagsCallable | - LLVM.API.LLVMJITSymbolGenericFlagsExported, 0) + LLVM.API.LLVMJITSymbolGenericFlagsCallable | + LLVM.API.LLVMJITSymbolGenericFlagsExported, + 0, + ) alias = LLVM.API.LLVMOrcCSymbolAliasMapPair( - mangle(lljit, entry), - LLVM.API.LLVMOrcCSymbolAliasMapEntry( - mangle(lljit, target), flags)) + mangle(lljit, entry), + LLVM.API.LLVMOrcCSymbolAliasMapEntry(mangle(lljit, target), flags), + ) mu = LLVM.reexports(lctm, ism, jd, [alias]) LLVM.define(jd, mu) @@ -140,8 +148,8 @@ end function get_trampoline(job) compiler = jit[] lljit = compiler.jit - lctm = compiler.lctm - ism = compiler.ism + lctm = compiler.lctm + ism = compiler.ism if lctm === nothing || ism === nothing error("Delayed compilation not available.") @@ -155,8 +163,7 @@ function get_trampoline(job) sym = String(gensym(:func)) _sym = String(gensym(:func)) - addr = add_trampoline!(jd, (lljit, lctm, ism), - _sym, sym) + addr = add_trampoline!(jd, (lljit, lctm, ism), _sym, sym) # 3. add MU that will call back into the compiler function materialize(mr) @@ -193,16 +200,14 @@ function get_trampoline(job) function discard(jd, sym) end flags = LLVM.API.LLVMJITSymbolFlags( - LLVM.API.LLVMJITSymbolGenericFlagsCallable | - LLVM.API.LLVMJITSymbolGenericFlagsExported, 0) + LLVM.API.LLVMJITSymbolGenericFlagsCallable | + LLVM.API.LLVMJITSymbolGenericFlagsExported, + 0, + ) - symbols = [ - LLVM.API.LLVMOrcCSymbolFlagsMapPair( - mangle(lljit, sym), flags), - ] + symbols = [LLVM.API.LLVMOrcCSymbolFlagsMapPair(mangle(lljit, sym), flags)] - mu = LLVM.CustomMaterializationUnit(sym, symbols, - materialize, discard) + mu = LLVM.CustomMaterializationUnit(sym, symbols, materialize, discard) LLVM.define(jd, mu) return addr end diff --git a/src/compiler/passes.jl b/src/compiler/passes.jl index e7e3c3c0a7..403b2bfa04 100644 --- a/src/compiler/passes.jl +++ b/src/compiler/passes.jl @@ -1,5 +1,5 @@ function reinsert_gcmarker_pass!(fn::LLVM.Function) - reinsert_gcmarker!(fn) + reinsert_gcmarker!(fn) unique_gcmarker!(fn) return true end diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl index 583a6f2f68..304372951c 100644 --- a/src/compiler/reflection.jl +++ b/src/compiler/reflection.jl @@ -1,11 +1,27 @@ -function get_job(@nospecialize(func), @nospecialize(A), @nospecialize(types); - run_enzyme::Bool=true, mode::API.CDerivativeMode=API.DEM_ReverseModeCombined, dupClosure::Bool=false, argwrap::Bool=true, width::Int=1, modifiedBetween=nothing, returnPrimal::Bool=false, augmentedInit=false, world=nothing, ABI=DefaultABI, ErrIfFuncWritten=false, RuntimeActivity=true, kwargs...) +function get_job( + @nospecialize(func), + @nospecialize(A), + @nospecialize(types); + run_enzyme::Bool = true, + mode::API.CDerivativeMode = API.DEM_ReverseModeCombined, + dupClosure::Bool = false, + argwrap::Bool = true, + width::Int = 1, + modifiedBetween = nothing, + returnPrimal::Bool = false, + augmentedInit = false, + world = nothing, + ABI = DefaultABI, + ErrIfFuncWritten = false, + RuntimeActivity = true, + kwargs..., +) - tt = Tuple{map(eltype, types.parameters)...} + tt = Tuple{map(eltype, types.parameters)...} if world === nothing world = codegen_world_age(Core.Typeof(func), tt) end - + primal = fspec(Core.Typeof(func), types, world) rt = Core.Compiler.return_type(func, tt, world) @@ -15,16 +31,40 @@ function get_job(@nospecialize(func), @nospecialize(A), @nospecialize(types); defaultMod = mode != API.DEM_ReverseModeCombined && mode != API.DEM_ForwardMode modifiedBetween = (defaultMod, (defaultMod for _ in types.parameters)...) end - params = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){Core.Typeof(func)}, types.parameters...}, mode, width, rt, run_enzyme, argwrap, modifiedBetween, returnPrimal, augmentedInit, Compiler.UnknownTapeType, ABI, ErrIfFuncWritten, RuntimeActivity) - return Compiler.CompilerJob(primal, CompilerConfig(target, params; kernel=false), world) + params = Compiler.EnzymeCompilerParams( + Tuple{(dupClosure ? Duplicated : Const){Core.Typeof(func)},types.parameters...}, + mode, + width, + rt, + run_enzyme, + argwrap, + modifiedBetween, + returnPrimal, + augmentedInit, + Compiler.UnknownTapeType, + ABI, + ErrIfFuncWritten, + RuntimeActivity, + ) + return Compiler.CompilerJob( + primal, + CompilerConfig(target, params; kernel = false), + world, + ) end -function reflect(@nospecialize(func), @nospecialize(A), @nospecialize(types); - optimize::Bool=true, second_stage::Bool=true, kwargs...) +function reflect( + @nospecialize(func), + @nospecialize(A), + @nospecialize(types); + optimize::Bool = true, + second_stage::Bool = true, + kwargs..., +) job = get_job(func, A, types; kwargs...) # Codegen the primal function and all its dependency in one module - mod, meta = Compiler.codegen(:llvm, job; optimize #= validate=false =#) + mod, meta = Compiler.codegen(:llvm, job; optimize) #= validate=false =# if second_stage post_optimze!(mod, JIT.get_tm()) @@ -40,33 +80,62 @@ struct jl_llvmf_dump F::LLVM.API.LLVMValueRef end -function enzyme_code_llvm(io::IO, @nospecialize(func), @nospecialize(A), @nospecialize(types); - optimize::Bool=true, run_enzyme::Bool=true, second_stage::Bool=true, - raw::Bool=false, debuginfo::Symbol=:default, dump_module::Bool=false, mode=API.DEM_ReverseModeCombined) +function enzyme_code_llvm( + io::IO, + @nospecialize(func), + @nospecialize(A), + @nospecialize(types); + optimize::Bool = true, + run_enzyme::Bool = true, + second_stage::Bool = true, + raw::Bool = false, + debuginfo::Symbol = :default, + dump_module::Bool = false, + mode = API.DEM_ReverseModeCombined, +) JuliaContext() do ctx entry_fn, ir = reflect(func, A, types; optimize, run_enzyme, second_stage, mode) ts_mod = ThreadSafeModule(ir) GC.@preserve ts_mod entry_fn begin value = Ref(jl_llvmf_dump(ts_mod.ref, entry_fn.ref)) - str = ccall(:jl_dump_function_ir, Ref{String}, - (Ptr{jl_llvmf_dump}, Bool, Bool, Ptr{UInt8}), - value, !raw, dump_module, debuginfo) + str = ccall( + :jl_dump_function_ir, + Ref{String}, + (Ptr{jl_llvmf_dump}, Bool, Bool, Ptr{UInt8}), + value, + !raw, + dump_module, + debuginfo, + ) end print(io, str) end end -enzyme_code_llvm(@nospecialize(func), @nospecialize(A), @nospecialize(types); kwargs...) = enzyme_code_llvm(stdout, func, A, types; kwargs...) +enzyme_code_llvm(@nospecialize(func), @nospecialize(A), @nospecialize(types); kwargs...) = + enzyme_code_llvm(stdout, func, A, types; kwargs...) -function enzyme_code_native(io::IO, @nospecialize(func), @nospecialize(A), @nospecialize(types); mode=API.DEM_ReverseModeCombined) +function enzyme_code_native( + io::IO, + @nospecialize(func), + @nospecialize(A), + @nospecialize(types); + mode = API.DEM_ReverseModeCombined, +) JuliaContext() do ctx _, mod = reflect(func, A, types; mode) str = String(LLVM.emit(JIT.get_tm(), mod, LLVM.API.LLVMAssemblyFile)) print(io, str) end end -enzyme_code_native(@nospecialize(func), @nospecialize(A), @nospecialize(types); kwargs...) = enzyme_code_native(stdout, func, A, types; kwargs...) +enzyme_code_native(@nospecialize(func), @nospecialize(A), @nospecialize(types); kwargs...) = + enzyme_code_native(stdout, func, A, types; kwargs...) -function enzyme_code_typed(@nospecialize(func), @nospecialize(A), @nospecialize(types); kwargs...) +function enzyme_code_typed( + @nospecialize(func), + @nospecialize(A), + @nospecialize(types); + kwargs..., +) job = get_job(func, A, types; kwargs...) GPUCompiler.code_typed(job; kwargs...) end diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index cde5d2cade..8a801067eb 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -2,16 +2,9 @@ struct MemoryEffect data::UInt32 end -@enum(ModRefInfo, - MRI_NoModRef = 0, - MRI_Ref = 1, - MRI_Mod = 2, - MRI_ModRef = 3) +@enum(ModRefInfo, MRI_NoModRef = 0, MRI_Ref = 1, MRI_Mod = 2, MRI_ModRef = 3) -@enum(IRMemLocation, - ArgMem = 0, - InaccessibleMem = 1, - Other = 2) +@enum(IRMemLocation, ArgMem = 0, InaccessibleMem = 1, Other = 2) const BitsPerLoc = UInt32(2) const LocMask = UInt32((1 << BitsPerLoc) - 1) @@ -27,14 +20,30 @@ end function Base.:&(lhs::ModRefInfo, rhs::ModRefInfo) ModRefInfo(UInt32(lhs) & UInt32(rhs)) end -const AllEffects = MemoryEffect((MRI_ModRef << getLocationPos(ArgMem)) | (MRI_ModRef << getLocationPos(InaccessibleMem)) | (MRI_ModRef << getLocationPos(Other))) -const ReadOnlyEffects = MemoryEffect((MRI_Ref << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_Ref << getLocationPos(Other))) -const ReadOnlyArgMemEffects = MemoryEffect((MRI_Ref << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))) -const NoEffects = MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))) +const AllEffects = MemoryEffect( + (MRI_ModRef << getLocationPos(ArgMem)) | + (MRI_ModRef << getLocationPos(InaccessibleMem)) | + (MRI_ModRef << getLocationPos(Other)), +) +const ReadOnlyEffects = MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_Ref << getLocationPos(InaccessibleMem)) | + (MRI_Ref << getLocationPos(Other)), +) +const ReadOnlyArgMemEffects = MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), +) +const NoEffects = MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), +) # Get ModRefInfo for any location. function getModRef(effect::MemoryEffect, loc::IRMemLocation)::ModRefInfo - ModRefInfo((effect.data >> getLocationPos(loc)) & LocMask) + ModRefInfo((effect.data >> getLocationPos(loc)) & LocMask) end function getModRef(effect::MemoryEffect)::ModRefInfo @@ -54,7 +63,7 @@ end function setModRef(effect::MemoryEffect)::MemoryEffect for loc in (ArgMem, InaccessibleMem, Other) - effect = setModRef(effect, mri)= getModRef(effect, loc) + effect = setModRef(effect, mri) = getModRef(effect, loc) end return effect end @@ -93,12 +102,12 @@ function is_writeonly(mri::ModRefInfo) end for n in (:is_readonly, :is_readnone, :is_writeonly) -@eval begin - function $n(memeffect::MemoryEffect) - return $n(getModRef(memeffect)) + @eval begin + function $n(memeffect::MemoryEffect) + return $n(getModRef(memeffect)) + end end end -end function is_noreturn(f::LLVM.Function) for attr in collect(function_attributes(f)) @@ -120,7 +129,8 @@ function is_readonly(f::LLVM.Function) if intr == LLVM.Intrinsic("llvm.assume").id return true end - if LLVM.name(f) == "llvm.julia.gc_preserve_begin" || LLVM.name(f) == "llvm.julia.gc_preserve_end" + if LLVM.name(f) == "llvm.julia.gc_preserve_begin" || + LLVM.name(f) == "llvm.julia.gc_preserve_end" return true end for attr in collect(function_attributes(f)) @@ -131,12 +141,12 @@ function is_readonly(f::LLVM.Function) return true end if LLVM.version().major > 15 - if kind(attr) == kind(EnumAttribute("memory")) - if is_readonly(MemoryEffect(value(attr))) - return true + if kind(attr) == kind(EnumAttribute("memory")) + if is_readonly(MemoryEffect(value(attr))) + return true + end end end - end end return false end @@ -152,7 +162,8 @@ function is_readnone(f::LLVM.Function) if intr == LLVM.Intrinsic("llvm.assume").id return true end - if LLVM.name(f) == "llvm.julia.gc_preserve_begin" || LLVM.name(f) == "llvm.julia.gc_preserve_end" + if LLVM.name(f) == "llvm.julia.gc_preserve_begin" || + LLVM.name(f) == "llvm.julia.gc_preserve_end" return true end for attr in collect(function_attributes(cur)) @@ -160,12 +171,12 @@ function is_readnone(f::LLVM.Function) return true end if LLVM.version().major > 15 - if kind(attr) == kind(EnumAttribute("memory")) - if is_readnone(MemoryEffect(value(attr))) - return true + if kind(attr) == kind(EnumAttribute("memory")) + if is_readnone(MemoryEffect(value(attr))) + return true + end end end - end end return false end @@ -181,7 +192,8 @@ function is_writeonly(f::LLVM.Function) if intr == LLVM.Intrinsic("llvm.assume").id return true end - if LLVM.name(f) == "llvm.julia.gc_preserve_begin" || LLVM.name(f) == "llvm.julia.gc_preserve_end" + if LLVM.name(f) == "llvm.julia.gc_preserve_begin" || + LLVM.name(f) == "llvm.julia.gc_preserve_end" return true end for attr in collect(function_attributes(cur)) @@ -192,12 +204,12 @@ function is_writeonly(f::LLVM.Function) return true end if LLVM.version().major > 15 - if kind(attr) == kind(EnumAttribute("memory")) - if is_writeonly(MemoryEffect(value(attr))) - return true + if kind(attr) == kind(EnumAttribute("memory")) + if is_writeonly(MemoryEffect(value(attr))) + return true + end end end - end end return false end @@ -205,7 +217,8 @@ end function set_readonly!(fn::LLVM.Function) attrs = collect(function_attributes(fn)) if LLVM.version().major <= 15 - if !any(kind(attr) == kind(EnumAttribute("readonly")) for attr in attrs) && !any(kind(attr) == kind(EnumAttribute("readnone")) for attr in attrs) + if !any(kind(attr) == kind(EnumAttribute("readonly")) for attr in attrs) && + !any(kind(attr) == kind(EnumAttribute("readnone")) for attr in attrs) if any(kind(attr) == kind(EnumAttribute("writeonly")) for attr in attrs) delete!(function_attributes(fn), EnumAttribute("writeonly")) push!(function_attributes(fn), EnumAttribute("readnone")) @@ -224,12 +237,20 @@ function set_readonly!(fn::LLVM.Function) return old != eff end end - push!(function_attributes(fn), EnumAttribute("memory", set_readonly(AllEffects).data)) + push!( + function_attributes(fn), + EnumAttribute("memory", set_readonly(AllEffects).data), + ) return true end end -function get_function!(mod::LLVM.Module, name::AbstractString, FT::LLVM.FunctionType, attrs=[]) +function get_function!( + mod::LLVM.Module, + name::AbstractString, + FT::LLVM.FunctionType, + attrs = [], +) if haskey(functions(mod), name) F = functions(mod)[name] PT = LLVM.PointerType(FT) @@ -261,8 +282,12 @@ T_ppjlvalue() = LLVM.PointerType(LLVM.PointerType(LLVM.StructType(LLVMType[]))) return v end -function declare_pgcstack!(mod) - get_function!(mod, "julia.get_pgcstack", LLVM.FunctionType(LLVM.PointerType(T_ppjlvalue()))) +function declare_pgcstack!(mod) + get_function!( + mod, + "julia.get_pgcstack", + LLVM.FunctionType(LLVM.PointerType(T_ppjlvalue())), + ) end function emit_pgcstack(B) @@ -285,14 +310,19 @@ function get_pgcstack(func) return nothing end -function reinsert_gcmarker!(func, PB=nothing) +function reinsert_gcmarker!(func, PB = nothing) for (i, v) in enumerate(parameters(func)) - if any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(func, i)))) + if any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(func, i)), + ), + ) return v end end - pgs = get_pgcstack(func) + pgs = get_pgcstack(func) if pgs === nothing context(LLVM.parent(func)) B = IRBuilder() @@ -303,13 +333,13 @@ function reinsert_gcmarker!(func, PB=nothing) position!(B, entry_bb) end emit_pgcstack(B) - else + else entry_bb = first(blocks(func)) fst = first(instructions(entry_bb)) if fst != pgs API.moveBefore(pgs, fst, PB === nothing ? C_NULL : PB.ref) end - pgs + pgs end end @@ -332,7 +362,7 @@ function unique_gcmarker!(func) end end if length(found) > 1 - for i in 2:length(found) + for i = 2:length(found) LLVM.replace_uses!(found[i], found[1]) ops = LLVM.collect(operands(found[i])) eraseInst(entry_bb, found[i]) @@ -341,7 +371,8 @@ function unique_gcmarker!(func) return nothing end -@inline AnonymousStruct(::Type{U}) where U<:Tuple = NamedTuple{ntuple(i->Symbol(i), Val(length(U.parameters))), U} +@inline AnonymousStruct(::Type{U}) where {U<:Tuple} = + NamedTuple{ntuple(i -> Symbol(i), Val(length(U.parameters))),U} # recursively compute the eltype type indexed by idx[0], idx[1], ... function recursive_eltype(val::LLVM.Value, idxs::Vector{Cuint}) @@ -359,7 +390,15 @@ end # Fix calling convention within julia that Tuple{Float,Float} ->[2 x float] rather than {float, float} # and that Bool -> i8, not i1 -function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev::LLVM.Value=LLVM.UndefValue(tape), lidxs::Vector{Cuint}=Cuint[], ridxs::Vector{Cuint}=Cuint[], emesg=nothing)::LLVM.Value +function calling_conv_fixup( + builder, + val::LLVM.Value, + tape::LLVM.LLVMType, + prev::LLVM.Value = LLVM.UndefValue(tape), + lidxs::Vector{Cuint} = Cuint[], + ridxs::Vector{Cuint} = Cuint[], + emesg = nothing, +)::LLVM.Value ctype = recursive_eltype(val, lidxs) if ctype == tape if length(lidxs) != 0 @@ -377,9 +416,9 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev: @assert length(ctype) == length(elements(tape)) for (i, ty) in enumerate(elements(tape)) ln = copy(lidxs) - push!(ln, i-1) + push!(ln, i - 1) rn = copy(ridxs) - push!(rn, i-1) + push!(rn, i - 1) prev = calling_conv_fixup(builder, val, ty, prev, ln, rn, emesg) end return prev @@ -388,9 +427,9 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev: @assert length(elements(ctype)) == length(elements(tape)) for (i, ty) in enumerate(elements(tape)) ln = copy(lidxs) - push!(ln, i-1) + push!(ln, i - 1) rn = copy(ridxs) - push!(rn, i-1) + push!(rn, i - 1) prev = calling_conv_fixup(builder, val, ty, prev, ln, rn, emesg) end return prev @@ -398,29 +437,31 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev: elseif isa(tape, LLVM.ArrayType) if isa(ctype, LLVM.ArrayType) @assert length(ctype) == length(tape) - for i in 1:length(tape) + for i = 1:length(tape) ln = copy(lidxs) - push!(ln, i-1) + push!(ln, i - 1) rn = copy(ridxs) - push!(rn, i-1) + push!(rn, i - 1) prev = calling_conv_fixup(builder, val, eltype(tape), prev, ln, rn, emesg) end return prev end if isa(ctype, LLVM.StructType) @assert length(elements(ctype)) == length(tape) - for i in 1:length(tape) + for i = 1:length(tape) ln = copy(lidxs) - push!(ln, i-1) + push!(ln, i - 1) rn = copy(ridxs) - push!(rn, i-1) + push!(rn, i - 1) prev = calling_conv_fixup(builder, val, eltype(tape), prev, ln, rn, emesg) end return prev end end - if isa(tape, LLVM.IntegerType) && LLVM.width(tape) == 1 && LLVM.width(ctype) != LLVM.width(tape) + if isa(tape, LLVM.IntegerType) && + LLVM.width(tape) == 1 && + LLVM.width(ctype) != LLVM.width(tape) if length(lidxs) != 0 val = API.e_extract_value!(builder, val, lidxs) end @@ -431,7 +472,9 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev: val end end - if isa(tape, LLVM.PointerType) && isa(ctype, LLVM.PointerType) && LLVM.addrspace(tape) == LLVM.addrspace(ctype) + if isa(tape, LLVM.PointerType) && + isa(ctype, LLVM.PointerType) && + LLVM.addrspace(tape) == LLVM.addrspace(ctype) if length(lidxs) != 0 val = API.e_extract_value!(builder, val, lidxs) end @@ -451,7 +494,7 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev: msg2 = sprint() do io println(io, "Enzyme Internal Error: Illegal calling convention fixup") - if emesg !== nothing + if emesg !== nothing emesg(io) end println(io, "ctype = ", ctype) @@ -461,7 +504,11 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev: println(io, "lidxs = ", lidxs) println(io, "ridxs = ", ridxs) println(io, "tape_type(tape) = ", tape_type(tape)) - println(io, "convert(LLVMType, tape_type(tape)) = ", convert(LLVM.LLVMType, tape_type(tape); allow_boxed=true)) + println( + io, + "convert(LLVMType, tape_type(tape)) = ", + convert(LLVM.LLVMType, tape_type(tape); allow_boxed = true), + ) end throw(AssertionError(msg2)) end diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 3df37be117..b672c50f57 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -3,105 +3,139 @@ using ObjectFile using Libdl module FFI - using LLVM - module BLASSupport - # TODO: LAPACK handling - using LinearAlgebra - using ObjectFile - using Libdl - function __init__() - global blas_handle = Libdl.dlopen(BLAS.libblastrampoline) - end - function get_blas_symbols() - symbols = BLAS.get_config().exported_symbols - if BLAS.USE_BLAS64 - return map(n->n*"64_", symbols) - end - return symbols - end - - function lookup_blas_symbol(name) - Libdl.dlsym(blas_handle::Ptr{Cvoid}, name; throw_error=false) - end +using LLVM +module BLASSupport +# TODO: LAPACK handling +using LinearAlgebra +using ObjectFile +using Libdl +function __init__() + global blas_handle = Libdl.dlopen(BLAS.libblastrampoline) +end +function get_blas_symbols() + symbols = BLAS.get_config().exported_symbols + if BLAS.USE_BLAS64 + return map(n -> n * "64_", symbols) end + return symbols +end - const ptr_map = Dict{Ptr{Cvoid},String}() - - function __init__() - known_names = ( - "jl_alloc_array_1d", "jl_alloc_array_2d", "jl_alloc_array_3d", - "ijl_alloc_array_1d", "ijl_alloc_array_2d", "ijl_alloc_array_3d", - "jl_new_array", "ijl_new_array", - "jl_array_copy", "ijl_array_copy", - "jl_alloc_string", - "jl_in_threaded_region", "jl_enter_threaded_region", "jl_exit_threaded_region", "jl_set_task_tid", "jl_new_task", - "malloc", "memmove", "memcpy", "memset", - "jl_array_grow_beg", "ijl_array_grow_beg", - "jl_array_grow_end", "ijl_array_grow_end", - "jl_array_grow_at", "ijl_array_grow_at", - "jl_array_del_beg", "ijl_array_del_beg", - "jl_array_del_end", "ijl_array_del_end", - "jl_array_del_at", "ijl_array_del_at", - "jl_array_ptr", "ijl_array_ptr", - "jl_value_ptr", "jl_get_ptls_states", "jl_gc_add_finalizer_th", - "jl_symbol_n", "jl_", "jl_object_id", - "jl_reshape_array","ijl_reshape_array", - "jl_matching_methods", "ijl_matching_methods", - "jl_array_sizehint", "ijl_array_sizehint", - "jl_get_keyword_sorter", "ijl_get_keyword_sorter", - "jl_ptr_to_array", - "jl_box_float32", - "ijl_box_float32", - "jl_box_float64", - "ijl_box_float64", - "jl_ptr_to_array_1d", - "jl_eqtable_get", "ijl_eqtable_get", - "memcmp","memchr", - "jl_get_nth_field_checked", "ijl_get_nth_field_checked", - "jl_stored_inline", - "ijl_stored_inline", - "jl_array_isassigned", "ijl_array_isassigned", - "jl_array_ptr_copy", "ijl_array_ptr_copy", - "jl_array_typetagdata", "ijl_array_typetagdata", - "jl_idtable_rehash" - ) - for name in known_names - sym = LLVM.find_symbol(name) - if sym == C_NULL +function lookup_blas_symbol(name) + Libdl.dlsym(blas_handle::Ptr{Cvoid}, name; throw_error = false) +end +end + +const ptr_map = Dict{Ptr{Cvoid},String}() + +function __init__() + known_names = ( + "jl_alloc_array_1d", + "jl_alloc_array_2d", + "jl_alloc_array_3d", + "ijl_alloc_array_1d", + "ijl_alloc_array_2d", + "ijl_alloc_array_3d", + "jl_new_array", + "ijl_new_array", + "jl_array_copy", + "ijl_array_copy", + "jl_alloc_string", + "jl_in_threaded_region", + "jl_enter_threaded_region", + "jl_exit_threaded_region", + "jl_set_task_tid", + "jl_new_task", + "malloc", + "memmove", + "memcpy", + "memset", + "jl_array_grow_beg", + "ijl_array_grow_beg", + "jl_array_grow_end", + "ijl_array_grow_end", + "jl_array_grow_at", + "ijl_array_grow_at", + "jl_array_del_beg", + "ijl_array_del_beg", + "jl_array_del_end", + "ijl_array_del_end", + "jl_array_del_at", + "ijl_array_del_at", + "jl_array_ptr", + "ijl_array_ptr", + "jl_value_ptr", + "jl_get_ptls_states", + "jl_gc_add_finalizer_th", + "jl_symbol_n", + "jl_", + "jl_object_id", + "jl_reshape_array", + "ijl_reshape_array", + "jl_matching_methods", + "ijl_matching_methods", + "jl_array_sizehint", + "ijl_array_sizehint", + "jl_get_keyword_sorter", + "ijl_get_keyword_sorter", + "jl_ptr_to_array", + "jl_box_float32", + "ijl_box_float32", + "jl_box_float64", + "ijl_box_float64", + "jl_ptr_to_array_1d", + "jl_eqtable_get", + "ijl_eqtable_get", + "memcmp", + "memchr", + "jl_get_nth_field_checked", + "ijl_get_nth_field_checked", + "jl_stored_inline", + "ijl_stored_inline", + "jl_array_isassigned", + "ijl_array_isassigned", + "jl_array_ptr_copy", + "ijl_array_ptr_copy", + "jl_array_typetagdata", + "ijl_array_typetagdata", + "jl_idtable_rehash", + ) + for name in known_names + sym = LLVM.find_symbol(name) + if sym == C_NULL + continue + end + if haskey(ptr_map, sym) + # On MacOS memcpy and memmove seem to collide? + if name == "memcpy" continue end - if haskey(ptr_map, sym) - # On MacOS memcpy and memmove seem to collide? - if name == "memcpy" - continue - end - end - @assert !haskey(ptr_map, sym) - ptr_map[sym] = name end - for sym in BLASSupport.get_blas_symbols() - ptr = BLASSupport.lookup_blas_symbol(sym) - if ptr !== nothing - if haskey(ptr_map, ptr) - if ptr_map[ptr] != sym - @warn "Duplicated symbol in ptr_map" ptr, sym, ptr_map[ptr] - end - continue + @assert !haskey(ptr_map, sym) + ptr_map[sym] = name + end + for sym in BLASSupport.get_blas_symbols() + ptr = BLASSupport.lookup_blas_symbol(sym) + if ptr !== nothing + if haskey(ptr_map, ptr) + if ptr_map[ptr] != sym + @warn "Duplicated symbol in ptr_map" ptr, sym, ptr_map[ptr] end - ptr_map[ptr] = sym + continue end + ptr_map[ptr] = sym end end +end - function memoize!(ptr, fn) - fn = get(ptr_map, ptr, fn) - if !haskey(ptr_map, ptr) - ptr_map[ptr] = fn - else - @assert ptr_map[ptr] == fn - end - return fn +function memoize!(ptr, fn) + fn = get(ptr_map, ptr, fn) + if !haskey(ptr_map, ptr) + ptr_map[ptr] = fn + else + @assert ptr_map[ptr] == fn end + return fn +end end import GPUCompiler: IRError, InvalidIRError @@ -111,7 +145,15 @@ function restore_lookups(mod::LLVM.Module) for (v, k) in FFI.ptr_map if haskey(functions(mod), k) f = functions(mod)[k] - replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstIntToPtr(ConstantInt(T_size_t, convert(UInt, v)), value_type(f)))) + replace_uses!( + f, + LLVM.Value( + LLVM.API.LLVMConstIntToPtr( + ConstantInt(T_size_t, convert(UInt, v)), + value_type(f), + ), + ), + ) eraseInst(mod, f) end end @@ -128,7 +170,7 @@ end # Rewrite calls with "jl_roots" to only have the jl_value_t attached and not { { {} addrspace(10)*, [1 x [2 x i64]], i64, i64 }, [2 x i64] } %unbox110183_replacementA function rewrite_ccalls!(mod::LLVM.Module) for f in collect(functions(mod)) - replaceAndErase = Tuple{Instruction, Instruction}[] + replaceAndErase = Tuple{Instruction,Instruction}[] for bb in blocks(f), inst in instructions(bb) if isa(inst, LLVM.CallInst) fn = called_operand(inst) @@ -160,17 +202,45 @@ function rewrite_ccalls!(mod::LLVM.Module) prevname = LLVM.name(inst) LLVM.name!(inst, "") if !isdefined(LLVM, :OperandBundleDef) - newinst = call!(B, called_type(inst), called_operand(inst), uservals, collect(operand_bundles(inst)), prevname) - else - newinst = call!(B, called_type(inst), called_operand(inst), uservals, collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), prevname) - end - for idx = [LLVM.API.LLVMAttributeFunctionIndex, LLVM.API.LLVMAttributeReturnIndex, [LLVM.API.LLVMAttributeIndex(i) for i in 1:(length(arguments(inst)))]...] + newinst = call!( + B, + called_type(inst), + called_operand(inst), + uservals, + collect(operand_bundles(inst)), + prevname, + ) + else + newinst = call!( + B, + called_type(inst), + called_operand(inst), + uservals, + collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), + prevname, + ) + end + for idx in [ + LLVM.API.LLVMAttributeFunctionIndex, + LLVM.API.LLVMAttributeReturnIndex, + [ + LLVM.API.LLVMAttributeIndex(i) for + i = 1:(length(arguments(inst))) + ]..., + ] idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) - count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx); - Attrs = Base.unsafe_convert(Ptr{LLVM.API.LLVMAttributeRef}, Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef)*count)) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) - for j in 1:count - LLVM.API.LLVMAddCallSiteAttribute(newinst, idx, unsafe_load(Attrs, j)) + for j = 1:count + LLVM.API.LLVMAddCallSiteAttribute( + newinst, + idx, + unsafe_load(Attrs, j), + ) end Libc.free(Attrs) end @@ -181,26 +251,26 @@ function rewrite_ccalls!(mod::LLVM.Module) continue end if !isdefined(LLVM, :OperandBundleDef) - newbundles = OperandBundle[] - else - newbundles = OperandBundleDef[] - end - for bunduse in operand_bundles(inst) + newbundles = OperandBundle[] + else + newbundles = OperandBundleDef[] + end + for bunduse in operand_bundles(inst) if isdefined(LLVM, :OperandBundleDef) - bunduse = LLVM.OperandBundleDef(bunduse) - end + bunduse = LLVM.OperandBundleDef(bunduse) + end if !isdefined(LLVM, :OperandBundleDef) - if LLVM.tag(bunduse) != "jl_roots" - push!(newbundles, bunduse) - continue - end - else - if LLVM.tag_name(bunduse) != "jl_roots" - push!(newbundles, bunduse) - continue - end - end + if LLVM.tag(bunduse) != "jl_roots" + push!(newbundles, bunduse) + continue + end + else + if LLVM.tag_name(bunduse) != "jl_roots" + push!(newbundles, bunduse) + continue + end + end uservals = LLVM.Value[] subchanged = false for lval in LLVM.inputs(bunduse) @@ -228,23 +298,47 @@ function rewrite_ccalls!(mod::LLVM.Module) end changed = true if !isdefined(LLVM, :OperandBundleDef) - push!(newbundles, OperandBundle(LLVM.tag(bunduse), uservals)) + push!(newbundles, OperandBundle(LLVM.tag(bunduse), uservals)) else - push!(newbundles, OperandBundleDef(LLVM.tag_name(bunduse), uservals)) + push!( + newbundles, + OperandBundleDef(LLVM.tag_name(bunduse), uservals), + ) end end changed = false if changed prevname = LLVM.name(inst) LLVM.name!(inst, "") - newinst = call!(B, called_type(inst), called_operand(inst), collect(arguments(inst)), newbundles, prevname) - for idx = [LLVM.API.LLVMAttributeFunctionIndex, LLVM.API.LLVMAttributeReturnIndex, [LLVM.API.LLVMAttributeIndex(i) for i in 1:(length(arguments(inst)))]...] + newinst = call!( + B, + called_type(inst), + called_operand(inst), + collect(arguments(inst)), + newbundles, + prevname, + ) + for idx in [ + LLVM.API.LLVMAttributeFunctionIndex, + LLVM.API.LLVMAttributeReturnIndex, + [ + LLVM.API.LLVMAttributeIndex(i) for + i = 1:(length(arguments(inst))) + ]..., + ] idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) - count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx); - Attrs = Base.unsafe_convert(Ptr{LLVM.API.LLVMAttributeRef}, Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef)*count)) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) - for j in 1:count - LLVM.API.LLVMAddCallSiteAttribute(newinst, idx, unsafe_load(Attrs, j)) + for j = 1:count + LLVM.API.LLVMAddCallSiteAttribute( + newinst, + idx, + unsafe_load(Attrs, j), + ) end Libc.free(Attrs) end @@ -270,7 +364,11 @@ function check_ir!(job, errors, mod::LLVM.Module) prev_ft = eltype(value_type(f)::LLVM.PointerType)::LLVM.FunctionType - mfn = LLVM.API.LLVMAddFunction(mod, "malloc", LLVM.FunctionType(ptr8, parameters(prev_ft))) + mfn = LLVM.API.LLVMAddFunction( + mod, + "malloc", + LLVM.FunctionType(ptr8, parameters(prev_ft)), + ) replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstPointerCast(mfn, value_type(f)))) eraseInst(mod, f) end @@ -291,11 +389,18 @@ function check_ir!(job, errors, imported, f::LLVM.Function) for bb in blocks(f), inst in instructions(bb) if isa(inst, LLVM.CallInst) push!(calls, inst) - # remove illegal invariant.load and jtbaa_const invariants + # remove illegal invariant.load and jtbaa_const invariants elseif isInline && isa(inst, LLVM.LoadInst) md = metadata(inst) if haskey(md, LLVM.MD_tbaa) - modified = LLVM.Metadata(ccall((:EnzymeMakeNonConstTBAA, API.libEnzyme), LLVM.API.LLVMMetadataRef, (LLVM.API.LLVMMetadataRef,), md[LLVM.MD_tbaa])) + modified = LLVM.Metadata( + ccall( + (:EnzymeMakeNonConstTBAA, API.libEnzyme), + LLVM.API.LLVMMetadataRef, + (LLVM.API.LLVMMetadataRef,), + md[LLVM.MD_tbaa], + ), + ) setindex!(md, modified, LLVM.MD_tbaa) end if haskey(md, LLVM.MD_invariant_load) @@ -314,7 +419,18 @@ end const libjulia = Ref{Ptr{Cvoid}}(C_NULL) # List of methods to location of arg which is the mi/function, then start of args -const generic_method_offsets = Dict{String, Tuple{Int,Int}}(("jl_f__apply_latest" => (2,3), "ijl_f__apply_latest" => (2,3), "jl_f__call_latest" => (2,3), "ijl_f__call_latest" => (2,3), "jl_f_invoke" => (2,3), "jl_invoke" => (1,3), "jl_apply_generic" => (1,2), "ijl_f_invoke" => (2,3), "ijl_invoke" => (1,3), "ijl_apply_generic" => (1,2))) +const generic_method_offsets = Dict{String,Tuple{Int,Int}}(( + "jl_f__apply_latest" => (2, 3), + "ijl_f__apply_latest" => (2, 3), + "jl_f__call_latest" => (2, 3), + "ijl_f__call_latest" => (2, 3), + "jl_f_invoke" => (2, 3), + "jl_invoke" => (1, 3), + "jl_apply_generic" => (1, 2), + "ijl_f_invoke" => (2, 3), + "ijl_invoke" => (1, 3), + "ijl_apply_generic" => (1, 2), +)) @inline function has_method(sig, world::UInt, mt::Union{Nothing,Core.MethodTable}) return ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), sig, mt, world) !== nothing @@ -330,16 +446,17 @@ end @inline function is_inactive(tys, world::UInt, mt) specTypes = Interpreter.simplify_kw(Tuple{tys...}) - if has_method(Tuple{typeof(EnzymeRules.inactive), tys...}, world, mt) + if has_method(Tuple{typeof(EnzymeRules.inactive),tys...}, world, mt) return true end - if has_method(Tuple{typeof(EnzymeRules.inactive_noinl), tys...}, world, mt) + if has_method(Tuple{typeof(EnzymeRules.inactive_noinl),tys...}, world, mt) return true end return false end -import GPUCompiler: DYNAMIC_CALL, DELAYED_BINDING, RUNTIME_FUNCTION, UNKNOWN_FUNCTION, POINTER_FUNCTION +import GPUCompiler: + DYNAMIC_CALL, DELAYED_BINDING, RUNTIME_FUNCTION, UNKNOWN_FUNCTION, POINTER_FUNCTION import GPUCompiler: backtrace, isintrinsic function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) world = job.world @@ -363,12 +480,28 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) mfn = LLVM.API.LLVMGetNamedFunction(mod, "malloc") if mfn == C_NULL ptr8 = LLVM.PointerType(LLVM.IntType(8)) - mfn = LLVM.API.LLVMAddFunction(mod, "malloc", LLVM.FunctionType(ptr8, [value_type(LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(inst, 0)))])) + mfn = LLVM.API.LLVMAddFunction( + mod, + "malloc", + LLVM.FunctionType( + ptr8, + [value_type(LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(inst, 0)))], + ), + ) end mfn2 = LLVM.Function(mfn) - nval = ptrtoint!(b, call!(b, LLVM.function_type(mfn2), mfn2, [LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(inst, 0))]), value_type(inst)) + nval = ptrtoint!( + b, + call!( + b, + LLVM.function_type(mfn2), + mfn2, + [LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(inst, 0))], + ), + value_type(inst), + ) replace_uses!(inst, nval) - LLVM.API.LLVMInstructionEraseFromParent(inst) + LLVM.API.LLVMInstructionEraseFromParent(inst) elseif fn == "jl_load_and_lookup" || fn == "ijl_load_and_lookup" ofn = LLVM.parent(LLVM.parent(inst)) mod = LLVM.parent(ofn) @@ -376,7 +509,9 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) arg1 = operands(inst)[1] while isa(arg1, ConstantExpr) - if opcode(arg1) == LLVM.API.LLVMAddrSpaceCast || opcode(arg1) == LLVM.API.LLVMBitCast || opcode(arg1) == LLVM.API.LLVMIntToPtr + if opcode(arg1) == LLVM.API.LLVMAddrSpaceCast || + opcode(arg1) == LLVM.API.LLVMBitCast || + opcode(arg1) == LLVM.API.LLVMIntToPtr arg1 = operands(arg1)[1] else break @@ -389,71 +524,106 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) hnd = operands(inst)[3] if isa(hnd, LLVM.GlobalVariable) hnd = LLVM.name(hnd) - if fn == "jl_lazy_load_and_lookup" - res = ccall(:jl_load_and_lookup, Ptr{Cvoid}, (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), arg1, fname, reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr)) - else - res = ccall(:ijl_load_and_lookup, Ptr{Cvoid}, (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), arg1, fname, reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr)) - end - replaceWith = LLVM.ConstantInt(LLVM.IntType(8*sizeof(Int)), reinterpret(UInt, res)) - for u in LLVM.uses(inst) - st = LLVM.user(u) - if isa(st, LLVM.StoreInst) && LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 0)) == inst - ptr = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 1)) - for u in LLVM.uses(ptr) - ld = LLVM.user(u) - if isa(ld, LLVM.LoadInst) - b = IRBuilder() - position!(b, ld) - for u in LLVM.uses(ld) - u = LLVM.user(u) - if isa(u, LLVM.CallInst) - push!(calls, u) + if fn == "jl_lazy_load_and_lookup" + res = ccall( + :jl_load_and_lookup, + Ptr{Cvoid}, + (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), + arg1, + fname, + reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr), + ) + else + res = ccall( + :ijl_load_and_lookup, + Ptr{Cvoid}, + (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), + arg1, + fname, + reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr), + ) + end + replaceWith = LLVM.ConstantInt( + LLVM.IntType(8 * sizeof(Int)), + reinterpret(UInt, res), + ) + for u in LLVM.uses(inst) + st = LLVM.user(u) + if isa(st, LLVM.StoreInst) && + LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 0)) == inst + ptr = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 1)) + for u in LLVM.uses(ptr) + ld = LLVM.user(u) + if isa(ld, LLVM.LoadInst) + b = IRBuilder() + position!(b, ld) + for u in LLVM.uses(ld) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end end + replace_uses!( + ld, + LLVM.inttoptr!( + b, + replaceWith, + value_type(inst), + ), + ) end - replace_uses!(ld, LLVM.inttoptr!(b, replaceWith, value_type(inst))) end end end - end - b = IRBuilder() - position!(b, inst) - replacement = LLVM.inttoptr!(b, replaceWith, value_type(inst)) - for u in LLVM.uses(inst) - u = LLVM.user(u) - if isa(u, LLVM.CallInst) - push!(calls, u) - end - if isa(u, LLVM.PHIInst) - if all(x->first(x) == inst || first(x) == replacement, LLVM.incoming(u)) + b = IRBuilder() + position!(b, inst) + replacement = LLVM.inttoptr!(b, replaceWith, value_type(inst)) + for u in LLVM.uses(inst) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end + if isa(u, LLVM.PHIInst) + if all( + x -> first(x) == inst || first(x) == replacement, + LLVM.incoming(u), + ) - for u in LLVM.uses(u) - u = LLVM.user(u) - if isa(u, LLVM.CallInst) - push!(calls, u) - end - if isa(u, LLVM.BitCastInst) - for u1 in LLVM.uses(u) - u1 = LLVM.user(u1) - if isa(u1, LLVM.CallInst) - push!(calls, u1) - end - end - replace_uses!(u, LLVM.inttoptr!(b, replaceWith, value_type(u))) + for u in LLVM.uses(u) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end + if isa(u, LLVM.BitCastInst) + for u1 in LLVM.uses(u) + u1 = LLVM.user(u1) + if isa(u1, LLVM.CallInst) + push!(calls, u1) end end + replace_uses!( + u, + LLVM.inttoptr!( + b, + replaceWith, + value_type(u), + ), + ) end end end - replace_uses!(inst, replacement) - LLVM.API.LLVMInstructionEraseFromParent(inst) + end + end + replace_uses!(inst, replacement) + LLVM.API.LLVMInstructionEraseFromParent(inst) end end end - + elseif fn == "jl_lazy_load_and_lookup" || fn == "ijl_lazy_load_and_lookup" ofn = LLVM.parent(LLVM.parent(inst)) mod = LLVM.parent(ofn) @@ -469,7 +639,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) op = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(op, 0)) end if isa(op, ConstantInt) - rep = reinterpret(Ptr{Cvoid}, convert(Csize_t, op)+8) + rep = reinterpret(Ptr{Cvoid}, convert(Csize_t, op) + 8) ld = unsafe_load(convert(Ptr{Ptr{Cvoid}}, rep)) flib = Base.unsafe_pointer_to_objref(ld) end @@ -485,8 +655,9 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) if isa(fname, LLVM.GlobalVariable) fname = LLVM.initializer(fname) end - if (isa(fname, LLVM.ConstantArray) || isa(fname, LLVM.ConstantDataArray)) && eltype(value_type(fname)) == LLVM.IntType(8) - fname = String(map((x)->convert(UInt8, x), collect(fname)[1:(end-1)])) + if (isa(fname, LLVM.ConstantArray) || isa(fname, LLVM.ConstantDataArray)) && + eltype(value_type(fname)) == LLVM.IntType(8) + fname = String(map((x) -> convert(UInt8, x), collect(fname)[1:(end-1)])) end if !isa(fname, String) || !isa(flib, String) @@ -494,7 +665,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) end found = false - + try data = open(flib, "r") do io lib = readmeta(io) @@ -537,14 +708,18 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) for u in LLVM.uses(inst) st = LLVM.user(u) - if isa(st, LLVM.StoreInst) && LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 0)) == inst + if isa(st, LLVM.StoreInst) && + LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 0)) == inst ptr = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 1)) for u in LLVM.uses(ptr) ld = LLVM.user(u) if isa(ld, LLVM.LoadInst) b = IRBuilder() position!(b, ld) - replace_uses!(ld, LLVM.pointercast!(b, replaceWith, value_type(inst))) + replace_uses!( + ld, + LLVM.pointercast!(b, replaceWith, value_type(inst)), + ) end end end @@ -558,14 +733,28 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) else if fn == "jl_lazy_load_and_lookup" - res = ccall(:jl_lazy_load_and_lookup, Ptr{Cvoid}, (Any, Cstring), flib, fname) + res = ccall( + :jl_lazy_load_and_lookup, + Ptr{Cvoid}, + (Any, Cstring), + flib, + fname, + ) else - res = ccall(:ijl_lazy_load_and_lookup, Ptr{Cvoid}, (Any, Cstring), flib, fname) + res = ccall( + :ijl_lazy_load_and_lookup, + Ptr{Cvoid}, + (Any, Cstring), + flib, + fname, + ) end - replaceWith = LLVM.ConstantInt(LLVM.IntType(8*sizeof(Int)), reinterpret(UInt, res)) + replaceWith = + LLVM.ConstantInt(LLVM.IntType(8 * sizeof(Int)), reinterpret(UInt, res)) for u in LLVM.uses(inst) st = LLVM.user(u) - if isa(st, LLVM.StoreInst) && LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 0)) == inst + if isa(st, LLVM.StoreInst) && + LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 0)) == inst ptr = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 1)) for u in LLVM.uses(ptr) ld = LLVM.user(u) @@ -578,7 +767,10 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) push!(calls, u) end end - replace_uses!(ld, LLVM.inttoptr!(b, replaceWith, value_type(inst))) + replace_uses!( + ld, + LLVM.inttoptr!(b, replaceWith, value_type(inst)), + ) end end end @@ -587,32 +779,38 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) b = IRBuilder() position!(b, inst) replacement = LLVM.inttoptr!(b, replaceWith, value_type(inst)) - for u in LLVM.uses(inst) + for u in LLVM.uses(inst) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end + if isa(u, LLVM.PHIInst) + if all( + x -> first(x) == inst || first(x) == replacement, + LLVM.incoming(u), + ) + + for u in LLVM.uses(u) u = LLVM.user(u) if isa(u, LLVM.CallInst) push!(calls, u) end - if isa(u, LLVM.PHIInst) - if all(x->first(x) == inst || first(x) == replacement, LLVM.incoming(u)) - - for u in LLVM.uses(u) - u = LLVM.user(u) - if isa(u, LLVM.CallInst) - push!(calls, u) - end - if isa(u, LLVM.BitCastInst) - for u1 in LLVM.uses(u) - u1 = LLVM.user(u1) - if isa(u1, LLVM.CallInst) - push!(calls, u1) - end - end - replace_uses!(u, LLVM.inttoptr!(b, replaceWith, value_type(u))) - end + if isa(u, LLVM.BitCastInst) + for u1 in LLVM.uses(u) + u1 = LLVM.user(u1) + if isa(u1, LLVM.CallInst) + push!(calls, u1) end end + replace_uses!( + u, + LLVM.inttoptr!(b, replaceWith, value_type(u)), + ) end end + end + end + end replace_uses!(inst, replacement) LLVM.API.LLVMInstructionEraseFromParent(inst) end @@ -622,29 +820,77 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) if isa(dest, LLVM.Function) && LLVM.name(dest) == "jl_f__apply_iterate" # Add 1 to account for function being first arg iteroff = 2 - + legal, iterlib = absint(operands(inst)[iteroff+1]) if legal && iterlib == Base.iterate - legal, GT = abs_typeof(operands(inst)[4+1], true) + legal, GT, byref = abs_typeof(operands(inst)[4+1], true) funcoff = 3 - legal2, funclib = abs_typeof(operands(inst)[funcoff+1]) + legal2, funclib, byref2 = abs_typeof(operands(inst)[funcoff+1]) if legal && (GT <: Vector || GT <: Tuple) if legal2 tys = [funclib, Vararg{Any}] - if funclib == typeof(Core.apply_type) || is_inactive(tys, world, method_table) + if funclib == typeof(Core.apply_type) || + is_inactive(tys, world, method_table) inactive = LLVM.StringAttribute("enzyme_inactive", "") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + inactive, + ) nofree = LLVM.EnumAttribute("nofree") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) - no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) - elseif funclib == typeof(Base.tuple) && length(operands(inst)) == 4+1+1 && Base.isconcretetype(GT) && Enzyme.Compiler.guaranteed_const_nongen(GT, world) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + nofree, + ) + no_escaping_alloc = + LLVM.StringAttribute("enzyme_no_escaping_allocation") + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + no_escaping_alloc, + ) + elseif funclib == typeof(Base.tuple) && + length(operands(inst)) == 4 + 1 + 1 && + Base.isconcretetype(GT) && + Enzyme.Compiler.guaranteed_const_nongen(GT, world) inactive = LLVM.StringAttribute("enzyme_inactive", "") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + inactive, + ) nofree = LLVM.EnumAttribute("nofree") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) - no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + nofree, + ) + no_escaping_alloc = + LLVM.StringAttribute("enzyme_no_escaping_allocation") + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + no_escaping_alloc, + ) end end end @@ -654,11 +900,11 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) if isa(dest, LLVM.Function) && in(LLVM.name(dest), keys(generic_method_offsets)) offset, start = generic_method_offsets[LLVM.name(dest)] # Add 1 to account for function being first arg - legal, flibty = abs_typeof(operands(inst)[offset+1]) + legal, flibty, byref = abs_typeof(operands(inst)[offset+1]) if legal tys = Type[flibty] for op in collect(operands(inst))[start+1:end-1] - legal, typ = abs_typeof(op, true) + legal, typ, byref2 = abs_typeof(op, true) if !legal typ = Any end @@ -673,11 +919,33 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) end if is_inactive(tys, world, method_table) inactive = LLVM.StringAttribute("enzyme_inactive", "") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + inactive, + ) nofree = LLVM.EnumAttribute("nofree") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) - no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + nofree, + ) + no_escaping_alloc = + LLVM.StringAttribute("enzyme_no_escaping_allocation") + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + no_escaping_alloc, + ) end end end @@ -697,7 +965,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) ptr = Ptr{Cvoid}(ptr_val) # look it up in the Julia JIT cache - frames = ccall(:jl_lookup_code_address, Any, (Ptr{Cvoid}, Cint,), ptr, 0) + frames = ccall(:jl_lookup_code_address, Any, (Ptr{Cvoid}, Cint), ptr, 0) if length(frames) >= 1 fn, file, line, linfo, fromC, inlined = last(frames) @@ -709,11 +977,24 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) mod = LLVM.parent(LLVM.parent(LLVM.parent(inst))) lfn = LLVM.API.LLVMGetNamedFunction(mod, fn) if lfn == C_NULL - lfn = LLVM.API.LLVMAddFunction(mod, fn, LLVM.API.LLVMGetCalledFunctionType(inst)) + lfn = LLVM.API.LLVMAddFunction( + mod, + fn, + LLVM.API.LLVMGetCalledFunctionType(inst), + ) else - lfn = LLVM.API.LLVMConstBitCast(lfn, LLVM.PointerType(LLVM.FunctionType(LLVM.API.LLVMGetCalledFunctionType(inst)))) + lfn = LLVM.API.LLVMConstBitCast( + lfn, + LLVM.PointerType( + LLVM.FunctionType(LLVM.API.LLVMGetCalledFunctionType(inst)), + ), + ) end - LLVM.API.LLVMSetOperand(inst, LLVM.API.LLVMGetNumOperands(inst)-1, lfn) + LLVM.API.LLVMSetOperand( + inst, + LLVM.API.LLVMGetNumOperands(inst) - 1, + lfn, + ) end end end @@ -721,11 +1002,11 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) if isa(dest, LLVM.Function) && in(LLVM.name(dest), keys(generic_method_offsets)) offset, start = generic_method_offsets[LLVM.name(dest)] - legal, flibty = abs_typeof(operands(inst)[offset]) + legal, flibty, byref = abs_typeof(operands(inst)[offset]) if legal tys = Type[flibty] for op in collect(operands(inst))[start:end-1] - legal, typ = abs_typeof(op, true) + legal, typ, byref2 = abs_typeof(op, true) if !legal typ = Any end @@ -735,15 +1016,18 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) if legal && isa(flib, Core.MethodInstance) if !Base.isvarargtype(flib.specTypes.parameters[end]) if length(tys) != length(flib.specTypes.parameters) - msg = sprint() do io::IO - println(io, "Enzyme internal error (length(tys) != length(flib.specTypes.parameters))") - println(io, "tys=", tys) - println(io, "flib=", flib) - println(io, "inst=", inst) - println(io, "offset=", offset) - println(io, "start=", start) - end - throw(AssertionError(msg)) + msg = sprint() do io::IO + println( + io, + "Enzyme internal error (length(tys) != length(flib.specTypes.parameters))", + ) + println(io, "tys=", tys) + println(io, "flib=", flib) + println(io, "inst=", inst) + println(io, "offset=", offset) + println(io, "start=", start) + end + throw(AssertionError(msg)) end end tys = flib.specTypes.parameters @@ -752,11 +1036,33 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) ofn = LLVM.parent(LLVM.parent(inst)) mod = LLVM.parent(ofn) inactive = LLVM.StringAttribute("enzyme_inactive", "") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + inactive, + ) nofree = LLVM.EnumAttribute("nofree") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) - no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + nofree, + ) + no_escaping_alloc = + LLVM.StringAttribute("enzyme_no_escaping_allocation") + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + no_escaping_alloc, + ) end end end @@ -767,15 +1073,15 @@ end function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width) - todo = Tuple{LLVM.Value, Tuple}[] + todo = Tuple{LLVM.Value,Tuple}[] for b in blocks(enzymefn) term = terminator(b) if LLVM.API.LLVMIsAReturnInst(term) != C_NULL if width == 1 push!(todo, (operands(term)[1], off == -1 ? () : (off,))) else - for i in 1:width - push!(todo, (operands(term)[1], off == -1 ? (i,) : (off,i))) + for i = 1:width + push!(todo, (operands(term)[1], off == -1 ? (i,) : (off, i))) end end end @@ -803,7 +1109,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width if isa(cur, LLVM.ExtractValueInst) noff = off - for i in 1:LLVM.API.LLVMGetNumIndices(cur) + for i = 1:LLVM.API.LLVMGetNumIndices(cur) noff = (noff..., convert(Int, unsafe_load(LLVM.API.LLVMGetIndices(cur), i))) end push!(todo, (operands(cur)[1], noff)) @@ -819,7 +1125,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width # if inserting at the current desired offset, we have found the value we need if ind == off[1] push!(todo, (operands(cur)[2], off[2:end])) - # otherwise it must be inserted at a different point + # otherwise it must be inserted at a different point else push!(todo, (operands(cur)[1], off)) end @@ -833,15 +1139,18 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width nm = LLVM.name(fn) end - # Type tag is arg 3 if nm == "julia.gc_alloc_obj" - legal, Ty = abs_typeof(cur) + legal, Ty, byref = abs_typeof(cur) @assert legal reg = active_reg_inner(Ty, (), world) if reg == ActiveState || reg == MixedState NTy = Base.RefValue{Ty} @assert sizeof(Ty) == sizeof(NTy) - LLVM.API.LLVMSetOperand(cur, 2, unsafe_to_llvm(LLVM.IRBuilder(cur), NTy)) + LLVM.API.LLVMSetOperand( + cur, + 2, + unsafe_to_llvm(LLVM.IRBuilder(cur), NTy), + ) end continue end @@ -858,7 +1167,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width if isa(cur, LLVM.LoadInst) al = operands(cur)[1] if isa(al, LLVM.AllocaInst) - atodo = Tuple{LLVM.Value, Tuple, LLVM.Value}[] + atodo = Tuple{LLVM.Value,Tuple,LLVM.Value}[] for u in LLVM.uses(al) push!(atodo, (LLVM.user(u), off, al)) end @@ -893,22 +1202,23 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width continue end - msg = sprint() do io::IO - println(io, "Enzyme Internal Error (rewrite_union_returns_as_ref[1])") - println(io, string(enzymefn)) - println(io, "BAD") - println(io, "acur=", acur) - println(io, "aoff=", aoff) - println(io, "prev=", prev) - end - throw(AssertionError(msg)) + msg = sprint() do io::IO + println(io, "Enzyme Internal Error (rewrite_union_returns_as_ref[1])") + println(io, string(enzymefn)) + println(io, "BAD") + println(io, "acur=", acur) + println(io, "aoff=", aoff) + println(io, "prev=", prev) + end + throw(AssertionError(msg)) end continue end end - if length(off) == 0 && value_type(cur) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked) - legal, typ = abs_typeof(cur) + if length(off) == 0 && + value_type(cur) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked) + legal, typ, byref = abs_typeof(cur) if legal reg = active_reg_inner(typ, (), world) if !(reg == ActiveState || reg == MixedState) @@ -921,7 +1231,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width push!(todo, (cur[off[1]], off[2:end])) continue end - + if isa(cur, LLVM.CallInst) dest = called_operand(cur) if isa(dest, LLVM.Function) @@ -932,12 +1242,12 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width end end - msg = sprint() do io::IO - println(io, "Enzyme Internal Error (rewrite_union_returns_as_ref[2])") - println(io, string(enzymefn)) - println(io, "cur=", string(cur)) - println(io, "off=", off) - end - throw(AssertionError(msg)) + msg = sprint() do io::IO + println(io, "Enzyme Internal Error (rewrite_union_returns_as_ref[2])") + println(io, string(enzymefn)) + println(io, "cur=", string(cur)) + println(io, "off=", off) + end + throw(AssertionError(msg)) end end diff --git a/src/gradientutils.jl b/src/gradientutils.jl index ff83f60c1d..b0a0bff26b 100644 --- a/src/gradientutils.jl +++ b/src/gradientutils.jl @@ -6,14 +6,35 @@ end Base.unsafe_convert(::Type{API.EnzymeGradientUtilsRef}, gutils::GradientUtils) = gutils.ref LLVM.dispose(gutils::GradientUtils) = throw("Cannot free gutils") -function call_samefunc_with_inverted_bundles!(B::LLVM.IRBuilder, gutils::GradientUtils, orig::LLVM.CallInst, args::Vector{<:LLVM.Value}, valTys::Vector{API.CValueType}, lookup::Bool) +function call_samefunc_with_inverted_bundles!( + B::LLVM.IRBuilder, + gutils::GradientUtils, + orig::LLVM.CallInst, + args::Vector{<:LLVM.Value}, + valTys::Vector{API.CValueType}, + lookup::Bool, +) @assert length(args) == length(valTys) - return LLVM.Value(API.EnzymeGradientUtilsCallWithInvertedBundles(gutils, LLVM.called_operand(orig), LLVM.called_type(orig), args, length(args), orig, valTys, length(valTys), B, #=lookup=#false)) + return LLVM.Value( + API.EnzymeGradientUtilsCallWithInvertedBundles( + gutils, + LLVM.called_operand(orig), + LLVM.called_type(orig), + args, + length(args), + orig, + valTys, + length(valTys), + B, + false, + ), + ) #=lookup=# end get_width(gutils::GradientUtils) = API.EnzymeGradientUtilsGetWidth(gutils) get_mode(gutils::GradientUtils) = API.EnzymeGradientUtilsGetMode(gutils) -get_runtime_activity(gutils::GradientUtils) = API.EnzymeGradientUtilsGetRuntimeActivity(gutils) +get_runtime_activity(gutils::GradientUtils) = + API.EnzymeGradientUtilsGetRuntimeActivity(gutils) function get_shadow_type(gutils::GradientUtils, T::LLVM.LLVMType) w = get_width(gutils) @@ -23,26 +44,45 @@ function get_shadow_type(gutils::GradientUtils, T::LLVM.LLVMType) return LLVM.ArrayType(T, Int(w)) end end -function get_uncacheable(gutils::GradientUtils, orig::LLVM.CallInst) - uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) - if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) != 1 +function get_uncacheable(gutils::GradientUtils, orig::LLVM.CallInst) + uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig))) - 1) + if API.EnzymeGradientUtilsGetUncacheableArgs( + gutils, + orig, + uncacheable, + length(uncacheable), + ) != 1 uncacheable .= 1 end return uncacheable end -erase_with_placeholder(gutils::GradientUtils, inst::LLVM.Instruction, orig::LLVM.Instruction, erase::Bool=true) = API.EnzymeGradientUtilsEraseWithPlaceholder(gutils, inst, orig, erase) -is_constant_value(gutils::GradientUtils, val::LLVM.Value) = API.EnzymeGradientUtilsIsConstantValue(gutils, val) != 0 +erase_with_placeholder( + gutils::GradientUtils, + inst::LLVM.Instruction, + orig::LLVM.Instruction, + erase::Bool = true, +) = API.EnzymeGradientUtilsEraseWithPlaceholder(gutils, inst, orig, erase) +is_constant_value(gutils::GradientUtils, val::LLVM.Value) = + API.EnzymeGradientUtilsIsConstantValue(gutils, val) != 0 -is_constant_inst(gutils::GradientUtils, inst::LLVM.Instruction) = API.EnzymeGradientUtilsIsConstantInstruction(gutils, inst) != 0 +is_constant_inst(gutils::GradientUtils, inst::LLVM.Instruction) = + API.EnzymeGradientUtilsIsConstantInstruction(gutils, inst) != 0 -new_from_original(gutils::GradientUtils, val::LLVM.Value) = LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, val)) +new_from_original(gutils::GradientUtils, val::LLVM.Value) = + LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, val)) -lookup_value(gutils::GradientUtils, val::LLVM.Value, B::LLVM.IRBuilder) = LLVM.Value(API.EnzymeGradientUtilsLookup(gutils, val, B)) +lookup_value(gutils::GradientUtils, val::LLVM.Value, B::LLVM.IRBuilder) = + LLVM.Value(API.EnzymeGradientUtilsLookup(gutils, val, B)) -invert_pointer(gutils::GradientUtils, val::LLVM.Value, B::LLVM.IRBuilder) = LLVM.Value(API.EnzymeGradientUtilsInvertPointer(gutils, val, B)) +invert_pointer(gutils::GradientUtils, val::LLVM.Value, B::LLVM.IRBuilder) = + LLVM.Value(API.EnzymeGradientUtilsInvertPointer(gutils, val, B)) -function debug_from_orig!(gutils::GradientUtils, nval::LLVM.Instruction, oval::LLVM.Instruction) +function debug_from_orig!( + gutils::GradientUtils, + nval::LLVM.Instruction, + oval::LLVM.Instruction, +) API.EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, nval, oval) nothing end diff --git a/src/internal_rules.jl b/src/internal_rules.jl index f29ed0d977..f8c6e730bb 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -66,7 +66,12 @@ end function EnzymeRules.inactive(::typeof(Core.kwfunc), args...) return nothing end -function EnzymeRules.inactive(::typeof(Random.rand!), ::Random.AbstractRNG, ::Random.Sampler, ::AbstractArray) +function EnzymeRules.inactive( + ::typeof(Random.rand!), + ::Random.AbstractRNG, + ::Random.Sampler, + ::AbstractArray, +) return nothing end function EnzymeRules.inactive(::typeof(Random.randn!), args...) @@ -96,7 +101,12 @@ end function EnzymeRules.inactive_noinl(::typeof(Base.size), args...) return nothing end -function EnzymeRules.inactive_noinl(::typeof(Base.setindex!), ::IdDict{K, V}, ::K, ::V) where {K, V <:Integer} +function EnzymeRules.inactive_noinl( + ::typeof(Base.setindex!), + ::IdDict{K,V}, + ::K, + ::V, +) where {K,V<:Integer} return nothing end @@ -123,24 +133,49 @@ Enzyme.EnzymeRules.inactive_noinl(::typeof(Core._compute_sparams), args...) = no # Note all of these forward mode definitions do not support runtime activity as # the do not keep the primal if shadow(x.y) == primal(x.y) -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(Base.deepcopy)}, + ::Type{<:DuplicatedNoNeed}, + x::Duplicated, +) return deepcopy(x.dval) end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(Base.deepcopy)}, + ::Type{<:BatchDuplicatedNoNeed}, + x::BatchDuplicated{T,N}, +) where {T,N} ntuple(Val(N)) do _ deepcopy(x.dval) end end # Deepcopy preserving the primal if runtime inactive -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Union{Integer, Char}} +@inline function deepcopy_rtact( + copied::RT, + primal::RT, + seen::IdDict, + shadow::RT, +) where {RT<:Union{Integer,Char}} return Base.deepcopy_internal(shadow, seen) end -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: AbstractFloat} +@inline function deepcopy_rtact( + copied::RT, + primal::RT, + seen::IdDict, + shadow::RT, +) where {RT<:AbstractFloat} return Base.deepcopy_internal(shadow, seen) end -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Array} +@inline function deepcopy_rtact( + copied::RT, + primal::RT, + seen::IdDict, + shadow::RT, +) where {RT<:Array} if !haskey(seen, shadow) if primal === shadow return seen[shadow] = copied @@ -154,19 +189,34 @@ end return seen[shadow] end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{typeof(Base.deepcopy)}, ::Type{<:Duplicated}, x::Duplicated) +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + func::Const{typeof(Base.deepcopy)}, + ::Type{<:Duplicated}, + x::Duplicated, +) primal = func.val(x.val) return Duplicated(primal, deepcopy_rtact(primal, x.val, IdDict(), x.dval)) end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + func::Const{typeof(Base.deepcopy)}, + ::Type{<:BatchDuplicated}, + x::BatchDuplicated{T,N}, +) where {T,N} primal = func.val(x.val) return BatchDuplicated(primal, ntuple(Val(N)) do i deepcopy_rtact(primal, x.val, IdDict(), x.dval[i]) end) end -function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, x::Annotation{Ty}) where {RT, Ty} +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + func::Const{typeof(Base.deepcopy)}, + ::Type{RT}, + x::Annotation{Ty}, +) where {RT,Ty} primal = if EnzymeRules.needs_primal(config) func.val(x.val) else @@ -183,14 +233,16 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const shadow = if EnzymeRules.needs_shadow(config) if EnzymeRules.width(config) == 1 - Enzyme.make_zero(source, - #=copy_if_inactive=#Val(!EnzymeRules.needs_primal(config)) + Enzyme.make_zero( + source, + Val(!EnzymeRules.needs_primal(config)), #=copy_if_inactive=# ) else ntuple(Val(EnzymeRules.width(config))) do _ Base.@_inline_meta - Enzyme.make_zero(source, - #=copy_if_inactive=#Val(!EnzymeRules.needs_primal(config)) + Enzyme.make_zero( + source, + Val(!EnzymeRules.needs_primal(config)), #=copy_if_inactive=# ) end end @@ -202,7 +254,11 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const end -@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT<:Array} +@inline function accumulate_into( + into::RT, + seen::IdDict, + from::RT, +)::Tuple{RT,RT} where {RT<:Array} if Enzyme.Compiler.guaranteed_const(RT) return (into, from) end @@ -217,9 +273,13 @@ end return seen[into] end -@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT<:AbstractFloat} +@inline function accumulate_into( + into::RT, + seen::IdDict, + from::RT, +)::Tuple{RT,RT} where {RT<:AbstractFloat} if !haskey(seen, into) - seen[into] = (into+from, RT(0)) + seen[into] = (into + from, RT(0)) end return seen[into] end @@ -234,12 +294,18 @@ end return seen[into] end -function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, x::Annotation{Ty}) where {RT, Ty} +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + func::Const{typeof(Base.deepcopy)}, + ::Type{RT}, + shadow, + x::Annotation{Ty}, +) where {RT,Ty} if EnzymeRules.needs_shadow(config) if EnzymeRules.width(config) == 1 accumulate_into(x.dval, IdDict(), shadow) else - for i in 1:EnzymeRules.width(config) + for i = 1:EnzymeRules.width(config) accumulate_into(x.dval[i], IdDict(), shadow[i]) end end @@ -248,43 +314,100 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(B return (nothing,) end -@inline function pmap_fwd(idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} +@inline function pmap_fwd( + idx, + tapes::Vector, + thunk::ThunkTy, + f::F, + fargs::Vararg{Annotation,N}, +) where {ThunkTy,F,N} @inbounds tapes[idx] = thunk(f, Const(idx), fargs...)[1] end -@inline function pmap_fwd(idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} +@inline function pmap_fwd( + idx, + tapes::Ptr, + thunk::ThunkTy, + f::F, + fargs::Vararg{Annotation,N}, +) where {ThunkTy,F,N} unsafe_store!(tapes, thunk(f, Const(idx), fargs...)[1], idx) end -function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} - - config2 = ReverseModeSplit{false, false, EnzymeRules.runtime_activity(config), EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI, false}() - fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + func::Const{typeof(Enzyme.pmap)}, + ::Type{Const{Nothing}}, + body::BodyTy, + count, + args::Vararg{Annotation,N}, +) where {BodyTy,N} + + config2 = ReverseModeSplit{ + false, + false, + EnzymeRules.runtime_activity(config), + EnzymeRules.width(config), + EnzymeRules.overwritten(config)[2:end], + InlineABI, + false, + }() + fwd_thunk, rev_thunk = + autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) TapeType = EnzymeRules.tape_type(fwd_thunk) tapes = if Enzyme.Compiler.any_jltypes(TapeType) Vector{TapeType}(undef, count.val) else - Base.unsafe_convert(Ptr{TapeType}, Libc.malloc(sizeof(TapeType)*count.val)) + Base.unsafe_convert(Ptr{TapeType}, Libc.malloc(sizeof(TapeType) * count.val)) end Enzyme.pmap(pmap_fwd, count.val, tapes, fwd_thunk, body, args...) return EnzymeRules.AugmentedReturn(nothing, nothing, tapes) end -@inline function pmap_rev(idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} +@inline function pmap_rev( + idx, + tapes::Vector, + thunk::ThunkTy, + f::F, + fargs::Vararg{Annotation,N}, +) where {ThunkTy,F,N} thunk(f, Const(idx), fargs..., @inbounds tapes[idx]) end -@inline function pmap_rev(idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} +@inline function pmap_rev( + idx, + tapes::Ptr, + thunk::ThunkTy, + f::F, + fargs::Vararg{Annotation,N}, +) where {ThunkTy,F,N} thunk(f, Const(idx), fargs..., unsafe_load(tapes, idx)) end -function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, tapes, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} - - config2 = ReverseModeSplit{false, false, EnzymeRules.runtime_activity(config), EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI, false}() - fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + func::Const{typeof(Enzyme.pmap)}, + ::Type{Const{Nothing}}, + tapes, + body::BodyTy, + count, + args::Vararg{Annotation,N}, +) where {BodyTy,N} + + config2 = ReverseModeSplit{ + false, + false, + EnzymeRules.runtime_activity(config), + EnzymeRules.width(config), + EnzymeRules.overwritten(config)[2:end], + InlineABI, + false, + }() + fwd_thunk, rev_thunk = + autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) Enzyme.pmap(pmap_rev, count.val, tapes, rev_thunk, body, args...) @@ -294,7 +417,7 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(E Libc.free(tapes) end - return ntuple(Val(2+length(args))) do _ + return ntuple(Val(2 + length(args))) do _ Base.@_inline_meta nothing end @@ -303,7 +426,7 @@ end # From LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1110 -@inline function compute_lu_cache(cache_A::AT, b::BT) where {AT, BT} +@inline function compute_lu_cache(cache_A::AT, b::BT) where {AT,BT} LinearAlgebra.require_one_based_indexing(cache_A, b) m, n = size(cache_A) @@ -323,12 +446,18 @@ end return LinearAlgebra.qr(cache_A, ColumnNorm()) end -@inline onedimensionalize(::Type{T}) where T <: Array = Vector{eltype(T)} +@inline onedimensionalize(::Type{T}) where {T<:Array} = Vector{eltype(T)} # y=inv(A) B # dA −= z y^T # dB += z, where z = inv(A^T) dy -function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(\)}, ::Type{RT}, A::Annotation{AT}, b::Annotation{BT}) where {RT, AT <: Array, BT <: Array} +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + func::Const{typeof(\)}, + ::Type{RT}, + A::Annotation{AT}, + b::Annotation{BT}, +) where {RT,AT<:Array,BT<:Array} cache_A = if EnzymeRules.overwritten(config)[2] copy(A.val) @@ -368,30 +497,46 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const end UT = Union{ - LinearAlgebra.Diagonal{eltype(AT), onedimensionalize(BT)}, - LinearAlgebra.LowerTriangular{eltype(AT), AT}, - LinearAlgebra.UpperTriangular{eltype(AT), AT}, - LinearAlgebra.LU{eltype(AT), AT, Vector{Int}}, - LinearAlgebra.QRPivoted{eltype(AT), AT, onedimensionalize(BT), Vector{Int}} + LinearAlgebra.Diagonal{eltype(AT),onedimensionalize(BT)}, + LinearAlgebra.LowerTriangular{eltype(AT),AT}, + LinearAlgebra.UpperTriangular{eltype(AT),AT}, + LinearAlgebra.LU{eltype(AT),AT,Vector{Int}}, + LinearAlgebra.QRPivoted{eltype(AT),AT,onedimensionalize(BT),Vector{Int}}, } - cache = NamedTuple{(Symbol("1"),Symbol("2"), Symbol("3"), Symbol("4")), Tuple{ - eltype(RT), - EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing, - UT, - typeof(cache_b) - }}( - (cache_res, dres, cache_A, cache_b) - ) + cache = NamedTuple{ + (Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4")), + Tuple{ + eltype(RT), + EnzymeRules.needs_shadow(config) ? + ( + EnzymeRules.width(config) == 1 ? eltype(RT) : + NTuple{EnzymeRules.width(config),eltype(RT)} + ) : Nothing, + UT, + typeof(cache_b), + }, + }((cache_res, dres, cache_A, cache_b)) return EnzymeRules.AugmentedReturn{ EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), - typeof(cache) - }(retres, dres, cache) + typeof(cache), + }( + retres, + dres, + cache, + ) end -function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(\)}, ::Type{RT}, cache, A::Annotation{<:Array}, b::Annotation{<:Array}) where RT +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + func::Const{typeof(\)}, + ::Type{RT}, + cache, + A::Annotation{<:Array}, + b::Annotation{<:Array}, +) where {RT} y, dys, cache_A, cache_b = cache @@ -448,14 +593,14 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(\ dy .= eltype(dy)(0) end - return (nothing,nothing) + return (nothing, nothing) end const EnzymeTriangulars = Union{ UpperTriangular{<:Complex}, LowerTriangular{<:Complex}, UnitUpperTriangular{<:Complex}, - UnitLowerTriangular{<:Complex} + UnitLowerTriangular{<:Complex}, } function EnzymeRules.augmented_primal( @@ -464,8 +609,8 @@ function EnzymeRules.augmented_primal( ::Type{RT}, Y::Annotation{YT}, A::Annotation{AT}, - B::Annotation{BT} -) where {RT, YT <: Array, AT <: EnzymeTriangulars, BT <: Array} + B::Annotation{BT}, +) where {RT,YT<:Array,AT<:EnzymeTriangulars,BT<:Array} cache_Y = EnzymeRules.overwritten(config)[1] ? copy(Y.val) : Y.val cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : A.val cache_A = compute_lu_cache(cache_A, B.val) @@ -476,9 +621,11 @@ function EnzymeRules.augmented_primal( return EnzymeRules.AugmentedReturn{ EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), - Tuple{typeof(cache_Y), typeof(cache_A), typeof(cache_B)} + Tuple{typeof(cache_Y),typeof(cache_A),typeof(cache_B)}, }( - primal, shadow, (cache_Y, cache_A, cache_B) + primal, + shadow, + (cache_Y, cache_A, cache_B), ) end @@ -489,11 +636,11 @@ function EnzymeRules.reverse( cache, Y::Annotation{YT}, A::Annotation{AT}, - B::Annotation{BT} -) where {YT <: Array, RT, AT <: EnzymeTriangulars, BT <: Array} + B::Annotation{BT}, +) where {YT<:Array,RT,AT<:EnzymeTriangulars,BT<:Array} if !isa(Y, Const) (cache_Yout, cache_A, cache_B) = cache - for b in 1:EnzymeRules.width(config) + for b = 1:EnzymeRules.width(config) dY = EnzymeRules.width(config) == 1 ? Y.dval : Y.dval[b] z = adjoint(cache_A) \ dY if !isa(B, Const) @@ -516,7 +663,13 @@ _zero_unused_elements!(X, ::UnitUpperTriangular) = triu!(X, 1) _zero_unused_elements!(X, ::UnitLowerTriangular) = tril!(X, -1) # Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) -function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + func::Const{typeof(Base.hvcat_fill!)}, + ::Type{RT}, + out::Annotation{AT}, + inp::Annotation{BT}, +) where {RT,AT<:Array,BT<:Tuple} primal = if EnzymeRules.needs_primal(config) out.val else @@ -531,9 +684,16 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const return EnzymeRules.AugmentedReturn(primal, shadow, nothing) end -function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} - nr, nc = size(out.val,1), size(out.val,2) - for b in 1:EnzymeRules.width(config) +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + func::Const{typeof(Base.hvcat_fill!)}, + ::Type{RT}, + _, + out::Annotation{AT}, + inp::Annotation{BT}, +) where {RT,AT<:Array,BT<:Tuple} + nr, nc = size(out.val, 1), size(out.val, 2) + for b = 1:EnzymeRules.width(config) da = if EnzymeRules.width(config) == 1 out.dval else @@ -547,7 +707,7 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(B res = da[i, j] da[i, j] = 0 j += 1 - if j == nc+1 + if j == nc + 1 i += 1 j = 1 end @@ -558,18 +718,19 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(B T(0) end end - return (nothing, dinp)::Tuple{Nothing, BT} + return (nothing, dinp)::Tuple{Nothing,BT} end end return (nothing, nothing) end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated{T}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}} +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + xs::Duplicated{T}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] @@ -584,15 +745,16 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, end end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, - xs::BatchDuplicated{T, N}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}, N} +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,BatchDuplicatedNoNeed,BatchDuplicated}}, + xs::BatchDuplicated{T,N}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat},N} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] - for i in 1:N + for i = 1:N xs.dval[i] .= xs.dval[i][inds] end if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) @@ -608,12 +770,12 @@ end function EnzymeRules.augmented_primal( - config::EnzymeRules.RevConfigWidth{1}, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated{T}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}} + config::EnzymeRules.RevConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + xs::Duplicated{T}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] @@ -631,26 +793,27 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.RevConfigWidth{1}, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - tape, - xs::Duplicated{T}; - kwargs..., - ) where {T <: AbstractArray{<:AbstractFloat}} + config::EnzymeRules.RevConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + tape, + xs::Duplicated{T}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} inds = tape back_inds = sortperm(inds) xs.dval .= xs.dval[back_inds] return (nothing,) end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, - ::Const{typeof(partialsort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated{T}, - k::Const{<:Union{Integer, OrdinalRange}}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}} +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + xs::Duplicated{T}, + k::Const{<:Union{Integer,OrdinalRange}}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} kv = k.val inds = collect(eachindex(xs.val)) partialsortperm!(inds, xs.val, kv; kwargs...) @@ -672,18 +835,19 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, end end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, - ::Const{typeof(partialsort!)}, - RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, - xs::BatchDuplicated{T, N}, - k::Const{<:Union{Integer, OrdinalRange}}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}, N} +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const,BatchDuplicatedNoNeed,BatchDuplicated}}, + xs::BatchDuplicated{T,N}, + k::Const{<:Union{Integer,OrdinalRange}}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat},N} kv = k.val inds = collect(eachindex(xs.val)) partialsortperm!(inds, xs.val, kv; kwargs...) xs.val .= xs.val[inds] - for i in 1:N + for i = 1:N xs.dval[i] .= xs.dval[i][inds] end @@ -707,13 +871,13 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, end function EnzymeRules.augmented_primal( - config::EnzymeRules.RevConfigWidth{1}, - ::Const{typeof(partialsort!)}, - RT::Type{<:Union{Const, Active, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated{T}, - k::Const{<:Union{Integer, OrdinalRange}}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}} + config::EnzymeRules.RevConfigWidth{1}, + ::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const,Active,DuplicatedNoNeed,Duplicated}}, + xs::Duplicated{T}, + k::Const{<:Union{Integer,OrdinalRange}}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} kv = k.val inds = collect(eachindex(xs.val)) partialsortperm!(inds, xs.val, kv; kwargs...) @@ -733,14 +897,14 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.RevConfigWidth{1}, - ::Const{typeof(partialsort!)}, - dret::Union{Active, Type{<:Union{Const, Active, DuplicatedNoNeed, Duplicated}}}, - tape, - xs::Duplicated{T}, - k::Const{<:Union{Integer, OrdinalRange}}; - kwargs..., - ) where {T <: AbstractArray{<:AbstractFloat}} + config::EnzymeRules.RevConfigWidth{1}, + ::Const{typeof(partialsort!)}, + dret::Union{Active,Type{<:Union{Const,Active,DuplicatedNoNeed,Duplicated}}}, + tape, + xs::Duplicated{T}, + k::Const{<:Union{Integer,OrdinalRange}}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} inds = tape kv = k.val if dret isa Active @@ -760,11 +924,14 @@ end # -> # B(out) = inv(A) B(in) # dB(out) = inv(A) [ dB(in) - dA B(out) ] -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{typeof(ldiv!)}, - RT::Type{<:Union{Const,Duplicated,BatchDuplicated}}, - fact::Annotation{<:Cholesky}, - B::Annotation{<:AbstractVecOrMat}; - kwargs...) +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + func::Const{typeof(ldiv!)}, + RT::Type{<:Union{Const,Duplicated,BatchDuplicated}}, + fact::Annotation{<:Cholesky}, + B::Annotation{<:AbstractVecOrMat}; + kwargs..., +) if B isa Const retval = func.val(fact.val, B.val; kwargs...) if EnzymeRules.needs_primal(config) @@ -827,10 +994,16 @@ end # Float64 ranges in Julia use bitwise `&` with higher precision # to correct for numerical error, thus we put rules over the # operations as this is not directly differentiable -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{Colon}, - RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated, - BatchDuplicated,BatchDuplicatedNoNeed}}, - start::Annotation{<:AbstractFloat}, step::Annotation{<:AbstractFloat}, stop::Annotation{<:AbstractFloat}) +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + func::Const{Colon}, + RT::Type{ + <:Union{Const,DuplicatedNoNeed,Duplicated,BatchDuplicated,BatchDuplicatedNoNeed}, + }, + start::Annotation{<:AbstractFloat}, + step::Annotation{<:AbstractFloat}, + stop::Annotation{<:AbstractFloat}, +) ret = func.val(start.val, step.val, stop.val) dstart = if start isa Const zero(eltype(ret)) @@ -839,7 +1012,9 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{Colon}, elseif start isa BatchDuplicated || start isa BatchDuplicatedNoNeed ntuple(i -> start.dval[i], Val(EnzymeRules.width(config))) else - error("Annotation type $(typeof(start)) not supported for range start. Please open an issue") + error( + "Annotation type $(typeof(start)) not supported for range start. Please open an issue", + ) end dstep = if step isa Const @@ -849,25 +1024,39 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{Colon}, elseif step isa BatchDuplicated || step isa BatchDuplicatedNoNeed ntuple(i -> step.dval[i], Val(EnzymeRules.width(config))) else - error("Annotation type $(typeof(start)) not supported for range step. Please open an issue") + error( + "Annotation type $(typeof(start)) not supported for range step. Please open an issue", + ) end if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) if EnzymeRules.width(config) == 1 - return Duplicated(ret, range(dstart; step=dstep, length=length(ret))) + return Duplicated(ret, range(dstart; step = dstep, length = length(ret))) else - return BatchDuplicated(ret, - ntuple(i -> range(dstart isa Number ? dstart : dstart[i]; - step=dstep isa Number ? dstep : dstep[i], - length=length(ret)), Val(EnzymeRules.width(config)))) + return BatchDuplicated( + ret, + ntuple( + i -> range( + dstart isa Number ? dstart : dstart[i]; + step = dstep isa Number ? dstep : dstep[i], + length = length(ret), + ), + Val(EnzymeRules.width(config)), + ), + ) end elseif EnzymeRules.needs_shadow(config) if EnzymeRules.width(config) == 1 - return range(dstart; step=dstep, length=length(ret)) + return range(dstart; step = dstep, length = length(ret)) else - return ntuple(i -> range(dstart isa Number ? dstart : dstart[i]; - step=dstep isa Number ? dstep : dstep[i], - length=length(ret)), Val(EnzymeRules.width(config))) + return ntuple( + i -> range( + dstart isa Number ? dstart : dstart[i]; + step = dstep isa Number ? dstep : dstep[i], + length = length(ret), + ), + Val(EnzymeRules.width(config)), + ) end elseif EnzymeRules.needs_primal(config) return ret @@ -878,8 +1067,14 @@ end -function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{Colon}, ::Type{<:Active}, - start::Annotation{<:AbstractFloat}, step::Annotation{<:AbstractFloat}, stop::Annotation{<:AbstractFloat}) +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + func::Const{Colon}, + ::Type{<:Active}, + start::Annotation{<:AbstractFloat}, + step::Annotation{<:AbstractFloat}, + stop::Annotation{<:AbstractFloat}, +) if EnzymeRules.needs_primal(config) primal = func.val(start.val, step.val, stop.val) @@ -889,8 +1084,15 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const return EnzymeRules.AugmentedReturn(primal, nothing, nothing) end -function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{Colon}, dret, tape::Nothing, - start::Annotation{T1}, step::Annotation{T2}, stop::Annotation{T3}) where {T1<:AbstractFloat, T2<:AbstractFloat, T3<:AbstractFloat} +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + func::Const{Colon}, + dret, + tape::Nothing, + start::Annotation{T1}, + step::Annotation{T2}, + stop::Annotation{T3}, +) where {T1<:AbstractFloat,T2<:AbstractFloat,T3<:AbstractFloat} dstart = if start isa Const nothing @@ -929,11 +1131,12 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{Colon}, end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, - Ty::Const{Type{BigFloat}}, - RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}; - kwargs... - ) +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + Ty::Const{Type{BigFloat}}, + RT::Type{<:Union{DuplicatedNoNeed,Duplicated,BatchDuplicated,BatchDuplicatedNoNeed}}; + kwargs..., +) if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) if EnzymeRules.width(config) == 1 @@ -950,9 +1153,9 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, return Ty.val(; kwargs...) else return ntuple(Val(EnzymeRules.width(config))) do i - Base.@_inline_meta - Ty.val(; kwargs...) - end + Base.@_inline_meta + Ty.val(; kwargs...) + end end elseif EnzymeRules.needs_primal(config) return Ty.val(; kwargs...) @@ -962,11 +1165,11 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, end function EnzymeRules.augmented_primal( - config::EnzymeRules.RevConfig, - Ty::Const{Type{BigFloat}}, - RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}, - kwargs... - ) + config::EnzymeRules.RevConfig, + Ty::Const{Type{BigFloat}}, + RT::Type{<:Union{DuplicatedNoNeed,Duplicated,BatchDuplicated,BatchDuplicatedNoNeed}}, + kwargs..., +) primal = if EnzymeRules.needs_primal(config) Ty.val(; kwargs...) else @@ -988,22 +1191,23 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.RevConfig, - Ty::Const{Type{BigFloat}}, - RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}, - tape, - kwargs..., - ) + config::EnzymeRules.RevConfig, + Ty::Const{Type{BigFloat}}, + RT::Type{<:Union{DuplicatedNoNeed,Duplicated,BatchDuplicated,BatchDuplicatedNoNeed}}, + tape, + kwargs..., +) return () end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, - Ty::Const{typeof(Random.rand!)}, - RT::Type, - rng::Annotation{rngty}, - dst::Annotation{<:Array{FT}}, - smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, - ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + Ty::Const{typeof(Random.rand!)}, + RT::Type, + rng::Annotation{rngty}, + dst::Annotation{<:Array{FT}}, + smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, +) where {rngty<:Union{TaskLocalRNG,Xoshiro},FT<:Union{Float32,Float64}} Ty.val(rng.val, dst.val, smpl.val) if !(dst isa Const) @@ -1017,7 +1221,7 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, end end end - + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) dst elseif EnzymeRules.needs_shadow(config) @@ -1029,13 +1233,14 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, end end -function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, - Ty::Const{typeof(Random.rand!)}, - RT::Type, - rng::Annotation{rngty}, - dst::Annotation{<:Array{FT}}, - smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, - ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + Ty::Const{typeof(Random.rand!)}, + RT::Type, + rng::Annotation{rngty}, + dst::Annotation{<:Array{FT}}, + smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, +) where {rngty<:Union{TaskLocalRNG,Xoshiro},FT<:Union{Float32,Float64}} Ty.val(rng.val, dst.val, smpl.val) if RT <: Duplicated || RT <: DuplicatedNoNeed fill!(dst.dval, 0) @@ -1047,16 +1252,21 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, nothing end end - return EnzymeRules.AugmentedReturn(EnzymeRules.needs_primal(config) ? dst.val : nothing, EnzymeRules.needs_shadow(config) ? dst.dval : nothing, nothing) + return EnzymeRules.AugmentedReturn( + EnzymeRules.needs_primal(config) ? dst.val : nothing, + EnzymeRules.needs_shadow(config) ? dst.dval : nothing, + nothing, + ) end -function EnzymeRules.reverse(config::EnzymeRules.RevConfig, - Ty::Const{typeof(Random.rand!)}, - RT::Type, - tape, - rng::Annotation{rngty}, - dst::Annotation{<:Array{FT}}, - smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, - ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + Ty::Const{typeof(Random.rand!)}, + RT::Type, + tape, + rng::Annotation{rngty}, + dst::Annotation{<:Array{FT}}, + smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, +) where {rngty<:Union{TaskLocalRNG,Xoshiro},FT<:Union{Float32,Float64}} return (nothing, nothing, nothing) end diff --git a/src/pmap.jl b/src/pmap.jl index f5160e0b62..a46d1c777a 100644 --- a/src/pmap.jl +++ b/src/pmap.jl @@ -5,17 +5,17 @@ function pmap(body::Body, count, args::Vararg{Any,N}) where {Body,N} tasks = Vector{Task}(undef, n_gen) cnt = (count + n_gen - 1) ÷ n_gen for i = 0:(n_gen-1) - let start = i * cnt, endv = min(count, (i+1) * cnt)-1 - t = Task() do - for j in start:endv - body(j+1, args...) - end - nothing - end - t.sticky = true - ccall(:jl_set_task_tid, Cint, (Any, Cint), t, i) - @inbounds tasks[i+1] = t - schedule(t) + let start = i * cnt, endv = min(count, (i + 1) * cnt) - 1 + t = Task() do + for j = start:endv + body(j + 1, args...) + end + nothing + end + t.sticky = true + ccall(:jl_set_task_tid, Cint, (Any, Cint), t, i) + @inbounds tasks[i+1] = t + schedule(t) end end try @@ -28,27 +28,27 @@ function pmap(body::Body, count, args::Vararg{Any,N}) where {Body,N} end macro parallel(args...) - captured = args[1:end-1] - ex = args[end] - if !(isa(ex, Expr) && ex.head === :for) - throw(ArgumentError("@parallel requires a `for` loop expression")) - end - if !(ex.args[1] isa Expr && ex.args[1].head === :(=)) + captured = args[1:end-1] + ex = args[end] + if !(isa(ex, Expr) && ex.head === :for) + throw(ArgumentError("@parallel requires a `for` loop expression")) + end + if !(ex.args[1] isa Expr && ex.args[1].head === :(=)) throw(ArgumentError("nested outer loops are not currently supported by @parallel")) - end - iter = ex.args[1] - lidx = iter.args[1] # index - range = iter.args[2] - body = ex.args[2] - esc(quote - let range = $(range) - function bodyf(idx, iter, $(captured...)) - local $(lidx) = @inbounds iter[idx] - $(body) - nothing - end - lenr = length(range) - $pmap(bodyf, lenr, range, $(captured...)) - end - end) + end + iter = ex.args[1] + lidx = iter.args[1] # index + range = iter.args[2] + body = ex.args[2] + esc(quote + let range = $(range) + function bodyf(idx, iter, $(captured...)) + local $(lidx) = @inbounds iter[idx] + $(body) + nothing + end + lenr = length(range) + $pmap(bodyf, lenr, range, $(captured...)) + end + end) end diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl index 9e32023957..7a940259fa 100644 --- a/src/rules/activityrules.jl +++ b/src/rules/activityrules.jl @@ -1,10 +1,13 @@ function julia_activity_rule(f::LLVM.Function) + if startswith(LLVM.name(f)) == "japi3" + return + end mi, RT = enzyme_custom_extract_mi(f) - llRT, sret, returnRoots = get_return_info(RT) + llRT, sret, returnRoots = get_return_info(RT) retRemoved, parmsRemoved = removed_ret_parms(f) - + dl = string(LLVM.datalayout(LLVM.parent(f))) expectLen = (sret !== nothing) + (returnRoots !== nothing) @@ -12,11 +15,18 @@ function julia_activity_rule(f::LLVM.Function) if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) continue end - expectLen+=1 + expectLen += 1 end expectLen -= length(parmsRemoved) - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(f, i)))) for i in 1:length(collect(parameters(f)))) + swiftself = any( + any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(f, i)), + ), + ) for i = 1:length(collect(parameters(f))) + ) if swiftself expectLen += 1 @@ -31,21 +41,28 @@ function julia_activity_rule(f::LLVM.Function) # TODO fix the attributor inlining such that this can assert always true if expectLen != length(parameters(f)) - msg = sprint() do io::IO - println(io, "Enzyme Internal Error (expectLen != length(parameters(f)))") - println(io, string(f)) - println(io, "expectLen=", string(expectLen)) - println(io, "swiftself=", string(swiftself)) - println(io, "sret=", string(sret)) - println(io, "returnRoots=", string(returnRoots)) - println(io, "mi.specTypes.parameters=", string(mi.specTypes.parameters)) - println(io, "retRemoved=", string(retRemoved)) - println(io, "parmsRemoved=", string(parmsRemoved)) - end - throw(AssertionError(msg)) + msg = sprint() do io::IO + println(io, "Enzyme Internal Error (expectLen != length(parameters(f)))") + println(io, string(f)) + println(io, "expectLen=", string(expectLen)) + println(io, "swiftself=", string(swiftself)) + println(io, "sret=", string(sret)) + println(io, "returnRoots=", string(returnRoots)) + println(io, "mi.specTypes.parameters=", string(mi.specTypes.parameters)) + println(io, "retRemoved=", string(retRemoved)) + println(io, "parmsRemoved=", string(parmsRemoved)) + end + throw(AssertionError(msg)) end - jlargs = classify_arguments(mi.specTypes, function_type(f), sret !== nothing, returnRoots !== nothing, swiftself, parmsRemoved) + jlargs = classify_arguments( + mi.specTypes, + function_type(f), + sret !== nothing, + returnRoots !== nothing, + swiftself, + parmsRemoved, + ) if !Enzyme.Compiler.no_type_setting(mi.specTypes; world)[1] for arg in jlargs @@ -54,12 +71,15 @@ function julia_activity_rule(f::LLVM.Function) end op_idx = arg.codegen.i - + typ, _ = enzyme_extract_parm_type(f, arg.codegen.i) @assert typ == arg.typ if guaranteed_const_nongen(arg.typ, world) - push!(parameter_attributes(f, arg.codegen.i), StringAttribute("enzyme_inactive")) + push!( + parameter_attributes(f, arg.codegen.i), + StringAttribute("enzyme_inactive"), + ) end end end @@ -69,13 +89,19 @@ function julia_activity_rule(f::LLVM.Function) idx = 0 if !in(0, parmsRemoved) if guaranteed_const_nongen(RT, world) - push!(parameter_attributes(f, idx+1), StringAttribute("enzyme_inactive")) + push!( + parameter_attributes(f, idx + 1), + StringAttribute("enzyme_inactive"), + ) end - idx+=1 + idx += 1 end if returnRoots !== nothing if !in(idx, parmsRemoved) - push!(parameter_attributes(f, idx+1), StringAttribute("enzyme_inactive")) + push!( + parameter_attributes(f, idx + 1), + StringAttribute("enzyme_inactive"), + ) end end end diff --git a/src/rules/allocrules.jl b/src/rules/allocrules.jl index 8e626d185f..1c1447dd65 100644 --- a/src/rules/allocrules.jl +++ b/src/rules/allocrules.jl @@ -1,16 +1,27 @@ -function array_inner(::Type{<:Array{T}}) where T +function array_inner(::Type{<:Array{T}}) where {T} return T end -function array_shadow_handler(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, numArgs::Csize_t, Args::Ptr{LLVM.API.LLVMValueRef}, gutils::API.EnzymeGradientUtilsRef)::LLVM.API.LLVMValueRef +function array_shadow_handler( + B::LLVM.API.LLVMBuilderRef, + OrigCI::LLVM.API.LLVMValueRef, + numArgs::Csize_t, + Args::Ptr{LLVM.API.LLVMValueRef}, + gutils::API.EnzymeGradientUtilsRef, +)::LLVM.API.LLVMValueRef inst = LLVM.Instruction(OrigCI) mod = LLVM.parent(LLVM.parent(LLVM.parent(inst))) ctx = LLVM.context(LLVM.Value(OrigCI)) gutils = GradientUtils(gutils) - legal, typ = abs_typeof(inst) + legal, typ, byref = abs_typeof(inst) if !legal - throw(AssertionError("Could not statically ahead-of-time determine allocation element type of "*string(inst))) + throw( + AssertionError( + "Could not statically ahead-of-time determine allocation element type of " * + string(inst), + ), + ) end typ = eltype(typ) @@ -25,7 +36,7 @@ function array_shadow_handler(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMV push!(valTys, API.VT_Primal) end - anti = call_samefunc_with_inverted_bundles!(b, gutils, orig, vals, valTys, #=lookup=#false) + anti = call_samefunc_with_inverted_bundles!(b, gutils, orig, vals, valTys, false) #=lookup=# prod = get_array_len(b, anti) @@ -33,11 +44,11 @@ function array_shadow_handler(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMV isunion = typ isa Union - LLT_ALIGN(x, sz) = (((x) + (sz)-1) & ~((sz)-1)) + LLT_ALIGN(x, sz) = (((x) + (sz) - 1) & ~((sz) - 1)) if !isunboxed elsz = sizeof(Ptr{Cvoid}) - al = elsz; + al = elsz else elsz = LLT_ALIGN(elsz, al) end @@ -63,7 +74,11 @@ function array_shadow_handler(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMV return ref end -function null_free_handler(B::LLVM.API.LLVMBuilderRef, ToFree::LLVM.API.LLVMValueRef, Fn::LLVM.API.LLVMValueRef)::LLVM.API.LLVMValueRef +function null_free_handler( + B::LLVM.API.LLVMBuilderRef, + ToFree::LLVM.API.LLVMValueRef, + Fn::LLVM.API.LLVMValueRef, +)::LLVM.API.LLVMValueRef return C_NULL end @@ -76,22 +91,78 @@ end @inline function register_alloc_rules() register_alloc_handler!( ("jl_alloc_array_1d", "ijl_alloc_array_1d"), - @cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef}, API.EnzymeGradientUtilsRef)), - @cfunction(null_free_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef)) + @cfunction( + array_shadow_handler, + LLVM.API.LLVMValueRef, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + Csize_t, + Ptr{LLVM.API.LLVMValueRef}, + API.EnzymeGradientUtilsRef, + ) + ), + @cfunction( + null_free_handler, + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef) + ) ) register_alloc_handler!( ("jl_alloc_array_2d", "ijl_alloc_array_2d"), - @cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef}, API.EnzymeGradientUtilsRef)), - @cfunction(null_free_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef)) + @cfunction( + array_shadow_handler, + LLVM.API.LLVMValueRef, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + Csize_t, + Ptr{LLVM.API.LLVMValueRef}, + API.EnzymeGradientUtilsRef, + ) + ), + @cfunction( + null_free_handler, + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef) + ) ) register_alloc_handler!( ("jl_alloc_array_3d", "ijl_alloc_array_3d"), - @cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef}, API.EnzymeGradientUtilsRef)), - @cfunction(null_free_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef)) + @cfunction( + array_shadow_handler, + LLVM.API.LLVMValueRef, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + Csize_t, + Ptr{LLVM.API.LLVMValueRef}, + API.EnzymeGradientUtilsRef, + ) + ), + @cfunction( + null_free_handler, + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef) + ) ) register_alloc_handler!( ("jl_new_array", "ijl_new_array"), - @cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef}, API.EnzymeGradientUtilsRef)), - @cfunction(null_free_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef)) + @cfunction( + array_shadow_handler, + LLVM.API.LLVMValueRef, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + Csize_t, + Ptr{LLVM.API.LLVMValueRef}, + API.EnzymeGradientUtilsRef, + ) + ), + @cfunction( + null_free_handler, + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef) + ) ) -end \ No newline at end of file +end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 08cd15facb..1985283da3 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -1,5 +1,13 @@ -function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, mi, @nospecialize(RT), reverse::Bool, isKWCall::Bool) +function enzyme_custom_setup_args( + B, + orig::LLVM.CallInst, + gutils::GradientUtils, + mi, + @nospecialize(RT), + reverse::Bool, + isKWCall::Bool, +) ops = collect(operands(orig)) called = ops[end] ops = ops[1:end-1] @@ -12,10 +20,10 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, actives = LLVM.Value[] - mixeds = Tuple{LLVM.Value, Type, LLVM.Value}[] + mixeds = Tuple{LLVM.Value,Type,LLVM.Value}[] uncacheable = get_uncacheable(gutils, orig) mode = get_mode(gutils) - + retRemoved, parmsRemoved = removed_ret_parms(orig) @assert length(parmsRemoved) == 0 @@ -25,8 +33,22 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, returnRoots = returnRoots !== nothing cv = LLVM.called_operand(orig) - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(cv, i)))) for i in 1:length(collect(parameters(cv)))) - jlargs = classify_arguments(mi.specTypes, called_type(orig), sret, returnRoots, swiftself, parmsRemoved) + swiftself = any( + any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(cv, i)), + ), + ) for i = 1:length(collect(parameters(cv))) + ) + jlargs = classify_arguments( + mi.specTypes, + called_type(orig), + sret, + returnRoots, + swiftself, + parmsRemoved, + ) alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) @@ -49,23 +71,33 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, push!(overwritten, false) end if B !== nothing - if Core.Compiler.isconstType(arg.typ) && !Core.Compiler.isconstType(Const{arg.typ}) - llty = convert(LLVMType, Const{arg.typ}) - al0 = al = emit_allocobj!(B, Const{arg.typ}) - al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) - al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) - - ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) - val = unsafe_to_llvm(B, arg.typ.parameters[1]) - store!(B, val, ptr) + if Core.Compiler.isconstType(arg.typ) && + !Core.Compiler.isconstType(Const{arg.typ}) + llty = convert(LLVMType, Const{arg.typ}) + al0 = al = emit_allocobj!(B, Const{arg.typ}) + al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) + al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) + + ptr = inbounds_gep!( + B, + llty, + al, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), 0), + ], + ) + val = unsafe_to_llvm(B, arg.typ.parameters[1]) + store!(B, val, ptr) - if any_jltypes(llty) - emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) + if any_jltypes(llty) + emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) + end + push!(args, al) + else + @assert isghostty(Const{arg.typ}) || + Core.Compiler.isconstType(Const{arg.typ}) end - push!(args, al) - else - @assert isghostty(Const{arg.typ}) || Core.Compiler.isconstType(Const{arg.typ}) - end end continue end @@ -82,7 +114,7 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, val = lookup_value(gutils, val, B) end - activep = API.EnzymeGradientUtilsGetDiffeType(gutils, op, #=isforeign=#false) + activep = API.EnzymeGradientUtilsGetDiffeType(gutils, op, false) #=isforeign=# if isKWCall && arg.arg_i == 2 Ty = arg.typ @@ -103,13 +135,21 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, if activep == API.DFT_CONSTANT Ty = Const{arg.typ} llty = convert(LLVMType, Ty) - arty = convert(LLVMType, arg.typ; allow_boxed=true) + arty = convert(LLVMType, arg.typ; allow_boxed = true) if B !== nothing al0 = al = emit_allocobj!(B, Ty) al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) - ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) + ptr = inbounds_gep!( + B, + llty, + al, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), 0), + ], + ) if value_type(val) != eltype(value_type(ptr)) val = load!(B, arty, val) end @@ -124,30 +164,45 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, push!(activity, Ty) - elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg_inner(arg.typ, (), world) == ActiveState) + elseif activep == API.DFT_OUT_DIFF || ( + mode != API.DEM_ForwardMode && + active_reg_inner(arg.typ, (), world) == ActiveState + ) Ty = Active{arg.typ} llty = convert(LLVMType, Ty) - arty = convert(LLVMType, arg.typ; allow_boxed=true) + arty = convert(LLVMType, arg.typ; allow_boxed = true) if B !== nothing al0 = al = emit_allocobj!(B, Ty) al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) - ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) + ptr = inbounds_gep!( + B, + llty, + al, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), 0), + ], + ) if value_type(val) != eltype(value_type(ptr)) if overwritten[end] emit_error( B, orig, - "Enzyme: active by ref type $Ty is overwritten in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr)). " - * "As a workaround until support for this is added, try passing values as separate arguments rather than as an aggregate of type $Ty.", + "Enzyme: active by ref type $Ty is overwritten in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr)). " * + "As a workaround until support for this is added, try passing values as separate arguments rather than as an aggregate of type $Ty.", ) end if arty == eltype(value_type(val)) val = load!(B, arty, val) else val = LLVM.UndefValue(arty) - emit_error(B, orig, "Enzyme: active by ref type $Ty is wrong type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))") + emit_error( + B, + orig, + "Enzyme: active by ref type $Ty is wrong type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))", + ) end end @@ -157,7 +212,11 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) end else - emit_error(B, orig, "Enzyme: active by ref type $Ty is wrong store type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))") + emit_error( + B, + orig, + "Enzyme: active by ref type $Ty is wrong store type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))", + ) end push!(args, al) @@ -193,21 +252,21 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, if active_reg_inner(arg.typ, (), world) == MixedState # TODO batchmixedupnoneed shadowty = Base.RefValue{shadowty} - Ty = BatchMixedDuplicated{arg.typ, Int(width)} + Ty = BatchMixedDuplicated{arg.typ,Int(width)} mixed = true else if activep == API.DFT_DUP_ARG - Ty = BatchDuplicated{arg.typ, Int(width)} + Ty = BatchDuplicated{arg.typ,Int(width)} else @assert activep == API.DFT_DUP_NONEED - Ty = BatchDuplicatedNoNeed{arg.typ, Int(width)} + Ty = BatchDuplicatedNoNeed{arg.typ,Int(width)} end end end llty = convert(LLVMType, Ty) - arty = convert(LLVMType, arg.typ; allow_boxed=true) - iarty = convert(LLVMType, shadowty; allow_boxed=true) + arty = convert(LLVMType, arg.typ; allow_boxed = true) + iarty = convert(LLVMType, shadowty; allow_boxed = true) sarty = LLVM.LLVMType(API.EnzymeGetShadowType(width, arty)) siarty = LLVM.LLVMType(API.EnzymeGetShadowType(width, iarty)) if B !== nothing @@ -215,42 +274,63 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) - ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) + ptr = inbounds_gep!( + B, + llty, + al, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), 0), + ], + ) needsload = false if value_type(val) != eltype(value_type(ptr)) val = load!(B, arty, val) if !mixed ptr_val = ival ival = UndefValue(siarty) - for idx in 1:width - ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1) + for idx = 1:width + ev = + (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx - 1) ld = load!(B, iarty, ev) - ival = (width == 1 ) ? ld : insert_value!(B, ival, ld, idx-1) + ival = (width == 1) ? ld : insert_value!(B, ival, ld, idx - 1) end end needsload = true end store!(B, val, ptr) - iptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 1)]) - + iptr = inbounds_gep!( + B, + llty, + al, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), 1), + ], + ) + if mixed RefTy = arg.typ if width != 1 - RefTy = NTuple{Int(width), RefTy} + RefTy = NTuple{Int(width),RefTy} end llrty = convert(LLVMType, RefTy) RefTy = Base.RefValue{RefTy} refal0 = refal = emit_allocobj!(B, RefTy) - refal = bitcast!(B, refal, LLVM.PointerType(llrty, addrspace(value_type(refal)))) + refal = bitcast!( + B, + refal, + LLVM.PointerType(llrty, addrspace(value_type(refal))), + ) @assert needsload ptr_val = ival ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llrty))) - for idx in 1:width - ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1) + for idx = 1:width + ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx - 1) ld = load!(B, llrty, ev) - ival = (width == 1 ) ? ld : insert_value!(B, ival, ld, idx-1) + ival = (width == 1) ? ld : insert_value!(B, ival, ld, idx - 1) end store!(B, ival, refal) emit_writebarrier!(B, get_julia_inner_types(B, refal0, ival)) @@ -273,10 +353,16 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, return args, activity, (overwritten...,), actives, kwtup, mixeds end -function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt), B) +function enzyme_custom_setup_ret( + gutils::GradientUtils, + orig::LLVM.CallInst, + mi, + @nospecialize(RealRt), + B, +) width = get_width(gutils) mode = get_mode(gutils) - + world = enzyme_extract_world(LLVM.parent(LLVM.parent(orig))) needsShadowP = Ref{UInt8}(0) @@ -286,28 +372,41 @@ function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, # calls differential use analysis to determine needsprimal/shadow. However, since now this function # is used as part of differential use analysis, we need to avoid an ininite recursion. Thus use # the version without differential use if actual unreachable results are not available anyways. - uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) + uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig))) - 1) cmode = mode if cmode == API.DEM_ReverseModeGradient cmode = API.DEM_ReverseModePrimal end - activep = if mode == API.DEM_ForwardMode || API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 1 - API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, cmode) - else - actv = API.EnzymeGradientUtilsGetDiffeType(gutils, orig, false) - if !isghostty(RealRt) - needsPrimalP[] = 1 - if actv == API.DFT_DUP_ARG || actv == API.DFT_DUP_NONEED - needsShadowP[] = 1 + activep = + if mode == API.DEM_ForwardMode || + API.EnzymeGradientUtilsGetUncacheableArgs( + gutils, + orig, + uncacheable, + length(uncacheable), + ) == 1 + API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + cmode, + ) + else + actv = API.EnzymeGradientUtilsGetDiffeType(gutils, orig, false) + if !isghostty(RealRt) + needsPrimalP[] = 1 + if actv == API.DFT_DUP_ARG || actv == API.DFT_DUP_NONEED + needsShadowP[] = 1 + end end + actv end - actv - end needsPrimal = needsPrimalP[] != 0 origNeedsPrimal = needsPrimal _, sret, _ = get_return_info(RealRt) if sret !== nothing - activep = API.EnzymeGradientUtilsGetDiffeType(gutils, operands(orig)[1], #=isforeign=#false) + activep = API.EnzymeGradientUtilsGetDiffeType(gutils, operands(orig)[1], false) #=isforeign=# needsPrimal = activep == API.DFT_DUP_ARG || activep == API.DFT_CONSTANT needsShadowP[] = activep == API.DFT_DUP_ARG || activep == API.DFT_DUP_NONEED end @@ -315,13 +414,20 @@ function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, if !needsPrimal && activep == API.DFT_DUP_ARG activep = API.DFT_DUP_NONEED end - + if activep == API.DFT_CONSTANT RT = Const{RealRt} - elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg_inner(RealRt, (), world, #=justActive=#Val(true)) == ActiveState) - if active_reg_inner(RealRt, (), world, #=justActive=#Val(false)) == MixedState && B !== nothing - emit_error(B, orig, "Enzyme: Return type $RealRt has mixed internal activity types in evaluation of custom rule for $mi. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information") + elseif activep == API.DFT_OUT_DIFF || ( + mode != API.DEM_ForwardMode && + active_reg_inner(RealRt, (), world, Val(true)) == ActiveState + ) #=justActive=# + if active_reg_inner(RealRt, (), world, Val(false)) == MixedState && B !== nothing #=justActive=# + emit_error( + B, + orig, + "Enzyme: Return type $RealRt has mixed internal activity types in evaluation of custom rule for $mi. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information", + ) end RT = Active{RealRt} @@ -329,20 +435,20 @@ function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, if width == 1 RT = Duplicated{RealRt} else - RT = BatchDuplicated{RealRt, Int(width)} + RT = BatchDuplicated{RealRt,Int(width)} end else @assert activep == API.DFT_DUP_NONEED if width == 1 RT = DuplicatedNoNeed{RealRt} else - RT = BatchDuplicatedNoNeed{RealRt, Int(width)} + RT = BatchDuplicatedNoNeed{RealRt,Int(width)} end end return RT, needsPrimal, needsShadowP[] != 0, origNeedsPrimal end -function custom_rule_method_error(world, fn, args...) +function custom_rule_method_error(world, fn, args...) throw(MethodError(fn, (args...,), world)) end @@ -354,7 +460,10 @@ end width = get_width(gutils) if shadowR != C_NULL - unsafe_store!(shadowR,UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))).ref) + unsafe_store!( + shadowR, + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))).ref, + ) end # TODO: don't inject the code multiple times for multiple calls @@ -370,10 +479,17 @@ end end # 2) Create activity, and annotate function spec - args, activity, overwritten, actives, kwtup, _ = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#false, isKWCall) - RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B) - - C = EnzymeRules.FwdConfig{Bool(needsPrimal), Bool(needsShadow), Int(width), get_runtime_activity(gutils)} + args, activity, overwritten, actives, kwtup, _ = + enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, false, isKWCall) #=reverse=# + RT, needsPrimal, needsShadow, origNeedsPrimal = + enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B) + + C = EnzymeRules.FwdConfig{ + Bool(needsPrimal), + Bool(needsShadow), + Int(width), + get_runtime_activity(gutils), + } alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) @@ -413,7 +529,7 @@ end llvmf = nested_codegen!(mode, mod, kwfunc, TT, world) fwd_RT = Core.Compiler.return_type(kwfunc, TT, world) else - TT = Tuple{typeof(world), typeof(kwfunc), TT.parameters...} + TT = Tuple{typeof(world),typeof(kwfunc),TT.parameters...} llvmf = nested_codegen!(mode, mod, custom_rule_method_error, TT, world) pushfirst!(args, LLVM.ConstantInt(world)) fwd_RT = Union{} @@ -424,16 +540,23 @@ end llvmf = nested_codegen!(mode, mod, EnzymeRules.forward, TT, world) fwd_RT = Core.Compiler.return_type(EnzymeRules.forward, TT, world) else - TT = Tuple{typeof(world), typeof(EnzymeRules.forward), TT.parameters...} + TT = Tuple{typeof(world),typeof(EnzymeRules.forward),TT.parameters...} llvmf = nested_codegen!(mode, mod, custom_rule_method_error, TT, world) pushfirst!(args, LLVM.ConstantInt(world)) fwd_RT = Union{} end end - + push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(llvmf, i)))) for i in 1:length(collect(parameters(llvmf)))) + swiftself = any( + any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(llvmf, i)), + ), + ) for i = 1:length(collect(parameters(llvmf))) + ) if swiftself pushfirst!(reinsert_gcmarker!(fn, B)) end @@ -452,7 +575,16 @@ end end if length(args) != length(parameters(llvmf)) - GPUCompiler.@safe_error "Calling convention mismatch", args, llvmf, string(value_type(llvmf)), orig, isKWCall, kwtup, TT, sret, returnRoots + GPUCompiler.@safe_error "Calling convention mismatch", + args, + llvmf, + string(value_type(llvmf)), + orig, + isKWCall, + kwtup, + TT, + sret, + returnRoots return false end @@ -471,7 +603,12 @@ end debug_from_orig!(gutils, res, orig) callconv!(res, callconv(llvmf)) - hasNoRet = any(map(k->kind(k)==kind(EnumAttribute("noreturn")), collect(function_attributes(llvmf)))) + hasNoRet = any( + map( + k -> kind(k) == kind(EnumAttribute("noreturn")), + collect(function_attributes(llvmf)), + ), + ) if hasNoRet return false @@ -488,7 +625,11 @@ end end if swiftself attr = EnumAttribute("swiftself") - LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(1+(sret !== nothing)), attr) + LLVM.API.LLVMAddCallSiteAttribute( + res, + LLVM.API.LLVMAttributeIndex(1 + (sret !== nothing)), + attr, + ) end shadowV = C_NULL @@ -497,7 +638,18 @@ end if RT <: Const if needsPrimal if RealRt != fwd_RT - emit_error(B, orig, "Enzyme: incorrect return type of const primal-only forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just return type "*string(RealRt)*" found "*string(fwd_RT)) + emit_error( + B, + orig, + "Enzyme: incorrect return type of const primal-only forward custom rule - $C " * + (string(RT)) * + " " * + string(activity) * + " want just return type " * + string(RealRt) * + " found " * + string(fwd_RT), + ) return false end if get_return_info(RealRt)[2] !== nothing @@ -508,7 +660,16 @@ end end else if Nothing != fwd_RT - emit_error(B, orig, "Enzyme: incorrect return type of const no-primal forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just return type Nothing found "*string(fwd_RT)) + emit_error( + B, + orig, + "Enzyme: incorrect return type of const no-primal forward custom rule - $C " * + (string(RT)) * + " " * + string(activity) * + " want just return type Nothing found " * + string(fwd_RT), + ) return false end end @@ -516,17 +677,28 @@ end if !needsPrimal ST = RealRt if width != 1 - ST = NTuple{Int(width), ST} + ST = NTuple{Int(width),ST} end if ST != fwd_RT - emit_error(B, orig, "Enzyme: incorrect return type of shadow-only forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just shadow type "*string(ST)*" found "*string(fwd_RT)) + emit_error( + B, + orig, + "Enzyme: incorrect return type of shadow-only forward custom rule - $C " * + (string(RT)) * + " " * + string(activity) * + " want just shadow type " * + string(ST) * + " found " * + string(fwd_RT), + ) return false end if get_return_info(RealRt)[2] !== nothing dval_ptr = invert_pointer(gutils, operands(orig)[1], B) - for idx in 1:width - ev = (width == 1) ? dval : extract_value!(B, dval, idx-1) - pev = (width == 1) ? dval_ptr : extract_value!(B, dval_ptr, idx-1) + for idx = 1:width + ev = (width == 1) ? dval : extract_value!(B, dval, idx - 1) + pev = (width == 1) ? dval_ptr : extract_value!(B, dval_ptr, idx - 1) store!(B, res, pev) end else @@ -536,21 +708,32 @@ end ST = if width == 1 Duplicated{RealRt} else - BatchDuplicated{RealRt, Int(width)} + BatchDuplicated{RealRt,Int(width)} end if ST != fwd_RT - emit_error(B, orig, "Enzyme: incorrect return type of prima/shadow forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just shadow type "*string(ST)*" found "*string(fwd_RT)) + emit_error( + B, + orig, + "Enzyme: incorrect return type of prima/shadow forward custom rule - $C " * + (string(RT)) * + " " * + string(activity) * + " want just shadow type " * + string(ST) * + " found " * + string(fwd_RT), + ) return false end if get_return_info(RealRt)[2] !== nothing val = new_from_original(gutils, operands(orig)[1]) store!(B, extract_value!(B, res, 0), val) - + dval_ptr = invert_pointer(gutils, operands(orig)[1], B) dval = extract_value!(B, res, 1) - for idx in 1:width - ev = (width == 1) ? dval : extract_value!(B, dval, idx-1) - pev = (width == 1) ? dval_ptr : extract_value!(B, dval_ptr, idx-1) + for idx = 1:width + ev = (width == 1) ? dval : extract_value!(B, dval, idx - 1) + pev = (width == 1) ? dval_ptr : extract_value!(B, dval_ptr, idx - 1) store!(B, ev, pev) end else @@ -570,7 +753,11 @@ end else ni = new_from_original(gutils, orig) if value_type(ni) != LLVM.VoidType() - API.EnzymeGradientUtilsReplaceAWithB(gutils, ni, LLVM.UndefValue(value_type(ni))) + API.EnzymeGradientUtilsReplaceAWithB( + gutils, + ni, + LLVM.UndefValue(value_type(ni)), + ) end API.EnzymeGradientUtilsErase(gutils, ni) end @@ -578,7 +765,12 @@ end return false end -@inline function aug_fwd_mi(orig::LLVM.CallInst, gutils::GradientUtils, forward=false, B=nothing) +@inline function aug_fwd_mi( + orig::LLVM.CallInst, + gutils::GradientUtils, + forward = false, + B = nothing, +) width = get_width(gutils) # 1) extract out the MI from attributes @@ -586,8 +778,10 @@ end isKWCall = isKWCallSignature(mi.specTypes) # 2) Create activity, and annotate function spec - args, activity, overwritten, actives, kwtup, mixeds = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#!forward, isKWCall) - RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B) + args, activity, overwritten, actives, kwtup, mixeds = + enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, !forward, isKWCall) #=reverse=# + RT, needsPrimal, needsShadow, origNeedsPrimal = + enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B) needsShadowJL = if RT <: Active false @@ -598,8 +792,14 @@ end fn = LLVM.parent(LLVM.parent(orig)) world = enzyme_extract_world(fn) - C = EnzymeRules.RevConfig{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten, get_runtime_activity(gutils)} - + C = EnzymeRules.RevConfig{ + Bool(needsPrimal), + Bool(needsShadowJL), + Int(width), + overwritten, + get_runtime_activity(gutils), + } + mode = get_mode(gutils) ami = nothing @@ -617,10 +817,14 @@ end kwfunc = Core.kwfunc(EnzymeRules.augmented_primal) try ami = GPUCompiler.methodinstance(Core.Typeof(kwfunc), augprimal_TT, world) - @safe_debug "Applying custom augmented_primal rule (kwcall)" TT=augprimal_TT + @safe_debug "Applying custom augmented_primal rule (kwcall)" TT = augprimal_TT catch e - augprimal_TT = Tuple{typeof(world), typeof(kwfunc), augprimal_TT.parameters...} - ami = GPUCompiler.methodinstance(typeof(custom_rule_method_error), augprimal_TT, world) + augprimal_TT = Tuple{typeof(world),typeof(kwfunc),augprimal_TT.parameters...} + ami = GPUCompiler.methodinstance( + typeof(custom_rule_method_error), + augprimal_TT, + world, + ) if forward pushfirst!(args, LLVM.ConstantInt(world)) end @@ -632,24 +836,57 @@ end augprimal_TT = Tuple{augprimal_tt...} try - ami = GPUCompiler.methodinstance(Core.Typeof(EnzymeRules.augmented_primal), augprimal_TT, world) - @safe_debug "Applying custom augmented_primal rule" TT=augprimal_TT + ami = GPUCompiler.methodinstance( + Core.Typeof(EnzymeRules.augmented_primal), + augprimal_TT, + world, + ) + @safe_debug "Applying custom augmented_primal rule" TT = augprimal_TT catch e - augprimal_TT = Tuple{typeof(world), typeof(EnzymeRules.augmented_primal), augprimal_TT.parameters...} - ami = GPUCompiler.methodinstance(typeof(custom_rule_method_error), augprimal_TT, world) + augprimal_TT = Tuple{ + typeof(world), + typeof(EnzymeRules.augmented_primal), + augprimal_TT.parameters..., + } + ami = GPUCompiler.methodinstance( + typeof(custom_rule_method_error), + augprimal_TT, + world, + ) if forward pushfirst!(args, LLVM.ConstantInt(world)) end end end - return ami, augprimal_TT, (args, activity, overwritten, actives, kwtup, RT, needsPrimal, needsShadow, origNeedsPrimal, mixeds) + return ami, + augprimal_TT, + ( + args, + activity, + overwritten, + actives, + kwtup, + RT, + needsPrimal, + needsShadow, + origNeedsPrimal, + mixeds, + ) end @inline function has_aug_fwd_rule(orig, gutils) return aug_fwd_mi(orig, gutils)[1] !== nothing end -@register_rev function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, normalR, shadowR, tape)::LLVM.API.LLVMValueRef +@register_rev function enzyme_custom_common_rev( + forward::Bool, + B, + orig::LLVM.CallInst, + gutils, + normalR, + shadowR, + tape, +)::LLVM.API.LLVMValueRef ctx = LLVM.context(orig) @@ -657,7 +894,7 @@ end shadowType = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) if shadowR != C_NULL - unsafe_store!(shadowR,UndefValue(shadowType).ref) + unsafe_store!(shadowR, UndefValue(shadowType).ref) end # TODO: don't inject the code multiple times for multiple calls @@ -668,7 +905,16 @@ end # 2) Create activity, and annotate function spec ami, augprimal_TT, setup = aug_fwd_mi(orig, gutils, forward, B) - args, activity, overwritten, actives, kwtup, RT, needsPrimal, needsShadow, origNeedsPrimal, mixeds = setup + args, + activity, + overwritten, + actives, + kwtup, + RT, + needsPrimal, + needsShadow, + origNeedsPrimal, + mixeds = setup needsShadowJL = if RT <: Active false @@ -676,7 +922,13 @@ end needsShadow end - C = EnzymeRules.RevConfig{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten, get_runtime_activity(gutils)} + C = EnzymeRules.RevConfig{ + Bool(needsPrimal), + Bool(needsShadowJL), + Int(width), + overwritten, + get_runtime_activity(gutils), + } alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) @@ -690,10 +942,24 @@ end @assert ami !== nothing target = DefaultCompilerTarget() params = PrimalCompilerParams(mode) - aug_RT = something(Core.Compiler.typeinf_type(GPUCompiler.get_interpreter(CompilerJob(ami, CompilerConfig(target, params; kernel=false), world)), ami.def, ami.specTypes, ami.sparam_vals), Any) + aug_RT = something( + Core.Compiler.typeinf_type( + GPUCompiler.get_interpreter( + CompilerJob(ami, CompilerConfig(target, params; kernel = false), world), + ), + ami.def, + ami.specTypes, + ami.sparam_vals, + ), + Any, + ) if kwtup !== nothing && kwtup <: Duplicated @safe_debug "Non-constant keyword argument found for " augprimal_TT - emit_error(B, orig, "Enzyme: Non-constant keyword argument found for " * string(augprimal_TT)) + emit_error( + B, + orig, + "Enzyme: Non-constant keyword argument found for " * string(augprimal_TT), + ) return C_NULL end @@ -702,15 +968,26 @@ end TapeT = Nothing - if (aug_RT <: EnzymeRules.AugmentedReturn || aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && !(aug_RT isa UnionAll) && !(aug_RT isa Union) && !(aug_RT === Union{}) + if ( + aug_RT <: EnzymeRules.AugmentedReturn || + aug_RT <: EnzymeRules.AugmentedReturnFlexShadow + ) && + !(aug_RT isa UnionAll) && + !(aug_RT isa Union) && + !(aug_RT === Union{}) TapeT = EnzymeRules.tape_type(aug_RT) - elseif (aug_RT isa UnionAll) && (aug_RT <: EnzymeRules.AugmentedReturn) && aug_RT.body.name == EnzymeCore.EnzymeRules.AugmentedReturn.body.body.body.name + elseif (aug_RT isa UnionAll) && + (aug_RT <: EnzymeRules.AugmentedReturn) && + aug_RT.body.name == EnzymeCore.EnzymeRules.AugmentedReturn.body.body.body.name if aug_RT.body.parameters[3] isa TypeVar TapeT = aug_RT.body.parameters[3].ub else TapeT = Any end - elseif (aug_RT isa UnionAll) && (aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && aug_RT.body.name == EnzymeCore.EnzymeRules.AugmentedReturnFlexShadow.body.body.body.name + elseif (aug_RT isa UnionAll) && + (aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && + aug_RT.body.name == + EnzymeCore.EnzymeRules.AugmentedReturnFlexShadow.body.body.body.name if aug_RT.body.parameters[3] isa TypeVar TapeT = aug_RT.body.parameters[3].ub else @@ -749,11 +1026,11 @@ end if isKWCall rkwfunc = Core.kwfunc(EnzymeRules.reverse) if EnzymeRules.isapplicable(rkwfunc, rev_TT; world) - @safe_debug "Applying custom reverse rule (kwcall)" TT=rev_TT + @safe_debug "Applying custom reverse rule (kwcall)" TT = rev_TT llvmf = nested_codegen!(mode, mod, rkwfunc, rev_TT, world) rev_RT = Core.Compiler.return_type(rkwfunc, rev_TT, world) else - rev_TT = Tuple{typeof(world), typeof(rkwfunc), rev_TT.parameters...} + rev_TT = Tuple{typeof(world),typeof(rkwfunc),rev_TT.parameters...} llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) pushfirst!(args, LLVM.ConstantInt(world)) rev_RT = Union{} @@ -761,11 +1038,12 @@ end end else if EnzymeRules.isapplicable(EnzymeRules.reverse, rev_TT; world) - @safe_debug "Applying custom reverse rule" TT=rev_TT + @safe_debug "Applying custom reverse rule" TT = rev_TT llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, rev_TT, world) rev_RT = Core.Compiler.return_type(EnzymeRules.reverse, rev_TT, world) else - rev_TT = Tuple{typeof(world), typeof(EnzymeRules.reverse), rev_TT.parameters...} + rev_TT = + Tuple{typeof(world),typeof(EnzymeRules.reverse),rev_TT.parameters...} llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) pushfirst!(args, LLVM.ConstantInt(world)) rev_RT = Union{} @@ -780,7 +1058,7 @@ end tapeV = C_NULL if forward && needsTape - tapeV = LLVM.UndefValue(convert(LLVMType, TapeT; allow_boxed=true)).ref + tapeV = LLVM.UndefValue(convert(LLVMType, TapeT; allow_boxed = true)).ref end # if !forward @@ -796,28 +1074,59 @@ end # llvmf = nested_codegen!(mode, mod, rev_func, Tuple{argTys...}, world) # end - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(llvmf, i)))) for i in 1:length(collect(parameters(llvmf)))) + swiftself = any( + any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(llvmf, i)), + ), + ) for i = 1:length(collect(parameters(llvmf))) + ) miRT = enzyme_custom_extract_mi(llvmf)[2] _, sret, returnRoots = get_return_info(miRT) sret_union = is_sret_union(miRT) - if sret_union - emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " had a union sret of type "*string(miRT)*" which is not currently supported") + if sret_union + emit_error( + B, + orig, + "Enzyme: Augmented forward pass custom rule " * + string(augprimal_TT) * + " had a union sret of type " * + string(miRT) * + " which is not currently supported", + ) return tapeV end if !forward - funcTy = rev_TT.parameters[isKWCall ? 4 : 2] + funcTy = rev_TT.parameters[isKWCall ? 4 : 2] if needsTape @assert tape != C_NULL - tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup)) + !isghostty(funcTy) + (!applicablefn) - trueidx = tape_idx+(sret !== nothing)+(returnRoots !== nothing)+swiftself + (RT <: Active) + tape_idx = + 1 + + (kwtup !== nothing && !isghostty(kwtup)) + + !isghostty(funcTy) + + (!applicablefn) + trueidx = + tape_idx + + (sret !== nothing) + + (returnRoots !== nothing) + + swiftself + + (RT <: Active) innerTy = value_type(parameters(llvmf)[trueidx]) if innerTy != value_type(tape) - if isabstracttype(TapeT) || TapeT isa UnionAll || TapeT == Tuple || TapeT.layout == C_NULL || TapeT == Array + if isabstracttype(TapeT) || + TapeT isa UnionAll || + TapeT == Tuple || + TapeT.layout == C_NULL || + TapeT == Array msg = sprint() do io - println(io, "Enzyme : mismatch between innerTy $innerTy and tape type $(value_type(tape))") + println( + io, + "Enzyme : mismatch between innerTy $innerTy and tape type $(value_type(tape))", + ) println(io, "tape_idx=", tape_idx) println(io, "true_idx=", trueidx) println(io, "isKWCall=", isKWCall) @@ -840,7 +1149,7 @@ end end throw(AssertionError(msg)) end - llty = convert(LLVMType, TapeT; allow_boxed=true) + llty = convert(LLVMType, TapeT; allow_boxed = true) al0 = al = emit_allocobj!(B, TapeT) al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) store!(B, tape, al) @@ -855,18 +1164,18 @@ end llty = convert(LLVMType, RT) - if API.EnzymeGradientUtilsGetDiffeType(gutils, orig, #=isforeign=#false) == API.DFT_OUT_DIFF + if API.EnzymeGradientUtilsGetDiffeType(gutils, orig, false) == API.DFT_OUT_DIFF #=isforeign=# val = LLVM.Value(API.EnzymeGradientUtilsDiffe(gutils, orig, B)) API.EnzymeGradientUtilsSetDiffe(gutils, orig, LLVM.null(value_type(val)), B) else - llety = convert(LLVMType, eltype(RT); allow_boxed=true) - ptr_val = invert_pointer(gutils, operands(orig)[1 + !isghostty(funcTy)], B) + llety = convert(LLVMType, eltype(RT); allow_boxed = true) + ptr_val = invert_pointer(gutils, operands(orig)[1+!isghostty(funcTy)], B) val = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llety))) - for idx in 1:width - ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1) + for idx = 1:width + ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx - 1) ld = load!(B, llety, ev) store!(B, LLVM.null(llety), ev) - val = (width == 1 ) ? ld : insert_value!(B, val, ld, idx-1) + val = (width == 1) ? ld : insert_value!(B, val, ld, idx - 1) end end @@ -874,13 +1183,28 @@ end al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) - ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) + ptr = inbounds_gep!( + B, + llty, + al, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), 0), + ], + ) store!(B, val, ptr) if any_jltypes(llty) emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) end - insert!(args, 1+(!isghostty(funcTy))+(kwtup!==nothing && !isghostty(kwtup)) + (!applicablefn), al) + insert!( + args, + 1 + + (!isghostty(funcTy)) + + (kwtup !== nothing && !isghostty(kwtup)) + + (!applicablefn), + al, + ) end end @@ -902,16 +1226,26 @@ end end if length(args) != length(parameters(llvmf)) - GPUCompiler.@safe_error "Calling convention mismatch", args, llvmf, orig, isKWCall, kwtup, augprimal_TT, rev_TT, fn, sret, returnRoots + GPUCompiler.@safe_error "Calling convention mismatch", + args, + llvmf, + orig, + isKWCall, + kwtup, + augprimal_TT, + rev_TT, + fn, + sret, + returnRoots return tapeV end - + T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - for i in 1:length(args) - party = value_type(parameters(llvmf)[i]) + for i = 1:length(args) + party = value_type(parameters(llvmf)[i]) if value_type(args[i]) != party if party == T_prjlvalue while true @@ -939,7 +1273,15 @@ end println(io, "args[i] = ", args[i]) println(io, "party = ", party) end - args[i] = calling_conv_fixup(B, args[i], party, LLVM.UndefValue(party), Cuint[], Cuint[], msg) + args[i] = calling_conv_fixup( + B, + args[i], + party, + LLVM.UndefValue(party), + Cuint[], + Cuint[], + msg, + ) end res = LLVM.call!(B, LLVM.function_type(llvmf), llvmf, args) @@ -947,7 +1289,12 @@ end debug_from_orig!(gutils, res, orig) callconv!(res, callconv(llvmf)) - hasNoRet = any(map(k->kind(k)==kind(EnumAttribute("noreturn")), collect(function_attributes(llvmf)))) + hasNoRet = any( + map( + k -> kind(k) == kind(EnumAttribute("noreturn")), + collect(function_attributes(llvmf)), + ), + ) if hasNoRet return tapeV @@ -959,13 +1306,21 @@ end else attr = EnumAttribute("sret") end - LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(1+swiftself), attr) + LLVM.API.LLVMAddCallSiteAttribute( + res, + LLVM.API.LLVMAttributeIndex(1 + swiftself), + attr, + ) res = load!(B, eltype(value_type(parameters(llvmf)[1+swiftself])), sret) API.SetMustCache!(res) end if swiftself attr = EnumAttribute("swiftself") - LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(1+(sret !== nothing)+(returnRoots !== nothing)), attr) + LLVM.API.LLVMAddCallSiteAttribute( + res, + LLVM.API.LLVMAttributeIndex(1 + (sret !== nothing) + (returnRoots !== nothing)), + attr, + ) end shadowV = C_NULL @@ -975,35 +1330,86 @@ end if forward ShadT = RealRt if width != 1 - ShadT = NTuple{Int(width), RealRt} + ShadT = NTuple{Int(width),RealRt} end - ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, TapeT} + ST = EnzymeRules.AugmentedReturn{ + needsPrimal ? RealRt : Nothing, + needsShadowJL ? ShadT : Nothing, + TapeT, + } if aug_RT != ST if aug_RT <: EnzymeRules.AugmentedReturnFlexShadow - if convert(LLVMType, EnzymeRules.shadow_type(aug_RT); allow_boxed=true) != - convert(LLVMType, EnzymeRules.shadow_type(ST) ; allow_boxed=true) - emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " flex shadow ABI return type mismatch, expected "*string(ST)*" found "* string(aug_RT)) + if convert(LLVMType, EnzymeRules.shadow_type(aug_RT); allow_boxed = true) != + convert(LLVMType, EnzymeRules.shadow_type(ST); allow_boxed = true) + emit_error( + B, + orig, + "Enzyme: Augmented forward pass custom rule " * + string(augprimal_TT) * + " flex shadow ABI return type mismatch, expected " * + string(ST) * + " found " * + string(aug_RT), + ) return tapeV end - ST = EnzymeRules.AugmentedReturnFlexShadow{needsPrimal ? RealRt : Nothing, needsShadowJL ? EnzymeRules.shadow_type(aug_RT) : Nothing, TapeT} + ST = EnzymeRules.AugmentedReturnFlexShadow{ + needsPrimal ? RealRt : Nothing, + needsShadowJL ? EnzymeRules.shadow_type(aug_RT) : Nothing, + TapeT, + } end end abstract = false if aug_RT != ST - abs = (EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, T} where T) + abs = ( + EnzymeRules.AugmentedReturn{ + needsPrimal ? RealRt : Nothing, + needsShadowJL ? ShadT : Nothing, + T, + } where {T} + ) if aug_RT <: abs abstract = true else - ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, Any} - emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " return type mismatch, expected "*string(ST)*" found "* string(aug_RT)) + ST = EnzymeRules.AugmentedReturn{ + needsPrimal ? RealRt : Nothing, + needsShadowJL ? ShadT : Nothing, + Any, + } + emit_error( + B, + orig, + "Enzyme: Augmented forward pass custom rule " * + string(augprimal_TT) * + " return type mismatch, expected " * + string(ST) * + " found " * + string(aug_RT), + ) return tapeV end end resV = if abstract - StructTy = convert(LLVMType, EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, Nothing}) + StructTy = convert( + LLVMType, + EnzymeRules.AugmentedReturn{ + needsPrimal ? RealRt : Nothing, + needsShadowJL ? ShadT : Nothing, + Nothing, + }, + ) if StructTy != LLVM.VoidType() - load!(B, StructTy, bitcast!(B, res, LLVM.PointerType(StructTy, addrspace(value_type(res))))) + load!( + B, + StructTy, + bitcast!( + B, + res, + LLVM.PointerType(StructTy, addrspace(value_type(res))), + ), + ) else res end @@ -1022,7 +1428,7 @@ end @assert value_type(normalV) == value_type(orig) normalV = normalV.ref end - idx+=1 + idx += 1 end if needsShadow if needsShadowJL @@ -1031,10 +1437,11 @@ end if get_return_info(RealRt)[2] !== nothing dval = invert_pointer(gutils, operands(orig)[1], B) - for idx in 1:width - to_store = (width == 1) ? shadowV : extract_value!(B, shadowV, idx-1) + for idx = 1:width + to_store = + (width == 1) ? shadowV : extract_value!(B, shadowV, idx - 1) - store_ptr = (width == 1) ? dval : extract_value!(B, dval, idx-1) + store_ptr = (width == 1) ? dval : extract_value!(B, dval, idx - 1) store!(B, to_store, store_ptr) end @@ -1043,7 +1450,7 @@ end @assert value_type(shadowV) == shadowType shadowV = shadowV.ref end - idx+=1 + idx += 1 end end if needsTape @@ -1052,23 +1459,37 @@ end else extract_value!(B, res, idx).ref end - idx+=1 + idx += 1 end else - Tys = (A <: Active ? (width == 1 ? eltype(A) : NTuple{Int(width), eltype(A)}) : Nothing for A in activity[2+isKWCall:end]) + Tys = ( + A <: Active ? (width == 1 ? eltype(A) : NTuple{Int(width),eltype(A)}) : Nothing for A in activity[2+isKWCall:end] + ) ST = Tuple{Tys...} if rev_RT != ST - emit_error(B, orig, "Enzyme: Reverse pass custom rule " * string(rev_TT) * " return type mismatch, expected "*string(ST)*" found "* string(rev_RT)) + emit_error( + B, + orig, + "Enzyme: Reverse pass custom rule " * + string(rev_TT) * + " return type mismatch, expected " * + string(ST) * + " found " * + string(rev_RT), + ) return tapeV end - if length(actives) >= 1 && !isa(value_type(res), LLVM.StructType) && !isa(value_type(res), LLVM.ArrayType) - GPUCompiler.@safe_error "Shadow arg calling convention mismatch found return ", res + if length(actives) >= 1 && + !isa(value_type(res), LLVM.StructType) && + !isa(value_type(res), LLVM.ArrayType) + GPUCompiler.@safe_error "Shadow arg calling convention mismatch found return ", + res return tapeV end idx = 0 dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(orig))))) - Tys2 = (eltype(A) for A in activity[(2 + isKWCall):end] if A <: Active) + Tys2 = (eltype(A) for A in activity[(2+isKWCall):end] if A <: Active) seen = TypeTreeTable() for (v, Ty) in zip(actives, Tys2) TT = typetree(Ty, ctx, dl, seen) @@ -1079,24 +1500,35 @@ end size = sizeof(Ty) align = 0 premask = C_NULL - API.EnzymeGradientUtilsAddToInvertedPointerDiffeTT(gutils, orig, C_NULL, TT, size, v, ext, B, align, premask) + API.EnzymeGradientUtilsAddToInvertedPointerDiffeTT( + gutils, + orig, + C_NULL, + TT, + size, + v, + ext, + B, + align, + premask, + ) else @assert value_type(ext) == shadowVType API.EnzymeGradientUtilsAddToDiffe(gutils, v, ext, B, Typ) end - idx+=1 + idx += 1 end for (ptr_val, argTyp, refal) in mixeds RefTy = argTyp if width != 1 - RefTy = NTuple{Int(width), RefTy} + RefTy = NTuple{Int(width),RefTy} end curs = load!(B, convert(LLVMType, RefTy), refal) - for idx in 1:width - evp = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1) - evcur = (width == 1) ? curs : extract_value!(B, curs, idx-1) + for idx = 1:width + evp = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx - 1) + evcur = (width == 1) ? curs : extract_value!(B, curs, idx - 1) store_nonjl_types!(B, evcur, evp) end end @@ -1121,10 +1553,12 @@ end @register_aug function enzyme_custom_augfwd(B, orig, gutils, normalR, shadowR, tapeR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) + if is_constant_value(gutils, orig) && + is_constant_inst(gutils, orig) && + !has_aug_fwd_rule(orig, gutils) return true end - tape = enzyme_custom_common_rev(#=forward=#true, B, orig, gutils, normalR, shadowR, #=tape=#nothing) + tape = enzyme_custom_common_rev(true, B, orig, gutils, normalR, shadowR, nothing) #=tape=# if tape != C_NULL unsafe_store!(tapeR, tape) end @@ -1132,34 +1566,41 @@ end end @register_rev function enzyme_custom_rev(B, orig, gutils, tape) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) + if is_constant_value(gutils, orig) && + is_constant_inst(gutils, orig) && + !has_aug_fwd_rule(orig, gutils) return end - enzyme_custom_common_rev(#=forward=#false, B, orig, gutils, #=normalR=#C_NULL, #=shadowR=#C_NULL, #=tape=#tape) + enzyme_custom_common_rev(false, B, orig, gutils, C_NULL, C_NULL, tape) #=tape=# return nothing end @register_diffuse function enzyme_custom_diffuse(orig, gutils, val, isshadow, mode) # use default - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) + if is_constant_value(gutils, orig) && + is_constant_inst(gutils, orig) && + !has_aug_fwd_rule(orig, gutils) return (false, true) end non_rooting_use = false fop = called_operand(orig)::LLVM.Function for (i, v) in enumerate(operands(orig)[1:end-1]) - if v == val - if !any(a->kind(a) == kind(StringAttribute("enzymejl_returnRoots")), collect(parameter_attributes(fop, i))) - non_rooting_use = true - break - end - end + if v == val + if !any( + a -> kind(a) == kind(StringAttribute("enzymejl_returnRoots")), + collect(parameter_attributes(fop, i)), + ) + non_rooting_use = true + break + end + end end - + # If the operand is just rooting, we don't need it and should override defaults if !non_rooting_use - return (false, false) + return (false, false) end - + # don't use default and always require the arg return (true, false) end diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 01edec7118..75bc415654 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1,4 +1,13 @@ -function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate=false; func=true, mixed_or_active = false, reverse=false) +function setup_macro_wraps( + forwardMode::Bool, + N::Int, + Width::Int, + base = nothing, + iterate = false; + func = true, + mixed_or_active = false, + reverse = false, +) primargs = Union{Symbol,Expr}[] shadowargs = Union{Symbol,Expr}[] batchshadowargs = Vector{Union{Symbol,Expr}}[] @@ -8,7 +17,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, dfns = Union{Symbol,Expr}[:df] base_idx = 1 if func - for w in 2:Width + for w = 2:Width if base === nothing shad = Symbol("df_$w") t = Symbol("DF__$w*") @@ -22,7 +31,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, push!(dfns, shad) end end - for i in 1:N + for i = 1:N if base === nothing prim = Symbol("primal_$i") t = Symbol("PT_$i") @@ -37,7 +46,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, push!(primargs, prim) push!(primtypes, t) shadows = Union{Symbol,Expr}[] - for w in 1:Width + for w = 1:Width if base === nothing shad = Symbol("shadow_$(i)_$w") t = Symbol("ST_$(i)_$w") @@ -62,7 +71,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, wrapped = Expr[] modbetween = Expr[:(MB[1])] active_refs = Expr[] - for i in 1:N + for i = 1:N if iterate push!(modbetween, quote ntuple(Val(length($(primargs[i])))) do _ @@ -73,7 +82,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, end aref = Symbol("active_ref_$i") push!(active_refs, quote - $aref = active_reg_nothrow($(primtypes[i]), Val(nothing)); + $aref = active_reg_nothrow($(primtypes[i]), Val(nothing)) end) expr = if iterate if forwardMode @@ -83,34 +92,60 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, end else quote - iterate_unwrap_fwd_batchdup(Val($Width), $(primargs[i]), $(shadowargs[i])) + iterate_unwrap_fwd_batchdup( + Val($Width), + $(primargs[i]), + $(shadowargs[i]), + ) end end :( - if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) - @assert $(primtypes[i]) !== DataType - $dupexpr - else - map(Const, $(primargs[i])) - end + if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) + @assert $(primtypes[i]) !== DataType + $dupexpr + else + map(Const, $(primargs[i])) + end ) else mixexpr = if Width == 1 quote - iterate_unwrap_augfwd_mix(Val($reverse), refs, $(primargs[i]), $(shadowargs[i])) + iterate_unwrap_augfwd_mix( + Val($reverse), + refs, + $(primargs[i]), + $(shadowargs[i]), + ) end else quote - iterate_unwrap_augfwd_batchmix(Val($reverse), refs, Val($Width), $(primargs[i]), $(shadowargs[i])) + iterate_unwrap_augfwd_batchmix( + Val($reverse), + refs, + Val($Width), + $(primargs[i]), + $(shadowargs[i]), + ) end end dupexpr = if Width == 1 quote - iterate_unwrap_augfwd_dup(Val($reverse), refs, $(primargs[i]), $(shadowargs[i])) + iterate_unwrap_augfwd_dup( + Val($reverse), + refs, + $(primargs[i]), + $(shadowargs[i]), + ) end else quote - iterate_unwrap_augfwd_batchdup(Val($reverse), refs, Val($Width), $(primargs[i]), $(shadowargs[i])) + iterate_unwrap_augfwd_batchdup( + Val($reverse), + refs, + Val($Width), + $(primargs[i]), + $(shadowargs[i]), + ) end end :( @@ -132,7 +167,10 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, if forwardMode quote if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) - $((Width == 1) ? :Duplicated : :BatchDuplicated)($(primargs[i]), $(shadowargs[i])) + $((Width == 1) ? :Duplicated : :BatchDuplicated)( + $(primargs[i]), + $(shadowargs[i]), + ) else Const($(primargs[i])) end @@ -144,9 +182,15 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, if $aref == ActiveState Active($(primargs[i])) elseif $aref == MixedState - $((Width == 1) ? :MixedDuplicated : :BatchMixedDuplicated)($(primargs[i]), $(shadowargs[i])) + $((Width == 1) ? :MixedDuplicated : :BatchMixedDuplicated)( + $(primargs[i]), + $(shadowargs[i]), + ) else - $((Width == 1) ? :Duplicated : :BatchDuplicated)($(primargs[i]), $(shadowargs[i])) + $((Width == 1) ? :Duplicated : :BatchDuplicated)( + $(primargs[i]), + $(shadowargs[i]), + ) end else Const($(primargs[i])) @@ -157,8 +201,10 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, push!(wrapped, expr) end - any_mixed = quote false end - for i in 1:N + any_mixed = quote + false + end + for i = 1:N aref = Symbol("active_ref_$i") if mixed_or_active any_mixed = :($any_mixed || $aref == MixedState || $aref == ActiveState) @@ -169,19 +215,27 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, push!(active_refs, quote any_mixed = $any_mixed end) - return primargs, shadowargs, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween, active_refs + return primargs, + shadowargs, + primtypes, + allargs, + typeargs, + wrapped, + batchshadowargs, + modbetween, + active_refs end function body_runtime_generic_fwd(N, Width, wrapped, primtypes) - nnothing = Vector{Nothing}(undef, Width+1) - nres = Vector{Expr}(undef, Width+1) + nnothing = Vector{Nothing}(undef, Width + 1) + nres = Vector{Expr}(undef, Width + 1) fill!(nnothing, nothing) fill!(nres, :(res[1])) - ModifiedBetween = Vector{Bool}(undef, N+1) + ModifiedBetween = Vector{Bool}(undef, N + 1) fill!(ModifiedBetween, false) ElTypes = Vector{Expr}(undef, N) Types = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N @inbounds ElTypes[i] = :(eltype(Core.Typeof(args[$i]))) @inbounds Types[i] = :(Core.Typeof(args[$i])) end @@ -195,7 +249,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) :(Duplicated(f, df)) else fargs = [:df] - for i in 2:Width + for i = 2:Width push!(fargs, Symbol("df_$i")) end :(BatchDuplicated(f, ($(fargs...),))) @@ -203,7 +257,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) dupty = if Width == 1 :(Duplicated{FT}) else - :(BatchDuplicated{FT, $Width}) + :(BatchDuplicated{FT,$Width}) end return quote @@ -221,7 +275,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) end if $Width != 1 if annotation <: Duplicated - annotation = BatchDuplicated{rt, $Width} + annotation = BatchDuplicated{rt,$Width} end end @@ -233,7 +287,20 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) world = codegen_world_age(FT, tt) opt_mi = Val(world) - forward = thunk(opt_mi, dupClosure ? $dupty : Const{FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val(($(ModifiedBetween...),)), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) + forward = thunk( + opt_mi, + dupClosure ? $dupty : Const{FT}, + annotation, + tt′, + Val(API.DEM_ForwardMode), + width, + Val(($(ModifiedBetween...),)), + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# res = forward(dupClosure ? $dup : Const(f), args...) @@ -253,44 +320,69 @@ function func_runtime_generic_fwd(N, Width) body = body_runtime_generic_fwd(N, Width, wrapped, primtypes) quote - function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, ReturnType, F, DF, $(typeargs...)} + function runtime_generic_fwd( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{$Width}, + RT::Val{ReturnType}, + f::F, + df::DF, + $(allargs...), + ) where {ActivityTup,RuntimeActivity,ReturnType,F,DF,$(typeargs...)} $body end end end -@generated function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, RuntimeActivity, Width, ReturnType, F, DF} - N = div(length(allargs)+2, Width+1)-1 +@generated function runtime_generic_fwd( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{Width}, + RT::Val{ReturnType}, + f::F, + df::DF, + allargs..., +) where {ActivityTup,RuntimeActivity,Width,ReturnType,F,DF} + N = div(length(allargs) + 2, Width + 1) - 1 _, _, primtypes, _, _, wrapped, _, _, _ = setup_macro_wraps(true, N, Width, :allargs) return body_runtime_generic_fwd(N, Width, wrapped, primtypes) end function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) - nres = Vector{Symbol}(undef, Width+1) + nres = Vector{Symbol}(undef, Width + 1) fill!(nres, :origRet) nzeros = Vector{Expr}(undef, Width) fill!(nzeros, :(Ref(make_zero(origRet)))) - + ElTypes = Vector{Expr}(undef, N) MakeTypes = Vector{Expr}(undef, N) Types = Vector{Symbol}(undef, N) MixedTypes = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N @inbounds ElTypes[i] = :(eltype($(Symbol("type_$i")))) @inbounds MakeTypes[i] = :($(Symbol("type_$i")) = Core.Typeof(args[$i])) @inbounds Types[i] = Symbol("type_$i") - @inbounds MixedTypes[i] = :($(Symbol("active_ref_$i") == MixedState) ? Ref($(Symbol("type_$i"))) : $(Symbol("type_$i"))) + @inbounds MixedTypes[i] = :( + $(Symbol("active_ref_$i") == MixedState) ? Ref($(Symbol("type_$i"))) : + $(Symbol("type_$i")) + ) end ending = if Width == 1 quote if annotation <: MixedDuplicated shadow_return = initShadow - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType((origRet, shadow_return, tape)) else shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType((origRet, initShadow, tape)) end end @@ -298,33 +390,39 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) quote if annotation <: BatchMixedDuplicated shadow_return = (initShadow...,) - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType((origRet, initShadow..., tape)) else shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType((origRet, initShadow..., tape)) end end end - + shadowretinit = if Width == 1 :(Ref(make_zero(origRet))) else :(($(nzeros...),)) end - + shadowretret = if Width == 1 :(return ReturnType((origRet, shadow_return, tape))) else :(return ReturnType((origRet, shadow_return..., tape))) end - + dup = if Width == 1 :(Duplicated(f, df)) else fargs = [:df] - for i in 2:Width + for i = 2:Width push!(fargs, Symbol("df_$i")) end :(BatchDuplicated(f, ($(fargs...),))) @@ -332,14 +430,14 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) dupty = if Width == 1 :(Duplicated{FT}) else - :(BatchDuplicated{FT, $Width}) + :(BatchDuplicated{FT,$Width}) end return quote $(active_refs...) args = ($(wrapped...),) $(MakeTypes...) - + FT = Core.Typeof(f) dupClosure0 = if ActivityTup[1] !guaranteed_const(FT) @@ -352,18 +450,29 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) annotationA = if $Width != 1 && annotation0 <: Duplicated - BatchDuplicated{rt, $Width} + BatchDuplicated{rt,$Width} elseif $Width != 1 && annotation0 <: MixedDuplicated - BatchMixedDuplicated{rt, $Width} + BatchMixedDuplicated{rt,$Width} else annotation0 end world = codegen_world_age(FT, tt) opt_mi = Val(world) - forward, adjoint = thunk(opt_mi, dupClosure0 ? $dupty : Const{FT}, - annotationA, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) + forward, adjoint = thunk( + opt_mi, + dupClosure0 ? $dupty : Const{FT}, + annotationA, + Tuple{$(Types...)}, + Val(API.DEM_ReverseModePrimal), + width, + ModifiedBetween, + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# internal_tape, origRet, initShadow = forward(dupClosure0 ? $dup : Const(f), args...) annotation = annotationA @@ -371,11 +480,17 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) resT = typeof(origRet) if annotation <: Const shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType(($(nres...), tape)) elseif annotation <: Active shadow_return = $shadowretinit - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) $shadowretret end @@ -384,31 +499,51 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) end function func_runtime_generic_augfwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _, _, active_refs = setup_macro_wraps(false, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _, _, active_refs = + setup_macro_wraps(false, N, Width) body = body_runtime_generic_augfwd(N, Width, wrapped, primtypes, active_refs) quote - function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, RuntimeActivity, F, DF, $(typeargs...)} + function runtime_generic_augfwd( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{$Width}, + ModifiedBetween::Val{MB}, + RT::Val{ReturnType}, + f::F, + df::DF, + $(allargs...), + )::ReturnType where {ActivityTup,MB,ReturnType,RuntimeActivity,F,DF,$(typeargs...)} $body end end end -@generated function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...)::ReturnType where {ActivityTup, MB, RuntimeActivity, Width, ReturnType, F, DF} - N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _, _, active_refs = setup_macro_wraps(false, N, Width, :allargs) +@generated function runtime_generic_augfwd( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{Width}, + ModifiedBetween::Val{MB}, + RT::Val{ReturnType}, + f::F, + df::DF, + allargs..., +)::ReturnType where {ActivityTup,MB,RuntimeActivity,Width,ReturnType,F,DF} + N = div(length(allargs) + 2, Width + 1) - 1 + _, _, primtypes, _, _, wrapped, _, _, active_refs = + setup_macro_wraps(false, N, Width, :allargs) return body_runtime_generic_augfwd(N, Width, wrapped, primtypes, active_refs) end -function nonzero_active_data(x::T) where T<: AbstractFloat +function nonzero_active_data(x::T) where {T<:AbstractFloat} return x != zero(T) end -nonzero_active_data(::T) where T<: Base.RefValue = false -nonzero_active_data(::T) where T<: Array = false -nonzero_active_data(::T) where T<: Ptr = false +nonzero_active_data(::T) where {T<:Base.RefValue} = false +nonzero_active_data(::T) where {T<:Array} = false +nonzero_active_data(::T) where {T<:Ptr} = false -function nonzero_active_data(x::T) where T +function nonzero_active_data(x::T) where {T} if guaranteed_const(T) return false end @@ -427,21 +562,33 @@ end function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, active_refs) outs = [] - for i in 1:N - for w in 1:Width + for i = 1:N + for w = 1:Width expr = if Width == 1 :(tup[$i]) else :(tup[$i][$w]) end shad = shadowargs[i][w] - out = :(if tup[$i] === nothing - elseif $shad isa Base.RefValue - $shad[] = recursive_add($shad[], $expr) + out = quote + if tup[$i] === nothing + elseif $shad isa Base.RefValue + $shad[] = recursive_add($shad[], $expr) else - error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)*" tup[i]="*string(tup[$i])*" i="*string($i)*" w="*string($w)*" tup="*string(tup)) + error( + "Enzyme Mutability Error: Cannot add one in place to immutable value " * + string($shad) * + " tup[i]=" * + string(tup[$i]) * + " i=" * + string($i) * + " w=" * + string($w) * + " tup=" * + string(tup), + ) end - ) + end push!(outs, out) end end @@ -450,7 +597,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act shadowret = :(tape.shadow_return[]) else shadowret = [] - for w in 1:Width + for w = 1:Width push!(shadowret, :(tape.shadow_return[$w][])) end shadowret = :(($(shadowret...),)) @@ -459,7 +606,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act ElTypes = Vector{Expr}(undef, N) MakeTypes = Vector{Expr}(undef, N) Types = Vector{Symbol}(undef, N) - for i in 1:N + for i = 1:N @inbounds ElTypes[i] = :(eltype($(Symbol("type_$i")))) @inbounds MakeTypes[i] = :($(Symbol("type_$i")) = Core.Typeof(args[$i])) @inbounds Types[i] = Symbol("type_$i") @@ -469,7 +616,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act :(Duplicated(f, df)) else fargs = [:df] - for i in 2:Width + for i = 2:Width push!(fargs, Symbol("df_$i")) end :(BatchDuplicated(f, ($(fargs...),))) @@ -477,14 +624,14 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act dupty = if Width == 1 :(Duplicated{FT}) else - :(BatchDuplicated{FT, $Width}) + :(BatchDuplicated{FT,$Width}) end quote $(active_refs...) args = ($(wrapped...),) $(MakeTypes...) - + FT = Core.Typeof(f) dupClosure0 = if ActivityTup[1] !guaranteed_const(FT) @@ -497,7 +644,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) annotation = if $Width != 1 && annotation0 <: Duplicated - BatchDuplicated{rt, $Width} + BatchDuplicated{rt,$Width} else annotation0 end @@ -505,36 +652,84 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act world = codegen_world_age(FT, tt) opt_mi = Val(world) - _, adjoint = thunk(opt_mi, dupClosure0 ? $dupty : Const{FT}, - annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) - - tup = if annotation0 <: Active || annotation0 <: MixedDuplicated || annotation0 <: BatchMixedDuplicated - adjoint(dupClosure0 ? $dup : Const(f), args..., $shadowret, tape.internal_tape)[1] - else - adjoint(dupClosure0 ? $dup : Const(f), args..., tape.internal_tape)[1] - end + _, adjoint = thunk( + opt_mi, + dupClosure0 ? $dupty : Const{FT}, + annotation, + Tuple{$(Types...)}, + Val(API.DEM_ReverseModePrimal), + width, + ModifiedBetween, + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# + + tup = + if annotation0 <: Active || + annotation0 <: MixedDuplicated || + annotation0 <: BatchMixedDuplicated + adjoint( + dupClosure0 ? $dup : Const(f), + args..., + $shadowret, + tape.internal_tape, + )[1] + else + adjoint(dupClosure0 ? $dup : Const(f), args..., tape.internal_tape)[1] + end $(outs...) + return nothing end end function func_runtime_generic_rev(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width) - body = body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs, active_refs) + _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width) + body = + body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs, active_refs) quote - function runtime_generic_rev(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, MB, TapeType, F, DF, $(typeargs...)} + function runtime_generic_rev( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{$Width}, + ModifiedBetween::Val{MB}, + tape::TapeType, + f::F, + df::DF, + $(allargs...), + ) where {ActivityTup,RuntimeActivity,MB,TapeType,F,DF,$(typeargs...)} $body end end end -@generated function runtime_generic_rev(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, RuntimeActivity, Width, TapeType, F, DF} - N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs) - return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs, active_refs) +@generated function runtime_generic_rev( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{Width}, + ModifiedBetween::Val{MB}, + tape::TapeType, + f::F, + df::DF, + allargs..., +) where {ActivityTup,MB,RuntimeActivity,Width,TapeType,F,DF} + N = div(length(allargs) + 2, Width + 1) - 1 + _, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width, :allargs) + return body_runtime_generic_rev( + N, + Width, + wrapped, + primtypes, + batchshadowargs, + active_refs, + ) end @inline concat() = () @@ -545,7 +740,8 @@ end @inline iterate_unwrap_inner_fwd(x::Const) = (map(Const, x.val)...,) @inline iterate_unwrap_inner_fwd(x::Duplicated) = (map(Duplicated, x.val, x.dval)...,) @inline batch_dup_tuple(x, vals...) = BatchDuplicated(x, (vals...,)) -@inline iterate_unwrap_inner_fwd(x::BatchDuplicated) = (map(batch_dup_tuple, x.val, x.dval...)...,) +@inline iterate_unwrap_inner_fwd(x::BatchDuplicated) = + (map(batch_dup_tuple, x.val, x.dval...)...,) @inline function iterate_unwrap_fwd(args...) ntuple(Val(length(args))) do i @@ -596,7 +792,7 @@ end end end -function push_if_not_ref(::Val{reverse}, vals, darg, ::Type{T2}) where {reverse, T2} +function push_if_not_ref(::Val{reverse}, vals, darg, ::Type{T2}) where {reverse,T2} if reverse return popfirst!(vals) else @@ -606,11 +802,21 @@ function push_if_not_ref(::Val{reverse}, vals, darg, ::Type{T2}) where {reverse, end end -function push_if_not_ref(::Val{reverse}, vals, darg::Base.RefValue{T2}, ::Type{T2}) where {reverse, T2} +function push_if_not_ref( + ::Val{reverse}, + vals, + darg::Base.RefValue{T2}, + ::Type{T2}, +) where {reverse,T2} return darg end -@inline function iterate_unwrap_augfwd_dup(::Val{reverse}, vals, args, dargs) where reverse +@inline function iterate_unwrap_augfwd_dup( + ::Val{reverse}, + vals, + args, + dargs, +) where {reverse} ntuple(Val(length(args))) do i Base.@_inline_meta arg = args[i] @@ -622,14 +828,23 @@ end Active(arg) elseif actreg == MixedState darg = Base.inferencebarrier(dargs[i]) - MixedDuplicated(arg, push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty}) + MixedDuplicated( + arg, + push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty}, + ) else Duplicated(arg, dargs[i]) end end end -@inline function iterate_unwrap_augfwd_batchdup(::Val{reverse}, vals, ::Val{Width}, args, dargs) where {reverse, Width} +@inline function iterate_unwrap_augfwd_batchdup( + ::Val{reverse}, + vals, + ::Val{Width}, + args, + dargs, +) where {reverse,Width} ntuple(Val(length(args))) do i Base.@_inline_meta arg = args[i] @@ -640,11 +855,14 @@ end elseif actreg == ActiveState Active(arg) elseif actreg == MixedState - BatchMixedDuplicated(arg, ntuple(Val(Width)) do j - Base.@_inline_meta - darg = Base.inferencebarrier(dargs[j][i]) - push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty} - end) + BatchMixedDuplicated( + arg, + ntuple(Val(Width)) do j + Base.@_inline_meta + darg = Base.inferencebarrier(dargs[j][i]) + push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty} + end, + ) else BatchDuplicated(arg, ntuple(Val(Width)) do j Base.@_inline_meta @@ -654,7 +872,12 @@ end end end -@inline function iterate_unwrap_augfwd_mix(::Val{reverse}, vals, args, dargs0) where reverse +@inline function iterate_unwrap_augfwd_mix( + ::Val{reverse}, + vals, + args, + dargs0, +) where {reverse} dargs = dargs0[] ntuple(Val(length(args))) do i Base.@_inline_meta @@ -667,14 +890,23 @@ end Active(arg) elseif actreg == MixedState darg = Base.inferencebarrier(dargs[i]) - MixedDuplicated(arg, push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty}) + MixedDuplicated( + arg, + push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty}, + ) else Duplicated(arg, dargs[i]) end end end -@inline function iterate_unwrap_augfwd_batchmix(::Val{reverse}, vals, ::Val{Width}, args, dargs) where {reverse, Width} +@inline function iterate_unwrap_augfwd_batchmix( + ::Val{reverse}, + vals, + ::Val{Width}, + args, + dargs, +) where {reverse,Width} ntuple(Val(length(args))) do i Base.@_inline_meta arg = args[i] @@ -685,11 +917,14 @@ end elseif actreg == ActiveState Active(arg) elseif actreg == MixedState - BatchMixedDuplicated(arg, ntuple(Val(Width)) do j - Base.@_inline_meta - darg = Base.inferencebarrier(dargs[j][][i]) - push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty} - end) + BatchMixedDuplicated( + arg, + ntuple(Val(Width)) do j + Base.@_inline_meta + darg = Base.inferencebarrier(dargs[j][][i]) + push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty} + end, + ) else BatchDuplicated(arg, ntuple(Val(Width)) do j Base.@_inline_meta @@ -699,21 +934,21 @@ end end end -@inline function allFirst(::Val{Width}, res) where Width +@inline function allFirst(::Val{Width}, res) where {Width} ntuple(Val(Width)) do i Base.@_inline_meta res[1] end end -@inline function allSame(::Val{Width}, res) where Width +@inline function allSame(::Val{Width}, res) where {Width} ntuple(Val(Width)) do i Base.@_inline_meta res end end -@inline function allZero(::Val{Width}, res) where Width +@inline function allZero(::Val{Width}, res) where {Width} ntuple(Val(Width)) do i Base.@_inline_meta Ref(make_zero(res)) @@ -721,21 +956,31 @@ end end # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] -function fwddiff_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {RuntimeActivity, width, dupClosure0, ReturnType, FT, tt′, DF, Nargs} +function fwddiff_with_return( + runtimeActivity::Val{RuntimeActivity}, + ::Val{width}, + ::Val{dupClosure0}, + ::Type{ReturnType}, + ::Type{FT}, + ::Type{tt′}, + f::FT, + df::DF, + args::Vararg{Annotation,Nargs}, +)::ReturnType where {RuntimeActivity,width,dupClosure0,ReturnType,FT,tt′,DF,Nargs} ReturnPrimal = Val(true) - ModifiedBetween = Val(Enzyme.falses_from_args(Nargs+1)) + ModifiedBetween = Val(Enzyme.falses_from_args(Nargs + 1)) dupClosure = dupClosure0 && !guaranteed_const(FT) FA = dupClosure ? Duplicated{FT} : Const{FT} - tt = Enzyme.vaEltypes(tt′) + tt = Enzyme.vaEltypes(tt′) rt = Core.Compiler.return_type(f, tt) annotation0 = guess_activity(rt, API.DEM_ForwardMode) annotation = if width != 1 if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated - BatchDuplicated{rt, width} + BatchDuplicated{rt,width} else Const{rt} end @@ -758,10 +1003,25 @@ function fwddiff_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width} Const(f) end opt_mi = Val(world) - res = thunk(opt_mi, FA, annotation, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), - ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity)(fa, args...) + res = thunk( + opt_mi, + FA, + annotation, + tt′, + Val(API.DEM_ForwardMode), + Val(width), #=Mode=# + ModifiedBetween, + ReturnPrimal, + Val(false), + FFIABI, + Val(false), + runtimeActivity, + )( + fa, + args..., + ) #=erriffuncwritten=# return if annotation <: Const - ReturnType(allFirst(Val(width+1), res)) + ReturnType(allFirst(Val(width + 1), res)) else if width == 1 ReturnType((res[2], res[1])) @@ -773,38 +1033,66 @@ end function body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) wrappedexexpand = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N @inbounds wrappedexexpand[i] = :($(wrapped[i])...) end return quote $(active_refs...) args = ($(wrappedexexpand...),) - tt′ = Enzyme.vaTypeof(args...) + tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - fwddiff_with_return(runtimeActivity, Val($Width), Val(ActivityTup[1]), ReturnType, FT, tt′, f, df, args...)::ReturnType + fwddiff_with_return( + runtimeActivity, + Val($Width), + Val(ActivityTup[1]), + ReturnType, + FT, + tt′, + f, + df, + args..., + )::ReturnType end end function func_runtime_iterate_fwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _, _, active_refs = setup_macro_wraps(true, N, Width, #=base=#nothing, #=iterate=#true) + _, _, primtypes, allargs, typeargs, wrapped, _, _, active_refs = + setup_macro_wraps(true, N, Width, nothing, true) #=iterate=# body = body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) quote - function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, ReturnType, F, DF, $(typeargs...)} + function runtime_iterate_fwd( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{$Width}, + RT::Val{ReturnType}, + f::F, + df::DF, + $(allargs...), + ) where {ActivityTup,RuntimeActivity,ReturnType,F,DF,$(typeargs...)} $body end end end -@generated function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, RuntimeActivity, Width, ReturnType, F, DF} - N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _, _, active_refs = setup_macro_wraps(true, N, Width, :allargs, #=iterate=#true) +@generated function runtime_iterate_fwd( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{Width}, + RT::Val{ReturnType}, + f::F, + df::DF, + allargs..., +) where {ActivityTup,RuntimeActivity,Width,ReturnType,F,DF} + N = div(length(allargs) + 2, Width + 1) - 1 + _, _, primtypes, _, _, wrapped, _, _, active_refs = + setup_macro_wraps(true, N, Width, :allargs, true) #=iterate=# return body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) end -@generated function primal_tuple(args::Vararg{Annotation, Nargs}) where Nargs +@generated function primal_tuple(args::Vararg{Annotation,Nargs}) where {Nargs} expr = Vector{Expr}(undef, Nargs) - for i in 1:Nargs + for i = 1:Nargs @inbounds expr[i] = :(args[$i].val) end return quote @@ -813,16 +1101,20 @@ end end end -@generated function shadow_tuple(::Type{Ann}, ::Val{1}, args::Vararg{Annotation, Nargs}) where {Ann, Nargs} +@generated function shadow_tuple( + ::Type{Ann}, + ::Val{1}, + args::Vararg{Annotation,Nargs}, +) where {Ann,Nargs} expr = Vector{Expr}(undef, Nargs) - for i in 1:Nargs + for i = 1:Nargs @inbounds expr[i] = quote @assert !(args[$i] isa Active) if args[$i] isa Const args[$i].val elseif args[$i] isa MixedDuplicated args[$i].dval[] - else + else args[$i].dval end end @@ -837,18 +1129,22 @@ end end end -@generated function shadow_tuple(::Type{Ann}, ::Val{width}, args::Vararg{Annotation, Nargs}) where {Ann, width, Nargs} +@generated function shadow_tuple( + ::Type{Ann}, + ::Val{width}, + args::Vararg{Annotation,Nargs}, +) where {Ann,width,Nargs} wexpr = Vector{Expr}(undef, width) - for w in 1:width + for w = 1:width expr = Vector{Expr}(undef, Nargs) - for i in 1:Nargs + for i = 1:Nargs @inbounds expr[i] = quote @assert !(args[$i] isa Active) if args[$i] isa Const args[$i].val elseif args[$i] isa BatchMixedDuplicated args[$i].dval[$w][] - else + else args[$i].dval[$w] end end @@ -867,19 +1163,40 @@ end end # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] -function augfwd_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Val{ModifiedBetween0}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {RuntimeActivity, width, dupClosure0, ReturnType, ModifiedBetween0, FT, tt′, DF, Nargs} +function augfwd_with_return( + runtimeActivity::Val{RuntimeActivity}, + ::Val{width}, + ::Val{dupClosure0}, + ::Type{ReturnType}, + ::Val{ModifiedBetween0}, + ::Type{FT}, + ::Type{tt′}, + f::FT, + df::DF, + args::Vararg{Annotation,Nargs}, +)::ReturnType where { + RuntimeActivity, + width, + dupClosure0, + ReturnType, + ModifiedBetween0, + FT, + tt′, + DF, + Nargs, +} ReturnPrimal = Val(true) ModifiedBetween = Val(ModifiedBetween0) - tt = Enzyme.vaEltypes(tt′) + tt = Enzyme.vaEltypes(tt′) rt = Core.Compiler.return_type(f, tt) annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) annotation = if width != 1 if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated - BatchDuplicated{rt, width} + BatchDuplicated{rt,width} elseif annotation0 <: MixedDuplicated - BatchMixedDuplicated{rt, width} + BatchMixedDuplicated{rt,width} elseif annotation0 <: Active Active{rt} else @@ -912,27 +1229,46 @@ function augfwd_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width}, end world = codegen_world_age(FT, tt) opt_mi = Val(world) - forward, adjoint = thunk(opt_mi, FA, - annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) + forward, adjoint = thunk( + opt_mi, + FA, + annotation, + tt′, + Val(API.DEM_ReverseModePrimal), + Val(width), + ModifiedBetween, + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# forward(fa, args...) else - nothing, primal_tuple(args...), annotation <: Active ? nothing : shadow_tuple(annotation, Val(width), args...) + nothing, + primal_tuple(args...), + annotation <: Active ? nothing : shadow_tuple(annotation, Val(width), args...) end resT = typeof(origRet) if annotation <: Const shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - return ReturnType((allSame(Val(width+1), origRet)..., tape)) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) + return ReturnType((allSame(Val(width + 1), origRet)..., tape)) elseif annotation <: Active shadow_return = if width == 1 Ref(make_zero(origRet)) else allZero(Val(width), origRet) end - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) if width == 1 return ReturnType((origRet, shadow_return, tape)) else @@ -943,21 +1279,33 @@ function augfwd_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width}, if width == 1 if annotation <: MixedDuplicated shadow_return = initShadow - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType((origRet, initShadow, tape)) else shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType((origRet, initShadow, tape)) end else if annotation <: BatchMixedDuplicated shadow_return = initShadow - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType((origRet, initShadow..., tape)) else shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType((origRet, initShadow..., tape)) end end @@ -965,65 +1313,129 @@ end function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) wrappedexexpand = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N @inbounds wrappedexexpand[i] = :($(wrapped[i])...) end - results = Vector{Expr}(undef, Width+1) - for i in 1:(Width+1) + results = Vector{Expr}(undef, Width + 1) + for i = 1:(Width+1) results[i] = :(tmpvals[$i]) end return quote refs = Base.RefValue[] $(active_refs...) args = ($(wrappedexexpand...),) - tt′ = Enzyme.vaTypeof(args...) + tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - tmpvals = augfwd_with_return(runtimeActivity, Val($Width), Val(ActivityTup[1]), ReturnType, Val(concat($(modbetween...))), FT, tt′, f, df, args...)::ReturnType - ReturnType(($(results...), (tmpvals[$(Width+2)], refs))) + tmpvals = augfwd_with_return( + runtimeActivity, + Val($Width), + Val(ActivityTup[1]), + ReturnType, + Val(concat($(modbetween...))), + FT, + tt′, + f, + df, + args..., + )::ReturnType + ReturnType(($(results...), (tmpvals[$(Width + 2)], refs))) end end function func_runtime_iterate_augfwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _, modbetween, active_refs = setup_macro_wraps(false, N, Width, #=base=#nothing, #=iterate=#true) - body = body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) + _, _, primtypes, allargs, typeargs, wrapped, _, modbetween, active_refs = + setup_macro_wraps(false, N, Width, nothing, true) #=iterate=# + body = + body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) quote - function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, MB, ReturnType, F, DF, $(typeargs...)} + function runtime_iterate_augfwd( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{$Width}, + ModifiedBetween::Val{MB}, + RT::Val{ReturnType}, + f::F, + df::DF, + $(allargs...), + ) where {ActivityTup,RuntimeActivity,MB,ReturnType,F,DF,$(typeargs...)} $body end end end -@generated function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, RuntimeActivity, MB, Width, ReturnType, F, DF} - N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _ , modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) - return body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) +@generated function runtime_iterate_augfwd( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{Width}, + ModifiedBetween::Val{MB}, + RT::Val{ReturnType}, + f::F, + df::DF, + allargs..., +) where {ActivityTup,RuntimeActivity,MB,Width,ReturnType,F,DF} + N = div(length(allargs) + 2, Width + 1) - 1 + _, _, primtypes, _, _, wrapped, _, modbetween, active_refs = + setup_macro_wraps(false, N, Width, :allargs, true) #=iterate=# + return body_runtime_iterate_augfwd( + N, + Width, + modbetween, + wrapped, + primtypes, + active_refs, + ) end function add_into_vec!(val::Base.RefValue, expr, vec, idx_in_vec) - val[] = recursive_add(val[], expr, identity, guaranteed_nonactive) - nothing + val[] = recursive_add(val[], expr, identity, guaranteed_nonactive) + nothing end -function add_into_vec!(val::T, expr, vec, idx_in_vec) where T +function add_into_vec!(val::T, expr, vec, idx_in_vec) where {T} if ismutable(vec) @inbounds vec[idx_in_vec] = recursive_add(val, expr, identity, guaranteed_nonactive) else - error("Enzyme Mutability Error: Cannot in place to immutable value vec[$idx_in_vec] = $val, vec=$vec") + error( + "Enzyme Mutability Error: Cannot in place to immutable value vec[$idx_in_vec] = $val, vec=$vec", + ) end nothing end # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] -@generated function rev_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween0}, ::Val{lengths}, ::Type{FT}, ::Type{ttp}, f::FT, df::DF, tape, shadowargs, args::Vararg{Annotation, Nargs})::Nothing where {RuntimeActivity, width, dupClosure0, ModifiedBetween0, lengths, FT, ttp, DF, Nargs} +@generated function rev_with_return( + runtimeActivity::Val{RuntimeActivity}, + ::Val{width}, + ::Val{dupClosure0}, + ::Val{ModifiedBetween0}, + ::Val{lengths}, + ::Type{FT}, + ::Type{ttp}, + f::FT, + df::DF, + tape, + shadowargs, + args::Vararg{Annotation,Nargs}, +)::Nothing where { + RuntimeActivity, + width, + dupClosure0, + ModifiedBetween0, + lengths, + FT, + ttp, + DF, + Nargs, +} nontupexprs = Vector{Expr}(undef, Nargs) - for i in 1:Nargs + for i = 1:Nargs mid = if width == 1 :(tape.shadow_return[][$i]) else mexprs = Vector{Expr}(undef, width) - for w in 1:width + for w = 1:width @inbounds mexprs[w] = :(tape.shadow_return[$w][][$i]) end quote @@ -1032,7 +1444,9 @@ end end @inbounds nontupexprs[i] = quote - if args[$i] isa Active || args[$i] isa MixedDuplicated || args[$i] isa BatchMixedDuplicated + if args[$i] isa Active || + args[$i] isa MixedDuplicated || + args[$i] isa BatchMixedDuplicated $mid else nothing @@ -1041,10 +1455,12 @@ end end endexprs = Matrix{Expr}(undef, Nargs, width) - for i in 1:Nargs - for w in 1:width + for i = 1:Nargs + for w = 1:width @inbounds endexprs[i, w] = quote - if args[$i] isa Active || args[$i] isa MixedDuplicated || args[$i] isa BatchMixedDuplicated + if args[$i] isa Active || + args[$i] isa MixedDuplicated || + args[$i] isa BatchMixedDuplicated expr = if args[$i] isa Active || f == Base.tuple if $width == 1 tup[$i] @@ -1061,7 +1477,7 @@ end idx_of_vec, idx_in_vec = $(lengths[i]) vec = @inbounds shadowargs[idx_of_vec][$w] if vec isa Base.RefValue - vecld = vec[] + vecld = vec[] T = Core.Typeof(vecld) vec[] = recursive_index_add(T, vecld, Val(idx_in_vec), expr) else @@ -1079,9 +1495,9 @@ end annotation = if width != 1 quote if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated - BatchDuplicated{rt, $width} + BatchDuplicated{rt,$width} elseif annotation0 <: MixedDuplicated - BatchMixedDuplicated{rt, $width} + BatchMixedDuplicated{rt,$width} elseif annotation0 <: Active Active{rt} else @@ -1106,7 +1522,7 @@ end :(adjoint(fa, args..., tape.shadow_return[], tape.internal_tape)[1]) else margs = Vector{Expr}(undef, width) - for w in 1:width + for w = 1:width @inbounds margs[w] = :(tape.shadow_return[$w][]) end :(adjoint(fa, args..., ($(margs...),), tape.internal_tape)[1]) @@ -1121,7 +1537,7 @@ end dupClosure = $dupClosure0 && !guaranteed_const($FT) FA = dupClosure ? Duplicated{$FT} : Const{$FT} - tt = $tt + tt = $tt rt = Core.Compiler.return_type(f, tt) annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) @@ -1135,10 +1551,21 @@ end Const(f) end opt_mi = Val(world) - forward, adjoint = thunk(opt_mi, FA, - annotation, $ttp, Val(API.DEM_ReverseModePrimal), Val($width), - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) - + forward, adjoint = thunk( + opt_mi, + FA, + annotation, + $ttp, + Val(API.DEM_ReverseModePrimal), + Val($width), + ModifiedBetween, + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# + tup = if tape.shadow_return !== nothing $shadadj else @@ -1154,9 +1581,9 @@ end end end -@generated function ntuple_pair(::Val{Len}, ::Val{i}) where {Len, i} +@generated function ntuple_pair(::Val{Len}, ::Val{i}) where {Len,i} mexprs = Vector{Expr}(undef, Len) - for j in 1:Len + for j = 1:Len @inbounds mexprs[j] = quote ($i, $j) end @@ -1167,24 +1594,32 @@ end end end -function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shadowargs, active_refs) +function body_runtime_iterate_rev( + N, + Width, + modbetween, + wrapped, + primargs, + shadowargs, + active_refs, +) shadow_ret = nothing if Width == 1 shadowret = :(tape.shadow_return[]) else shadowret = Expr[] - for w in 1:Width + for w = 1:Width push!(shadowret, :(tape.shadow_return[$w][])) end shadowret = :(($(shadowret...),)) end wrappedexexpand = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N wrappedexexpand[i] = :($(wrapped[i])...) end lengths = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N lengths[i] = quote ntuple_pair(Val(length($(primargs[i]))), Val($i)) end @@ -1198,28 +1633,84 @@ function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shado (tape0, refs) = tape $(active_refs...) args = ($(wrappedexexpand...),) - tt′ = Enzyme.vaTypeof(args...) + tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - rev_with_return(runtimeActivity, Val($Width), Val(ActivityTup[1]), Val(concat($(modbetween...))), Val(concat($(lengths...))), FT, tt′, f, df, tape0, ($(shadowsplat...),), args...) + rev_with_return( + runtimeActivity, + Val($Width), + Val(ActivityTup[1]), + Val(concat($(modbetween...))), + Val(concat($(lengths...))), + FT, + tt′, + f, + df, + tape0, + ($(shadowsplat...),), + args..., + ) return nothing end end function func_runtime_iterate_rev(N, Width) - primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, #=body=#nothing, #=iterate=#true; reverse=true) - body = body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs, active_refs) + primargs, + _, + primtypes, + allargs, + typeargs, + wrapped, + batchshadowargs, + modbetween, + active_refs = setup_macro_wraps(false, N, Width, nothing, true; reverse = true) #=iterate=# + body = body_runtime_iterate_rev( + N, + Width, + modbetween, + wrapped, + primargs, + batchshadowargs, + active_refs, + ) quote - function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, MB, TapeType, F, DF, $(typeargs...)} + function runtime_iterate_rev( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{$Width}, + ModifiedBetween::Val{MB}, + tape::TapeType, + f::F, + df::DF, + $(allargs...), + ) where {ActivityTup,RuntimeActivity,MB,TapeType,F,DF,$(typeargs...)} $body end end end -@generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, RuntimeActivity, MB, Width, TapeType, F, DF} - N = div(length(allargs)+2, Width+1)-1 - primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true; reverse=true) - return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs, active_refs) +@generated function runtime_iterate_rev( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{Width}, + ModifiedBetween::Val{MB}, + tape::TapeType, + f::F, + df::DF, + allargs..., +) where {ActivityTup,RuntimeActivity,MB,Width,TapeType,F,DF} + N = div(length(allargs) + 2, Width + 1) - 1 + primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween, active_refs = + setup_macro_wraps(false, N, Width, :allargs, true; reverse = true) #=iterate=# + return body_runtime_iterate_rev( + N, + Width, + modbetween, + wrapped, + primargs, + batchshadowargs, + active_refs, + ) end # Create specializations @@ -1232,7 +1723,21 @@ for (N, Width) in Iterators.product(0:30, 1:10) eval(func_runtime_iterate_rev(N, Width)) end -function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false, endcast=true, firstconst_after_tape=true, runtime_activity=true) +function generic_setup( + orig, + func, + ReturnType, + gutils, + start, + B::LLVM.IRBuilder, + lookup; + sret = nothing, + tape = nothing, + firstconst = false, + endcast = true, + firstconst_after_tape = true, + runtime_activity = true, +) width = get_width(gutils) mode = get_mode(gutils) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -1271,7 +1776,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, inverted = nothing active = !is_constant_value(gutils, op) - + if !active push!(ActivityList, unsafe_to_llvm(B, false)) else @@ -1285,19 +1790,27 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, else extract_value!(B, inverted, 0) end - push!(ActivityList, select!(B, icmp!(B, LLVM.API.LLVMIntNE, val, inv_0), unsafe_to_llvm(B, true), unsafe_to_llvm(B, false))) + push!( + ActivityList, + select!( + B, + icmp!(B, LLVM.API.LLVMIntNE, val, inv_0), + unsafe_to_llvm(B, true), + unsafe_to_llvm(B, false), + ), + ) else push!(ActivityList, unsafe_to_llvm(B, true)) end end - for w in 1:width + for w = 1:width ev = fill_val if inverted !== nothing if width == 1 ev = inverted else - ev = extract_value!(B, inverted, w-1) + ev = extract_value!(B, inverted, w - 1) end end @@ -1317,7 +1830,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, else pushfirst!(vals, unsafe_to_llvm(B, Val(ReturnType))) end - + if firstconst && firstconst_after_tape val = new_from_original(gutils, operands(orig)[start]) if lookup @@ -1333,7 +1846,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, ModifiedBetween = Bool[] - for idx in 1:(length(ops)+firstconst) + for idx = 1:(length(ops)+firstconst) push!(ModifiedBetween, uncacheable[(start-1)+idx] != 0) end pushfirst!(vals, unsafe_to_llvm(B, Val((ModifiedBetween...,)))) @@ -1344,7 +1857,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, pushfirst!(vals, unsafe_to_llvm(B, Val(get_runtime_activity(gutils)))) end etup0 = emit_tuple!(B, ActivityList) - etup = emit_apply_type!(B, Base.Val, [etup0]) + etup = emit_apply_type!(B, Base.Val, [etup0]) if isa(etup, LLVM.Instruction) @assert length(collect(LLVM.uses(etup0))) == 1 end @@ -1355,7 +1868,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, cal = emit_apply_generic!(B, vals) debug_from_orig!(gutils, cal, orig) - + if tape === nothing && endcast llty = convert(LLVMType, ReturnType) cal = LLVM.addrspacecast!(B, cal, LLVM.PointerType(T_jlvalue, Derived)) @@ -1366,42 +1879,69 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, end function common_generic_fwd(offset, B, orig, gutils, normalR, shadowR) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + shadow = + (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end width = get_width(gutils) - sret = generic_setup(orig, runtime_generic_fwd, AnyArray(1+Int(width)), gutils, #=start=#offset, B, false) - AT = LLVM.ArrayType(T_prjlvalue, 1+Int(width)) + sret = generic_setup( + orig, + runtime_generic_fwd, + AnyArray(1 + Int(width)), + gutils, + offset, + B, + false, + ) #=start=# + AT = LLVM.ArrayType(T_prjlvalue, 1 + Int(width)) if unsafe_load(shadowR) != C_NULL if width == 1 - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + gep = + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) end if unsafe_load(normalR) != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) else # Delete the primal code @@ -1419,45 +1959,76 @@ end end function common_generic_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + shadow = + (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end width = get_width(gutils) - sret = generic_setup(orig, runtime_generic_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset, B, false) - AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) - - if unsafe_load(shadowR) != C_NULL + sret = generic_setup( + orig, + runtime_generic_augfwd, + AnyArray(2 + Int(width)), + gutils, + offset, + B, + false, + ) #=start=# + AT = LLVM.ArrayType(T_prjlvalue, 2 + Int(width)) + + if unsafe_load(shadowR) != C_NULL if width == 1 - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + gep = + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) end - tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) + tape = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1 + width)]), + ) unsafe_store!(tapeR, tape.ref) if normalR != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) else # Delete the primal code @@ -1479,15 +2050,22 @@ end function common_generic_rev(offset, B, orig, gutils, tape)::Cvoid needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return nothing end @assert tape !== C_NULL width = get_width(gutils) - generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset, B, true; tape) + generic_setup(orig, runtime_generic_rev, Nothing, gutils, offset, B, true; tape) #=start=# return nothing end @@ -1504,9 +2082,16 @@ end function common_apply_latest_fwd(offset, B, orig, gutils, normalR, shadowR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -1515,27 +2100,45 @@ function common_apply_latest_fwd(offset, B, orig, gutils, normalR, shadowR) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) width = get_width(gutils) - AT = LLVM.ArrayType(T_prjlvalue, 1+Int(width)) - sret = generic_setup(orig, runtime_generic_fwd, AnyArray(1+Int(width)), gutils, #=start=#offset+1, B, false) + AT = LLVM.ArrayType(T_prjlvalue, 1 + Int(width)) + sret = generic_setup( + orig, + runtime_generic_fwd, + AnyArray(1 + Int(width)), + gutils, + offset + 1, + B, + false, + ) #=start=# if unsafe_load(shadowR) != C_NULL if width == 1 - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + gep = + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) end if unsafe_load(normalR) != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) else # Delete the primal code @@ -1549,9 +2152,16 @@ end function common_apply_latest_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end @@ -1559,31 +2169,53 @@ function common_apply_latest_augfwd(offset, B, orig, gutils, normalR, shadowR, t T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) width = get_width(gutils) - AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) + AT = LLVM.ArrayType(T_prjlvalue, 2 + Int(width)) # sret = generic_setup(orig, runtime_apply_latest_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+1, ctx, B, false) - sret = generic_setup(orig, runtime_generic_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+1, B, false) + sret = generic_setup( + orig, + runtime_generic_augfwd, + AnyArray(2 + Int(width)), + gutils, + offset + 1, + B, + false, + ) #=start=# if unsafe_load(shadowR) != C_NULL if width == 1 - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + gep = + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) end - tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) + tape = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1 + width)]), + ) unsafe_store!(tapeR, tape.ref) if unsafe_load(normalR) != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) else # Delete the primal code @@ -1596,14 +2228,21 @@ end function common_apply_latest_rev(offset, B, orig, gutils, tape)::Cvoid needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return nothing end if !is_constant_value(gutils, orig) || !is_constant_inst(gutils, orig) width = get_width(gutils) - generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset+1, B, true; tape) + generic_setup(orig, runtime_generic_rev, Nothing, gutils, offset + 1, B, true; tape) #=start=# end return nothing @@ -1637,9 +2276,16 @@ end function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end @@ -1648,9 +2294,14 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) width = get_width(gutils) - if v && v2 && isiter == Base.iterate && istup == Base.tuple && length(operands(orig)) >= offset+4 + if v && + v2 && + isiter == Base.iterate && + istup == Base.tuple && + length(operands(orig)) >= offset + 4 origops = collect(operands(orig)[1:end-1]) - shadowins = [ invert_pointer(gutils, origops[i], B) for i in (offset+3):length(origops) ] + shadowins = + [invert_pointer(gutils, origops[i], B) for i = (offset+3):length(origops)] shadowres = if width == 1 newops = LLVM.Value[] newvals = API.CValueType[] @@ -1664,18 +2315,25 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) push!(newvals, API.VT_Primal) end end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + newops, + newvals, + false, + ) #=lookup=# callconv!(cal, callconv(orig)) cal else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for j in 1:width + for j = 1:width newops = LLVM.Value[] newvals = API.CValueType[] for (i, v) in enumerate(origops) if i >= offset + 3 - shadowin2 = extract_value!(B, shadowins[i-offset-3+1], j-1) + shadowin2 = extract_value!(B, shadowins[i-offset-3+1], j - 1) push!(newops, shadowin2) push!(newvals, API.VT_Shadow) else @@ -1683,9 +2341,16 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) push!(newvals, API.VT_Primal) end end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + newops, + newvals, + false, + ) #=lookup=# callconv!(cal, callconv(orig)) - shadow = insert_value!(B, shadow, cal, j-1) + shadow = insert_value!(B, shadow, cal, j - 1) end shadow end @@ -1698,26 +2363,48 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - sret = generic_setup(orig, runtime_iterate_fwd, AnyArray(1+Int(width)), gutils, #=start=#offset+2, B, false) - AT = LLVM.ArrayType(T_prjlvalue, 1+Int(width)) + sret = generic_setup( + orig, + runtime_iterate_fwd, + AnyArray(1 + Int(width)), + gutils, + offset + 2, + B, + false, + ) #=start=# + AT = LLVM.ArrayType(T_prjlvalue, 1 + Int(width)) if unsafe_load(shadowR) != C_NULL if width == 1 - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(1)], + ) shadow = LLVM.load!(B, T_prjlvalue, gep) else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) end if unsafe_load(normalR) != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) else # Delete the primal code @@ -1727,7 +2414,12 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) return false end - emit_error(B, orig, "Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate "*string((v, v2, isiter, istup, length(operands(orig)), offset+4))) + emit_error( + B, + orig, + "Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate " * + string((v, v2, isiter, istup, length(operands(orig)), offset + 4)), + ) return false end @@ -1735,9 +2427,16 @@ end function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end @@ -1750,30 +2449,61 @@ function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - sret = generic_setup(orig, runtime_iterate_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+2, B, false) - AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) + sret = generic_setup( + orig, + runtime_iterate_augfwd, + AnyArray(2 + Int(width)), + gutils, + offset + 2, + B, + false, + ) #=start=# + AT = LLVM.ArrayType(T_prjlvalue, 2 + Int(width)) - if unsafe_load(shadowR) != C_NULL + if unsafe_load(shadowR) != C_NULL if width == 1 - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(1)], + ) shadow = LLVM.load!(B, T_prjlvalue, gep) else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) end - tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) + tape = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(1 + width)], + ), + ) unsafe_store!(tapeR, tape.ref) if normalR != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) else # Delete the primal code @@ -1784,24 +2514,39 @@ function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, return false end - emit_error(B, orig, "Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate "*string((v, v2, isiter, istup, length(operands(orig)), offset+4))) + emit_error( + B, + orig, + "Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate " * + string((v, v2, isiter, istup, length(operands(orig)), offset + 4)), + ) - unsafe_store!(shadowR,UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))).ref) + unsafe_store!( + shadowR, + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))).ref, + ) return false end function common_apply_iterate_rev(offset, B, orig, gutils, tape) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return nothing end @assert tape !== C_NULL width = get_width(gutils) - generic_setup(orig, runtime_iterate_rev, Nothing, gutils, #=start=#offset+2, B, true; tape) + generic_setup(orig, runtime_iterate_rev, Nothing, gutils, offset + 2, B, true; tape) #=start=# return nothing end @@ -1821,37 +2566,62 @@ end function common_invoke_fwd(offset, B, orig, gutils, normalR, shadowR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end - + T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) width = get_width(gutils) - sret = generic_setup(orig, runtime_generic_fwd, AnyArray(1+Int(width)), gutils, #=start=#offset+1, B, false) - AT = LLVM.ArrayType(T_prjlvalue, 1+Int(width)) + sret = generic_setup( + orig, + runtime_generic_fwd, + AnyArray(1 + Int(width)), + gutils, + offset + 1, + B, + false, + ) #=start=# + AT = LLVM.ArrayType(T_prjlvalue, 1 + Int(width)) if unsafe_load(shadowR) != C_NULL if width == 1 - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + gep = + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) end if unsafe_load(normalR) != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) else # Delete the primal code @@ -1865,44 +2635,75 @@ end function common_invoke_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing - + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + shadow = + (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) conv = LLVM.callconv(orig) width = get_width(gutils) - sret = generic_setup(orig, runtime_generic_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+1, B, false) - AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) + sret = generic_setup( + orig, + runtime_generic_augfwd, + AnyArray(2 + Int(width)), + gutils, + offset + 1, + B, + false, + ) #=start=# + AT = LLVM.ArrayType(T_prjlvalue, 2 + Int(width)) if unsafe_load(shadowR) != C_NULL if width == 1 - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + gep = + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) end - tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) + tape = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1 + width)]), + ) unsafe_store!(tapeR, tape.ref) if unsafe_load(normalR) != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) else # Delete the primal code @@ -1916,14 +2717,21 @@ end function common_invoke_rev(offset, B, orig, gutils, tape) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return nothing end - + width = get_width(gutils) - generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset+1, B, true; tape) + generic_setup(orig, runtime_generic_rev, Nothing, gutils, offset + 1, B, true; tape) #=start=# return nothing end diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 965114447c..a41912cf82 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -1,14 +1,30 @@ macro register_aug(expr) decl = string(expr.args[1]) name = decl[1:prevind(decl, findfirst('(', decl))] - cname = name*"_cfunc" + cname = name * "_cfunc" name = Symbol(name) cname = Symbol(cname) expr2 = :(@inline $expr) res = quote - function $cname(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef}, tapeR::Ptr{LLVM.API.LLVMValueRef})::UInt8 - return UInt8($name(LLVM.IRBuilder(B), LLVM.CallInst(OrigCI), GradientUtils(gutils), normalR, shadowR, tapeR)::Bool) + function $cname( + B::LLVM.API.LLVMBuilderRef, + OrigCI::LLVM.API.LLVMValueRef, + gutils::API.EnzymeGradientUtilsRef, + normalR::Ptr{LLVM.API.LLVMValueRef}, + shadowR::Ptr{LLVM.API.LLVMValueRef}, + tapeR::Ptr{LLVM.API.LLVMValueRef}, + )::UInt8 + return UInt8( + $name( + LLVM.IRBuilder(B), + LLVM.CallInst(OrigCI), + GradientUtils(gutils), + normalR, + shadowR, + tapeR, + )::Bool, + ) end end return Expr(:block, esc(expr2), esc(res)) @@ -17,14 +33,24 @@ end macro register_rev(expr) decl = string(expr.args[1]) name = decl[1:prevind(decl, findfirst('(', decl))] - cname = name*"_cfunc" + cname = name * "_cfunc" name = Symbol(name) cname = Symbol(cname) expr2 = :(@inline $expr) res = quote - function $cname(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, tape::LLVM.API.LLVMValueRef)::Cvoid - $name(LLVM.IRBuilder(B), LLVM.CallInst(OrigCI), GradientUtils(gutils), tape == C_NULL ? nothing : LLVM.Value(tape)) + function $cname( + B::LLVM.API.LLVMBuilderRef, + OrigCI::LLVM.API.LLVMValueRef, + gutils::API.EnzymeGradientUtilsRef, + tape::LLVM.API.LLVMValueRef, + )::Cvoid + $name( + LLVM.IRBuilder(B), + LLVM.CallInst(OrigCI), + GradientUtils(gutils), + tape == C_NULL ? nothing : LLVM.Value(tape), + ) return end end @@ -34,13 +60,27 @@ end macro register_fwd(expr) decl = string(expr.args[1]) name = decl[1:prevind(decl, findfirst('(', decl))] - cname = name*"_cfunc" + cname = name * "_cfunc" name = Symbol(name) cname = Symbol(cname) expr2 = :(@inline $expr) res = quote - function $cname(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef})::UInt8 - return UInt8($name(LLVM.IRBuilder(B), LLVM.CallInst(OrigCI), GradientUtils(gutils), normalR, shadowR)::Bool) + function $cname( + B::LLVM.API.LLVMBuilderRef, + OrigCI::LLVM.API.LLVMValueRef, + gutils::API.EnzymeGradientUtilsRef, + normalR::Ptr{LLVM.API.LLVMValueRef}, + shadowR::Ptr{LLVM.API.LLVMValueRef}, + )::UInt8 + return UInt8( + $name( + LLVM.IRBuilder(B), + LLVM.CallInst(OrigCI), + GradientUtils(gutils), + normalR, + shadowR, + )::Bool, + ) end end return Expr(:block, esc(expr2), esc(res)) @@ -49,13 +89,26 @@ end macro register_diffuse(expr) decl = string(expr.args[1]) name = decl[1:prevind(decl, findfirst('(', decl))] - cname = name*"_cfunc" + cname = name * "_cfunc" name = Symbol(name) cname = Symbol(cname) expr2 = :(@inline $expr) res = quote - function $cname(OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, val::LLVM.API.LLVMValueRef, shadow::UInt8, mode::API.CDerivativeMode, useDefault::Ptr{UInt8})::UInt8 - res = $name(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, mode)::Tuple{Bool, Bool} + function $cname( + OrigCI::LLVM.API.LLVMValueRef, + gutils::API.EnzymeGradientUtilsRef, + val::LLVM.API.LLVMValueRef, + shadow::UInt8, + mode::API.CDerivativeMode, + useDefault::Ptr{UInt8}, + )::UInt8 + res = $name( + LLVM.CallInst(OrigCI), + GradientUtils(gutils), + LLVM.Value(val), + shadow != 0, + mode, + )::Tuple{Bool,Bool} unsafe_store!(useDefault, UInt8(res[2])) return UInt8(res[1]) end @@ -75,7 +128,15 @@ include("parallelrules.jl") if in(name, ("ijl_apply_generic", "jl_apply_generic")) return common_generic_fwd(2, B, orig, gutils, normalR, shadowR) end - if in(name, ("ijl_f__apply_latest", "ijl_f__call_latest", "jl_f__apply_latest", "jl_f__call_latest")) + if in( + name, + ( + "ijl_f__apply_latest", + "ijl_f__call_latest", + "jl_f__apply_latest", + "jl_f__call_latest", + ), + ) return common_apply_latest_fwd(2, B, orig, gutils, normalR, shadowR) end if in(name, ("ijl_new_structv", "jl_new_structv")) @@ -99,17 +160,27 @@ include("parallelrules.jl") if in(name, ("ijl_f_finalizer", "jl_f_finalizer")) return common_finalizer_fwd(2, B, orig, gutils, normalR, shadowR) end - if any(map(k->kind(k)==kind(StringAttribute("enzyme_inactive")), collect(function_attributes(F)))) + if any( + map( + k -> kind(k) == kind(StringAttribute("enzyme_inactive")), + collect(function_attributes(F)), + ), + ) return true end end - err = emit_error(B, orig, "Enzyme: jl_call calling convention not implemented in forward for "*string(orig)) - + err = emit_error( + B, + orig, + "Enzyme: jl_call calling convention not implemented in forward for " * string(orig), + ) + newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -131,7 +202,15 @@ end if in(name, ("ijl_apply_generic", "jl_apply_generic")) return common_generic_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end - if in(name, ("ijl_f__apply_latest", "ijl_f__call_latest", "jl_f__apply_latest", "jl_f__call_latest")) + if in( + name, + ( + "ijl_f__apply_latest", + "ijl_f__call_latest", + "jl_f__apply_latest", + "jl_f__call_latest", + ), + ) return common_apply_latest_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end if in(name, ("ijl_new_structv", "jl_new_structv")) @@ -155,16 +234,27 @@ end if in(name, ("ijl_f_finalizer", "jl_f_finalizer")) return common_finalizer_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end - if any(map(k->kind(k)==kind(StringAttribute("enzyme_inactive")), collect(function_attributes(F)))) + if any( + map( + k -> kind(k) == kind(StringAttribute("enzyme_inactive")), + collect(function_attributes(F)), + ), + ) return true end end - err = emit_error(B, orig, "Enzyme: jl_call calling convention not implemented in aug_forward for "*string(orig)) + err = emit_error( + B, + orig, + "Enzyme: jl_call calling convention not implemented in aug_forward for " * + string(orig), + ) newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -187,7 +277,15 @@ end common_generic_rev(2, B, orig, gutils, tape) return nothing end - if in(name, ("ijl_f__apply_latest", "ijl_f__call_latest", "jl_f__apply_latest", "jl_f__call_latest")) + if in( + name, + ( + "ijl_f__apply_latest", + "ijl_f__call_latest", + "jl_f__apply_latest", + "jl_f__call_latest", + ), + ) common_apply_latest_rev(2, B, orig, gutils, tape) return nothing end @@ -219,12 +317,21 @@ end common_finalizer_rev(2, B, orig, gutils, tape) return nothing end - if any(map(k->kind(k)==kind(StringAttribute("enzyme_inactive")), collect(function_attributes(F)))) + if any( + map( + k -> kind(k) == kind(StringAttribute("enzyme_inactive")), + collect(function_attributes(F)), + ), + ) return nothing end end - emit_error(B, orig, "Enzyme: jl_call calling convention not implemented in reverse for "*string(orig)) + emit_error( + B, + orig, + "Enzyme: jl_call calling convention not implemented in reverse for " * string(orig), + ) return nothing end @@ -236,7 +343,12 @@ end if in(name, ("ijl_invoke", "jl_invoke")) return common_invoke_fwd(2, B, orig, gutils, normalR, shadowR) end - if any(map(k->kind(k)==kind(StringAttribute("enzyme_inactive")), collect(function_attributes(F)))) + if any( + map( + k -> kind(k) == kind(StringAttribute("enzyme_inactive")), + collect(function_attributes(F)), + ), + ) return true end end @@ -253,7 +365,12 @@ end if in(name, ("ijl_invoke", "jl_invoke")) return common_invoke_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end - if any(map(k->kind(k)==kind(StringAttribute("enzyme_inactive")), collect(function_attributes(F)))) + if any( + map( + k -> kind(k) == kind(StringAttribute("enzyme_inactive")), + collect(function_attributes(F)), + ), + ) return true end end @@ -271,7 +388,12 @@ end common_invoke_rev(2, B, orig, gutils, tape) return nothing end - if any(map(k->kind(k)==kind(StringAttribute("enzyme_inactive")), collect(function_attributes(F)))) + if any( + map( + k -> kind(k) == kind(StringAttribute("enzyme_inactive")), + collect(function_attributes(F)), + ), + ) return nothing end end @@ -295,8 +417,15 @@ end real_ops = collect(operands(orig))[1:end-1] ops = [lookup_value(gutils, new_from_original(gutils, o), B) for o in real_ops] - - c = call_samefunc_with_inverted_bundles!(B, gutils, orig, ops, [API.VT_Primal for _ in ops], #=lookup=#false) + + c = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + ops, + [API.VT_Primal for _ in ops], + false, + ) #=lookup=# callconv!(c, callconv(orig)) return nothing @@ -319,64 +448,117 @@ end algn = 0 if width == 1 - shadowres = call_samefunc_with_inverted_bundles!(B, gutils, orig, [shadowin], [API.VT_Shadow], #=lookup=#false) + shadowres = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + [shadowin], + [API.VT_Shadow], + false, + ) #=lookup=# # TODO zero based off runtime types, rather than presume floatlike? if is_constant_value(gutils, origops[1]) elSize = get_array_elsz(B, shadowin) - elSize = LLVM.zext!(B, elSize, LLVM.IntType(8*sizeof(Csize_t))) + elSize = LLVM.zext!(B, elSize, LLVM.IntType(8 * sizeof(Csize_t))) len = get_array_len(B, shadowin) length = LLVM.mul!(B, len, elSize) - bt = GPUCompiler.backtrace(orig) - btstr = sprint() do io - print(io,"\nCaused by:") - Base.show_backtrace(io, bt) - end + bt = GPUCompiler.backtrace(orig) + btstr = sprint() do io + print(io, "\nCaused by:") + Base.show_backtrace(io, bt) + end GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type $btstr" - LLVM.memset!(B, get_array_data(B, shadowres), LLVM.ConstantInt(i8, 0, false), length, algn) + LLVM.memset!( + B, + get_array_data(B, shadowres), + LLVM.ConstantInt(i8, 0, false), + length, + algn, + ) end if get_runtime_activity(gutils) prev = new_from_original(gutils, orig) - shadowres = LLVM.select!(B, LLVM.icmp!(B, LLVM.API.LLVMIntNE, shadowin, new_from_original(gutils, origops[1])), shadowres, prev) + shadowres = LLVM.select!( + B, + LLVM.icmp!( + B, + LLVM.API.LLVMIntNE, + shadowin, + new_from_original(gutils, origops[1]), + ), + shadowres, + prev, + ) API.moveBefore(prev, shadowres, B) end else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width - ev = extract_value!(B, shadowin, idx-1) - callv = call_samefunc_with_inverted_bundles!(B, gutils, orig, [ev], [API.VT_Shadow], #=lookup=#false) + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + ev = extract_value!(B, shadowin, idx - 1) + callv = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + [ev], + [API.VT_Shadow], + false, + ) #=lookup=# if is_constant_value(gutils, origops[1]) elSize = get_array_elsz(B, ev) - elSize = LLVM.zext!(B, elSize, LLVM.IntType(8*sizeof(Csize_t))) + elSize = LLVM.zext!(B, elSize, LLVM.IntType(8 * sizeof(Csize_t))) len = get_array_len(B, ev) length = LLVM.mul!(B, len, elSize) - bt = GPUCompiler.backtrace(orig) - btstr = sprint() do io - print(io,"\nCaused by:") - Base.show_backtrace(io, bt) - end + bt = GPUCompiler.backtrace(orig) + btstr = sprint() do io + print(io, "\nCaused by:") + Base.show_backtrace(io, bt) + end GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type $btstr" - LLVM.memset!(B, get_array_data(B, callv), LLVM.ConstantInt(i8, 0, false), length, algn) + LLVM.memset!( + B, + get_array_data(B, callv), + LLVM.ConstantInt(i8, 0, false), + length, + algn, + ) end if get_runtime_activity(gutils) prev = new_from_original(gutils, orig) - callv = LLVM.select!(B, LLVM.icmp!(B, LLVM.API.LLVMIntNE, ev, new_from_original(gutils, origops[1])), callv, prev) + callv = LLVM.select!( + B, + LLVM.icmp!( + B, + LLVM.API.LLVMIntNE, + ev, + new_from_original(gutils, origops[1]), + ), + callv, + prev, + ) if idx == 1 API.moveBefore(prev, callv, B) end end - shadowres = insert_value!(B, shadowres, callv, idx-1) + shadowres = insert_value!(B, shadowres, callv, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) - return false + return false end function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) needsPrimal = needsPrimalP[] != 0 needsShadow = needsShadowP[] != 0 if !needsShadow @@ -390,20 +572,20 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) # size_t len = jl_array_len(ary); # size_t elsz = ary->elsize; # memcpy(new_ary->data, ary->data, len * elsz); - # JL_EXTENSION typedef struct { - # JL_DATA_TYPE - # void *data; - # #ifdef STORE_ARRAY_LEN - # size_t length; - # #endif - # jl_array_flags_t flags; - # uint16_t elsize; // element size including alignment (dim 1 memory stride) - - tt = TypeTree(API.EnzymeGradientUtilsAllocAndGetTypeTree(gutils, orig)) + # JL_EXTENSION typedef struct { + # JL_DATA_TYPE + # void *data; + # #ifdef STORE_ARRAY_LEN + # size_t length; + # #endif + # jl_array_flags_t flags; + # uint16_t elsize; // element size including alignment (dim 1 memory stride) + + tt = TypeTree(API.EnzymeGradientUtilsAllocAndGetTypeTree(gutils, orig)) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) - dl = string(LLVM.datalayout(mod)) - API.EnzymeTypeTreeLookupEq(tt, 1, dl) - data0!(tt) + dl = string(LLVM.datalayout(mod)) + API.EnzymeTypeTreeLookupEq(tt, 1, dl) + data0!(tt) ct = API.EnzymeTypeTreeInner0(tt) if ct == API.DT_Unknown @@ -411,7 +593,11 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) # ip = API.EnzymeTypeAnalyzerToString(analyzer) # sval = Base.unsafe_string(ip) # API.EnzymeStringFree(ip) - emit_error(B, orig, "Enzyme: Unknown concrete type in arraycopy_common. tt: " * string(tt)) + emit_error( + B, + orig, + "Enzyme: Unknown concrete type in arraycopy_common. tt: " * string(tt), + ) return nothing end @@ -431,7 +617,14 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) B0 = B elseif typeof(actualOp) <: LLVM.Argument B0 = LLVM.IRBuilder() - position!(B0, first(instructions(new_from_original(gutils, LLVM.entry(LLVM.parent(LLVM.parent(orig))))))) + position!( + B0, + first( + instructions( + new_from_original(gutils, LLVM.entry(LLVM.parent(LLVM.parent(orig)))), + ), + ), + ) else B0 = LLVM.IRBuilder() nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(actualOp)) @@ -442,7 +635,7 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) end elSize = get_array_elsz(B0, actualOp) - elSize = LLVM.zext!(B0, elSize, LLVM.IntType(8*sizeof(Csize_t))) + elSize = LLVM.zext!(B0, elSize, LLVM.IntType(8 * sizeof(Csize_t))) len = get_array_len(B0, actualOp) @@ -478,30 +671,64 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) if width == 1 - shadowsrc = get_array_data(B, shadowsrc) - shadowdst = get_array_data(B, shadowdst) - - if fwd && secretty != nothing - LLVM.memset!(B, shadowdst, LLVM.ConstantInt(i8, 0, false), length, algn) - end - - API.sub_transfer(gutils, fwd ? API.DEM_ReverseModePrimal : API.DEM_ReverseModeGradient, secretty, intrinsic, #=dstAlign=#1, #=srcAlign=#1, #=offset=#0, false, shadowdst, false, shadowsrc, length, isVolatile, orig, allowForward, #=shadowsLookedUp=#!fwd) + shadowsrc = get_array_data(B, shadowsrc) + shadowdst = get_array_data(B, shadowdst) + + if fwd && secretty != nothing + LLVM.memset!(B, shadowdst, LLVM.ConstantInt(i8, 0, false), length, algn) + end + + API.sub_transfer( + gutils, + fwd ? API.DEM_ReverseModePrimal : API.DEM_ReverseModeGradient, + secretty, + intrinsic, + 1, + 1, + 0, + false, + shadowdst, + false, + shadowsrc, + length, + isVolatile, + orig, + allowForward, + !fwd, + ) #=shadowsLookedUp=# else - for i in 1:width + for i = 1:width - evsrc = extract_value!(B, shadowsrc, i-1) - evdst = extract_value!(B, shadowdst, i-1) + evsrc = extract_value!(B, shadowsrc, i - 1) + evdst = extract_value!(B, shadowdst, i - 1) - shadowsrc0 = get_array_data(B, evsrc) - shadowdst0 = get_array_data(B, evdst) + shadowsrc0 = get_array_data(B, evsrc) + shadowdst0 = get_array_data(B, evdst) - if fwd && secretty != nothing - LLVM.memset!(B, shadowdst0, LLVM.ConstantInt(i8, 0, false), length, algn) - end + if fwd && secretty != nothing + LLVM.memset!(B, shadowdst0, LLVM.ConstantInt(i8, 0, false), length, algn) + end - API.sub_transfer(gutils, fwd ? API.DEM_ReverseModePrimal : API.DEM_ReverseModeGradient, secretty, intrinsic, #=dstAlign=#1, #=srcAlign=#1, #=offset=#0, false, shadowdst0, false, shadowsrc0, length, isVolatile, orig, allowForward, #=shadowsLookedUp=#!fwd) - end + API.sub_transfer( + gutils, + fwd ? API.DEM_ReverseModePrimal : API.DEM_ReverseModeGradient, + secretty, + intrinsic, + 1, + 1, + 0, + false, + shadowdst0, + false, + shadowsrc0, + length, + isVolatile, + orig, + allowForward, + !fwd, + ) #=shadowsLookedUp=# + end end @@ -517,18 +744,18 @@ end origops = LLVM.operands(orig) if !is_constant_value(gutils, origops[1]) && !is_constant_value(gutils, orig) - shadowres = LLVM.Value(unsafe_load(shadowR)) + shadowres = LLVM.Value(unsafe_load(shadowR)) - arraycopy_common(#=fwd=#true, B, orig, origops[1], gutils, shadowres) + arraycopy_common(true, B, orig, origops[1], gutils, shadowres) #=fwd=# end - return false + return false end @register_rev function arraycopy_rev(B, orig, gutils, tape) origops = LLVM.operands(orig) if !is_constant_value(gutils, origops[1]) && !is_constant_value(gutils, orig) - arraycopy_common(#=fwd=#false, B, orig, origops[1], gutils, nothing) + arraycopy_common(false, B, orig, origops[1], gutils, nothing) #=fwd=# end return nothing @@ -548,25 +775,41 @@ end shadowin = invert_pointer(gutils, origops[2], B) if width == 1 args = LLVM.Value[ - new_from_original(gutils, origops[1]) - shadowin - new_from_original(gutils, origops[3]) - ] - shadowres = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Primal, API.VT_Shadow, API.VT_Primal], #=lookup=#false) + new_from_original(gutils, origops[1]) + shadowin + new_from_original(gutils, origops[3]) + ] + shadowres = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Primal, API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width - args = LLVM.Value[new_from_original(gutils, origops[1]) - extract_value!(B, shadowin, idx-1) - new_from_original(gutils, origops[3]) - ] - tmp = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Primal, API.VT_Shadow, API.VT_Primal], #=lookup=#false) - shadowres = insert_value!(B, shadowres, tmp, idx-1) + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + args = LLVM.Value[ + new_from_original(gutils, origops[1]) + extract_value!(B, shadowin, idx - 1) + new_from_original(gutils, origops[3]) + ] + tmp = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Primal, API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# + shadowres = insert_value!(B, shadowres, tmp, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) - return false + return false end @register_aug function arrayreshape_augfwd(B, orig, gutils, normalR, shadowR, tapeR) @@ -580,12 +823,18 @@ end @register_fwd function gcloaded_fwd(B, orig, gutils, normalR, shadowR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) return true end - + origops = LLVM.operands(orig) if is_constant_value(gutils, origops[1]) emit_error(B, orig, "Enzyme: gcloaded has active return, but inactive input(1)") @@ -600,21 +849,36 @@ end shadowin2 = invert_pointer(gutils, origops[2], B) if width == 1 args = LLVM.Value[shadowin1, shadowin2] - shadowres = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Shadow], #=lookup=#false) + shadowres = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Shadow, API.VT_Shadow], + false, + ) #=lookup=# else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width args = LLVM.Value[ - extract_value!(B, shadowin1, idx-1) - extract_value!(B, shadowin2, idx-1) - ] - tmp = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Shadow], #=lookup=#false) - shadowres = insert_value!(B, shadowres, tmp, idx-1) + extract_value!(B, shadowin1, idx - 1) + extract_value!(B, shadowin2, idx - 1) + ] + tmp = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Shadow, API.VT_Shadow], + false, + ) #=lookup=# + shadowres = insert_value!(B, shadowres, tmp, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) - return false + return false end @register_aug function gcloaded_augfwd(B, orig, gutils, normalR, shadowR, tapeR) @@ -628,10 +892,16 @@ end @register_fwd function boxfloat_fwd(B, orig, gutils, normalR, shadowR) origops = collect(operands(orig)) width = get_width(gutils) - + needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) if is_constant_value(gutils, orig) || needsShadowP[] == 0 return true @@ -643,14 +913,13 @@ end shadowres = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), shadowsin) callconv!(shadowres, callconv(orig)) else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width - args = LLVM.Value[ - extract_value!(B, s, idx-1) for s in shadowsin - ] + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + args = LLVM.Value[extract_value!(B, s, idx - 1) for s in shadowsin] tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args) callconv!(tmp, callconv(orig)) - shadowres = insert_value!(B, shadowres, tmp, idx-1) + shadowres = insert_value!(B, shadowres, tmp, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -660,10 +929,16 @@ end @register_aug function boxfloat_augfwd(B, orig, gutils, normalR, shadowR, tapeR) origops = collect(operands(orig)) width = get_width(gutils) - + needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) if is_constant_value(gutils, orig) || needsShadowP[] == 0 return true @@ -679,11 +954,11 @@ end shadowres = obj else shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, flt))) - for idx in 1:width + for idx = 1:width obj = emit_allocobj!(B, Base.RefValue{TT}) o2 = bitcast!(B, obj, LLVM.PointerType(flt, addrspace(value_type(obj)))) store!(B, ConstantFP(flt, 0.0), o2) - shadowres = insert_value!(B, shadowres, obj, idx-1) + shadowres = insert_value!(B, shadowres, obj, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -691,10 +966,16 @@ end end @register_rev function boxfloat_rev(B, orig, gutils, tape) - + needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) if is_constant_value(gutils, orig) || needsShadowP[] == 0 return nothing @@ -713,12 +994,12 @@ end end else shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, flt))) - for idx in 1:width - ipc = extract_value!(B, ip, idx-1) + for idx = 1:width + ipc = extract_value!(B, ip, idx - 1) ipc = bitcast!(B, ipc, LLVM.PointerType(flt, addrspace(value_type(orig)))) ld = load!(B, flt, ipc) store!(B, ConstantFP(flt, 0.0), ipc) - shadowres = insert_value!(B, shadowres, ld, idx-1) + shadowres = insert_value!(B, shadowres, ld, idx - 1) end if !is_constant_value(gutils, origops[1]) API.EnzymeGradientUtilsAddToDiffe(gutils, origops[1], shadowret, B, flt) @@ -734,7 +1015,8 @@ end emit_error(B, orig, "Enzyme: Not yet implemented forward for jl_eqtable_get") - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -742,11 +1024,13 @@ end return false end -function error_if_active(::Type{T}) where T +function error_if_active(::Type{T}) where {T} seen = () - areg = active_reg_inner(T, seen, nothing, #=justActive=#Val(true)) + areg = active_reg_inner(T, seen, nothing, Val(true)) #=justActive=# if areg == ActiveState - throw(AssertionError("Found unhandled active variable in tuple splat, jl_eqtable $T")) + throw( + AssertionError("Found unhandled active variable in tuple splat, jl_eqtable $T"), + ) end nothing end @@ -755,12 +1039,18 @@ end if is_constant_value(gutils, orig) return true end - + mode = get_mode(gutils) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + mode, + ) if needsShadowP[] == 0 return false end @@ -772,14 +1062,49 @@ end origh, origkey, origdflt = operands(orig)[1:end-1] if is_constant_value(gutils, origh) - emit_error(B, orig, "Enzyme: Not yet implemented constant table in jl_eqtable_get "*string(origh)*" "*string(orig)*" result: "*string(absint(orig))*" "*string(abs_typeof(orig, true))*" dict: "*string(absint(origh))*" "*string(abs_typeof(origh, true))*" key "*string(absint(origkey))*" "*string(abs_typeof(origkey, true))*" dflt "*string(absint(origdflt))*" "*string(abs_typeof(origdflt, true))) + emit_error( + B, + orig, + "Enzyme: Not yet implemented constant table in jl_eqtable_get " * + string(origh) * + " " * + string(orig) * + " result: " * + string(absint(orig)) * + " " * + string(abs_typeof(orig, true)) * + " dict: " * + string(absint(origh)) * + " " * + string(abs_typeof(origh, true)) * + " key " * + string(absint(origkey)) * + " " * + string(abs_typeof(origkey, true)) * + " dflt " * + string(absint(origdflt)) * + " " * + string(abs_typeof(origdflt, true)), + ) end - + shadowh = invert_pointer(gutils, origh, B) shadowdflt = if is_constant_value(gutils, origdflt) - shadowdflt2 = julia_error(Base.unsafe_convert(Cstring, "Mixed activity for default of jl_eqtable_get "*string(orig)*" "*string(origdflt)), - orig.ref, API.ET_MixedActivityError, gutils.ref, origdflt.ref, B.ref) + shadowdflt2 = julia_error( + Base.unsafe_convert( + Cstring, + "Mixed activity for default of jl_eqtable_get " * + string(orig) * + " " * + string(origdflt), + ), + orig.ref, + API.ET_MixedActivityError, + gutils.ref, + origdflt.ref, + B.ref, + ) if shadowdflt2 != C_NULL LLVM.Value(shadowdflt2) else @@ -789,8 +1114,8 @@ end else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(nop))) shadowm = LLVM.UndefValue(ST) - for j in 1:width - shadowm = insert_value!(B, shadowm, nop, j-1) + for j = 1:width + shadowm = insert_value!(B, shadowm, nop, j - 1) end shadowm end @@ -798,24 +1123,41 @@ end else invert_pointer(gutils, origdflt, B) end - + newvals = API.CValueType[API.VT_Shadow, API.VT_Primal, API.VT_Shadow] - + shadowres = if width == 1 newops = LLVM.Value[shadowh, new_from_original(gutils, origkey), shadowdflt] - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, false) #=lookup=# callconv!(cal, callconv(orig)) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, cal)]) + emit_apply_generic!( + B, + LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, cal)], + ) cal else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for j in 1:width - newops = LLVM.Value[extract_value!(B, shadowh, j-1), new_from_original(gutils, origkey), extract_value!(B, shadowdflt, j-1)] - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + for j = 1:width + newops = LLVM.Value[ + extract_value!(B, shadowh, j - 1), + new_from_original(gutils, origkey), + extract_value!(B, shadowdflt, j - 1), + ] + cal = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + newops, + newvals, + false, + ) #=lookup=# callconv!(cal, callconv(orig)) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, cal)]) - shadow = insert_value!(B, shadow, cal, j-1) + emit_apply_generic!( + B, + LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, cal)], + ) + shadow = insert_value!(B, shadow, cal, j - 1) end shadow end @@ -834,7 +1176,8 @@ end end emit_error(B, orig, "Enzyme: Not yet implemented forward for jl_eqtable_put") - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -857,8 +1200,20 @@ end shadowval = invert_pointer(gutils, origval, B) shadowval = if is_constant_value(gutils, origval) - shadowdflt2 = julia_error(Base.unsafe_convert(Cstring, "Mixed activity for val of jl_eqtable_put "*string(orig)*" "*string(origval)), - orig.ref, API.ET_MixedActivityError, gutils.ref, origval.ref, B.ref) + shadowdflt2 = julia_error( + Base.unsafe_convert( + Cstring, + "Mixed activity for val of jl_eqtable_put " * + string(orig) * + " " * + string(origval), + ), + orig.ref, + API.ET_MixedActivityError, + gutils.ref, + origval.ref, + B.ref, + ) if shadowdflt2 != C_NULL LLVM.Value(shadowdflt2) else @@ -868,8 +1223,8 @@ end else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(nop))) shadowm = LLVM.UndefValue(ST) - for j in 1:width - shadowm = insert_value!(B, shadowm, nop, j-1) + for j = 1:width + shadowm = insert_value!(B, shadowm, nop, j - 1) end shadowm end @@ -881,23 +1236,46 @@ end mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) newvals = API.CValueType[API.VT_Shadow, API.VT_Primal, API.VT_Shadow, API.VT_None] - + shadowres = if width == 1 - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, shadowval)]) - newops = LLVM.Value[shadowh, new_from_original(gutils, origkey), shadowval, LLVM.null(value_type(originserted))] - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + emit_apply_generic!( + B, + LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, shadowval)], + ) + newops = LLVM.Value[ + shadowh, + new_from_original(gutils, origkey), + shadowval, + LLVM.null(value_type(originserted)), + ] + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, false) #=lookup=# callconv!(cal, callconv(orig)) cal else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for j in 1:width - sval2 = extract_value!(B, shadowval, j-1) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, sval2)]) - newops = LLVM.Value[extract_value!(B, shadowh, j-1), new_from_original(gutils, origkey), sval2, LLVM.null(value_type(originserted))] - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + for j = 1:width + sval2 = extract_value!(B, shadowval, j - 1) + emit_apply_generic!( + B, + LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, sval2)], + ) + newops = LLVM.Value[ + extract_value!(B, shadowh, j - 1), + new_from_original(gutils, origkey), + sval2, + LLVM.null(value_type(originserted)), + ] + cal = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + newops, + newvals, + false, + ) #=lookup=# callconv!(cal, callconv(orig)) - shadow = insert_value!(B, shadow, cal, j-1) + shadow = insert_value!(B, shadow, cal, j - 1) end shadow end @@ -917,7 +1295,8 @@ end end emit_error(B, orig, "Enzyme: Not yet implemented forward for jl_idtable_rehash") - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -929,9 +1308,14 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - emit_error(B, orig, "Enzyme: Not yet implemented augmented forward for jl_idtable_rehash") + emit_error( + B, + orig, + "Enzyme: Not yet implemented augmented forward for jl_idtable_rehash", + ) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -955,17 +1339,31 @@ end shadowin = invert_pointer(gutils, origops[1], B) if width == 1 args = LLVM.Value[ - shadowin - new_from_original(gutils, origops[2]) - ] - call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Primal], #=lookup=#false) + shadowin + new_from_original(gutils, origops[2]) + ] + call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# else - for idx in 1:width + for idx = 1:width args = LLVM.Value[ - extract_value!(B, shadowin, idx-1) - new_from_original(gutils, origops[2]) - ] - call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Primal], #=lookup=#false) + extract_value!(B, shadowin, idx - 1) + new_from_original(gutils, origops[2]) + ] + call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# end end return false @@ -997,14 +1395,21 @@ end tot = mul!(B, inc, elsz) args = LLVM.Value[anti, inc] - call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Primal], #=lookup=#false) + call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# toset = get_array_data(B, anti) toset = gep!(B, i8, toset, LLVM.Value[off]) mcall = LLVM.memset!(B, toset, LLVM.ConstantInt(i8, 0, false), tot, al) else - for idx in 1:width - anti = extract_value!(B, shadowin, idx-1) + for idx = 1:width + anti = extract_value!(B, shadowin, idx - 1) idx = get_array_nrows(B, anti) elsz = zext!(B, get_array_elsz(B, anti), value_type(idx)) @@ -1012,7 +1417,14 @@ end tot = mul!(B, inc, elsz) args = LLVM.Value[anti, inc] - call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Primal], #=lookup=#false) + call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# toset = get_array_data(B, anti) toset = gep!(B, i8, toset, LLVM.Value[off]) @@ -1042,17 +1454,18 @@ end if width == 1 args = LLVM.Value[ - shadowin - offset - ] + shadowin + offset + ] LLVM.call!(B, fty, delF, args) else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width args = LLVM.Value[ - extract_value!(B, shadowin, idx-1) - offset - ] + extract_value!(B, shadowin, idx - 1) + offset + ] LLVM.call!(B, fty, delF, args) end end @@ -1086,32 +1499,40 @@ end # TODO get actual alignment algn = 0 - + i8 = LLVM.IntType(8) - for idx in 1:width + for idx = 1:width anti = if width == 1 shadowin else - extract_value!(B, shadowin, idx-1) + extract_value!(B, shadowin, idx - 1) end if get_runtime_activity(gutils) - emit_error(B, orig, "Enzyme: Not yet implemented runtime activity for reverse of jl_array_del_end") + emit_error( + B, + orig, + "Enzyme: Not yet implemented runtime activity for reverse of jl_array_del_end", + ) end args = LLVM.Value[anti, offset] - - found, arty = abs_typeof(origops[1]) + + found, arty, byref = abs_typeof(origops[1]) anti = shadowin elSize = if found LLVM.ConstantInt(Csize_t(sizeof(eltype(arty)))) else - elSize = LLVM.zext!(B, get_array_elsz(B, anti), LLVM.IntType(8*sizeof(Csize_t))) + elSize = LLVM.zext!( + B, + get_array_elsz(B, anti), + LLVM.IntType(8 * sizeof(Csize_t)), + ) end len = get_array_len(B, anti) - + LLVM.call!(B, fty, delF, args) - + length = LLVM.mul!(B, len, elSize) - + if !found && !(eltype(arty) <: Base.IEEEFloat) GPUCompiler.@safe_warn "TODO reverse jl_array_del_end zero-set used memset rather than runtime type of $((found, arty)) in $(string(origops[1]))" end @@ -1138,22 +1559,30 @@ end push!(args, v) end push!(args, new_from_original(gutils, origops[end-1])) - valTys = API.CValueType[API.VT_Shadow, API.VT_Shadow, API.VT_Shadow, API.VT_Shadow, API.VT_Primal] + valTys = API.CValueType[ + API.VT_Shadow, + API.VT_Shadow, + API.VT_Shadow, + API.VT_Shadow, + API.VT_Primal, + ] if width == 1 vargs = args - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, vargs, valTys, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, vargs, valTys, false) #=lookup=# debug_from_orig!(gutils, cal, orig) callconv!(cal, callconv(orig)) else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width vargs = LLVM.Value[] for a in args[1:end-1] - push!(vargs, extract_value!(B, a, idx-1)) + push!(vargs, extract_value!(B, a, idx - 1)) end push!(vargs, args[end]) - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, vargs, valTys, #=lookup=#false) + cal = + call_samefunc_with_inverted_bundles!(B, gutils, orig, vargs, valTys, false) #=lookup=# debug_from_orig!(gutils, cal, orig) callconv!(cal, callconv(orig)) end @@ -1162,7 +1591,7 @@ end return false end @register_aug function jl_array_ptr_copy_augfwd(B, orig, gutils, normalR, shadowR, tapeR) - jl_array_ptr_copy_fwd(B, orig, gutils, normalR, shadowR) + jl_array_ptr_copy_fwd(B, orig, gutils, normalR, shadowR) end @register_rev function jl_array_ptr_copy_rev(B, orig, gutils, tape) return nothing @@ -1178,18 +1607,33 @@ end shadowin = invert_pointer(gutils, origops[1], B) if width == 1 args = LLVM.Value[ - shadowin - new_from_original(gutils, origops[2]) - ] - call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Primal], #=lookup=#false) + shadowin + new_from_original(gutils, origops[2]) + ] + call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width args = LLVM.Value[ - extract_value!(B, shadowin, idx-1) - new_from_original(gutils, origops[2]) - ] - call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Primal], #=lookup=#false) + extract_value!(B, shadowin, idx - 1) + new_from_original(gutils, origops[2]) + ] + call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# end end return false @@ -1206,9 +1650,10 @@ end @register_fwd function jl_unhandled_fwd(B, orig, gutils, normalR, shadowR) newo = new_from_original(gutils, orig) origops = collect(operands(orig)) - err = emit_error(B, orig, "Enzyme: unhandled forward for "*string(origops[end])) + err = emit_error(B, orig, "Enzyme: unhandled forward for " * string(origops[end])) API.moveBefore(newo, err, C_NULL) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing width = get_width(gutils) @@ -1216,9 +1661,11 @@ end shadowres = normal else position!(B, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(normal))) - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal)))) - for idx in 1:width - shadowres = insert_value!(B, shadowres, normal, idx-1) + shadowres = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal))), + ) + for idx = 1:width + shadowres = insert_value!(B, shadowres, normal, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -1226,7 +1673,7 @@ end return false end @register_aug function jl_unhandled_augfwd(B, orig, gutils, normalR, shadowR, tapeR) - jl_unhandled_fwd(B, orig, gutils, normalR, shadowR) + jl_unhandled_fwd(B, orig, gutils, normalR, shadowR) end @register_rev function jl_unhandled_rev(B, orig, gutils, tape) return nothing @@ -1241,16 +1688,21 @@ end API.moveBefore(newo, err, B) if unsafe_load(shadowR) != C_NULL - valTys = API.CValueType[API.VT_Primal, API.VT_Primal] - args = [new_from_original(gutils, operands(orig)[1]), new_from_original(gutils, operands(orig)[2])] - normal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, #=lookup=#false) + valTys = API.CValueType[API.VT_Primal, API.VT_Primal] + args = [ + new_from_original(gutils, operands(orig)[1]), + new_from_original(gutils, operands(orig)[2]), + ] + normal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, false) #=lookup=# width = get_width(gutils) if width == 1 shadowres = normal else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal)))) - for idx in 1:width - shadowres = insert_value!(B, shadowres, normal, idx-1) + shadowres = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal))), + ) + for idx = 1:width + shadowres = insert_value!(B, shadowres, normal, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -1262,20 +1714,29 @@ end if is_constant_value(gutils, orig) return true end - err = emit_error(B, orig, "Enzyme: unhandled augmented forward for jl_get_binding_or_error") + err = emit_error( + B, + orig, + "Enzyme: unhandled augmented forward for jl_get_binding_or_error", + ) newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) if unsafe_load(shadowR) != C_NULL - valTys = API.CValueType[API.VT_Primal, API.VT_Primal] - args = [new_from_original(gutils, operands(orig)[1]), new_from_original(gutils, operands(orig)[2])] - normal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, #=lookup=#false) + valTys = API.CValueType[API.VT_Primal, API.VT_Primal] + args = [ + new_from_original(gutils, operands(orig)[1]), + new_from_original(gutils, operands(orig)[2]), + ] + normal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, false) #=lookup=# width = get_width(gutils) if width == 1 shadowres = normal else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal)))) - for idx in 1:width - shadowres = insert_value!(B, shadowres, normal, idx-1) + shadowres = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal))), + ) + for idx = 1:width + shadowres = insert_value!(B, shadowres, normal, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -1292,10 +1753,15 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - err = emit_error(B, orig, "Enzyme: unhandled forward for jl_gc_add_finalizer_th or jl_gc_add_ptr_finalizer") + err = emit_error( + B, + orig, + "Enzyme: unhandled forward for jl_gc_add_finalizer_th or jl_gc_add_ptr_finalizer", + ) newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -1306,10 +1772,15 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - err = emit_error(B, orig, "Enzyme: unhandled augmented forward for jl_gc_add_finalizer_th") + err = emit_error( + B, + orig, + "Enzyme: unhandled augmented forward for jl_gc_add_finalizer_th", + ) newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -1333,10 +1804,15 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - err = emit_error(B, orig, "There is a known issue in GPUCompiler.jl which is preventing higher-order AD of this code.\nPlease see https://github.com/JuliaGPU/GPUCompiler.jl/issues/629 for more information and to alert the GPUCompiler authors of your use case and need.") + err = emit_error( + B, + orig, + "There is a known issue in GPUCompiler.jl which is preventing higher-order AD of this code.\nPlease see https://github.com/JuliaGPU/GPUCompiler.jl/issues/629 for more information and to alert the GPUCompiler authors of your use case and need.", + ) newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -1347,10 +1823,15 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - err = emit_error(B, orig, "There is a known issue in GPUCompiler.jl which is preventing higher-order AD of this code.\nPlease see https://github.com/JuliaGPU/GPUCompiler.jl/issues/629 for more information and to alert the GPUCompiler authors of your use case and need.") + err = emit_error( + B, + orig, + "There is a known issue in GPUCompiler.jl which is preventing higher-order AD of this code.\nPlease see https://github.com/JuliaGPU/GPUCompiler.jl/issues/629 for more information and to alert the GPUCompiler authors of your use case and need.", + ) newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -1369,7 +1850,7 @@ end end -function register_handler!(variants, augfwd_handler, rev_handler, fwd_handler=nothing) +function register_handler!(variants, augfwd_handler, rev_handler, fwd_handler = nothing) for variant in variants if augfwd_handler !== nothing && rev_handler !== nothing API.EnzymeRegisterCallHandler(variant, augfwd_handler, rev_handler) @@ -1381,31 +1862,71 @@ function register_handler!(variants, augfwd_handler, rev_handler, fwd_handler=no end macro augfunc(f) - cname = Symbol(string(f)*"_cfunc") - :(@cfunction($cname, UInt8, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}) + cname = Symbol(string(f) * "_cfunc") + :(@cfunction( + $cname, + UInt8, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + API.EnzymeGradientUtilsRef, + Ptr{LLVM.API.LLVMValueRef}, + Ptr{LLVM.API.LLVMValueRef}, + Ptr{LLVM.API.LLVMValueRef}, + ) )) end macro revfunc(f) - cname = Symbol(string(f)*"_cfunc") - :(@cfunction($cname, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef) + cname = Symbol(string(f) * "_cfunc") + :(@cfunction( + $cname, + Cvoid, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + API.EnzymeGradientUtilsRef, + LLVM.API.LLVMValueRef, + ) )) end macro fwdfunc(f) - cname = Symbol(string(f)*"_cfunc") - :(@cfunction($cname, UInt8, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}) + cname = Symbol(string(f) * "_cfunc") + :(@cfunction( + $cname, + UInt8, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + API.EnzymeGradientUtilsRef, + Ptr{LLVM.API.LLVMValueRef}, + Ptr{LLVM.API.LLVMValueRef}, + ) )) end macro diffusefunc(f) - cname = Symbol(string(f)*"_cfunc") - :(@cfunction(Compiler.$cname, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8}) + cname = Symbol(string(f) * "_cfunc") + :(@cfunction( + Compiler.$cname, + UInt8, + ( + LLVM.API.LLVMValueRef, + API.EnzymeGradientUtilsRef, + LLVM.API.LLVMValueRef, + UInt8, + API.CDerivativeMode, + Ptr{UInt8}, + ) )) end @noinline function register_llvm_rules() - API.EnzymeRegisterDiffUseCallHandler("enzyme_custom", @diffusefunc(enzyme_custom_diffuse)) + API.EnzymeRegisterDiffUseCallHandler( + "enzyme_custom", + @diffusefunc(enzyme_custom_diffuse) + ) register_handler!( ("julia.call",), @augfunc(jlcall_augfwd), @@ -1473,79 +1994,79 @@ end @fwdfunc(wait_fwd), ) register_handler!( - ("jl_","jl_breakpoint"), + ("jl_", "jl_breakpoint"), @augfunc(noop_augfwd), @revfunc(duplicate_rev), @fwdfunc(noop_fwd), ) register_handler!( - ("jl_array_copy","ijl_array_copy"), + ("jl_array_copy", "ijl_array_copy"), @augfunc(arraycopy_augfwd), @revfunc(arraycopy_rev), @fwdfunc(arraycopy_fwd), ) register_handler!( - ("jl_reshape_array","ijl_reshape_array"), + ("jl_reshape_array", "ijl_reshape_array"), @augfunc(arrayreshape_augfwd), @revfunc(arrayreshape_rev), @fwdfunc(arrayreshape_fwd), ) register_handler!( - ("jl_f_setfield","ijl_f_setfield"), + ("jl_f_setfield", "ijl_f_setfield"), @augfunc(setfield_augfwd), @revfunc(setfield_rev), @fwdfunc(setfield_fwd), ) register_handler!( - ("jl_box_float32","ijl_box_float32", "jl_box_float64", "ijl_box_float64"), + ("jl_box_float32", "ijl_box_float32", "jl_box_float64", "ijl_box_float64"), @augfunc(boxfloat_augfwd), @revfunc(boxfloat_rev), @fwdfunc(boxfloat_fwd), ) register_handler!( - ("jl_f_tuple","ijl_f_tuple"), + ("jl_f_tuple", "ijl_f_tuple"), @augfunc(f_tuple_augfwd), @revfunc(f_tuple_rev), @fwdfunc(f_tuple_fwd), ) register_handler!( - ("jl_eqtable_get","ijl_eqtable_get"), + ("jl_eqtable_get", "ijl_eqtable_get"), @augfunc(eqtableget_augfwd), @revfunc(eqtableget_rev), @fwdfunc(eqtableget_fwd), ) register_handler!( - ("jl_eqtable_put","ijl_eqtable_put"), + ("jl_eqtable_put", "ijl_eqtable_put"), @augfunc(eqtableput_augfwd), @revfunc(eqtableput_rev), @fwdfunc(eqtableput_fwd), ) register_handler!( - ("jl_idtable_rehash","ijl_idtable_rehash"), + ("jl_idtable_rehash", "ijl_idtable_rehash"), @augfunc(idtablerehash_augfwd), @revfunc(idtablerehash_rev), @fwdfunc(idtablerehash_fwd), ) register_handler!( - ("jl_f__apply_iterate","ijl_f__apply_iterate"), + ("jl_f__apply_iterate", "ijl_f__apply_iterate"), @augfunc(apply_iterate_augfwd), @revfunc(apply_iterate_rev), @fwdfunc(apply_iterate_fwd), ) register_handler!( - ("jl_f__svec_ref","ijl_f__svec_ref"), + ("jl_f__svec_ref", "ijl_f__svec_ref"), @augfunc(f_svec_ref_augfwd), @revfunc(f_svec_ref_rev), @fwdfunc(f_svec_ref_fwd), ) register_handler!( - ("jl_new_structv","ijl_new_structv"), + ("jl_new_structv", "ijl_new_structv"), @augfunc(new_structv_augfwd), @revfunc(new_structv_rev), @fwdfunc(new_structv_fwd), ) register_handler!( - ("jl_new_structt","ijl_new_structt"), + ("jl_new_structt", "ijl_new_structt"), @augfunc(new_structt_augfwd), @revfunc(new_structt_rev), @fwdfunc(new_structt_fwd), @@ -1557,7 +2078,12 @@ end @fwdfunc(get_binding_or_error_fwd), ) register_handler!( - ("jl_gc_add_finalizer_th","ijl_gc_add_finalizer_th", "jl_gc_add_ptr_finalizer","ijl_gc_add_ptr_finalizer"), + ( + "jl_gc_add_finalizer_th", + "ijl_gc_add_finalizer_th", + "jl_gc_add_ptr_finalizer", + "ijl_gc_add_ptr_finalizer", + ), @augfunc(finalizer_augfwd), @revfunc(finalizer_rev), @fwdfunc(finalizer_fwd), @@ -1569,37 +2095,37 @@ end @fwdfunc(deferred_fwd), ) register_handler!( - ("jl_array_grow_end","ijl_array_grow_end"), + ("jl_array_grow_end", "ijl_array_grow_end"), @augfunc(jl_array_grow_end_augfwd), @revfunc(jl_array_grow_end_rev), @fwdfunc(jl_array_grow_end_fwd), ) register_handler!( - ("jl_array_del_end","ijl_array_del_end"), + ("jl_array_del_end", "ijl_array_del_end"), @augfunc(jl_array_del_end_augfwd), @revfunc(jl_array_del_end_rev), @fwdfunc(jl_array_del_end_fwd), ) register_handler!( - ("jl_f_getfield","ijl_f_getfield"), + ("jl_f_getfield", "ijl_f_getfield"), @augfunc(jl_getfield_augfwd), @revfunc(jl_getfield_rev), @fwdfunc(jl_getfield_fwd), ) register_handler!( - ("ijl_get_nth_field_checked","jl_get_nth_field_checked"), + ("ijl_get_nth_field_checked", "jl_get_nth_field_checked"), @augfunc(jl_nthfield_augfwd), @revfunc(jl_nthfield_rev), @fwdfunc(jl_nthfield_fwd), ) register_handler!( - ("jl_array_sizehint","ijl_array_sizehint"), + ("jl_array_sizehint", "ijl_array_sizehint"), @augfunc(jl_array_sizehint_augfwd), @revfunc(jl_array_sizehint_rev), @fwdfunc(jl_array_sizehint_fwd), ) register_handler!( - ("jl_array_ptr_copy","ijl_array_ptr_copy"), + ("jl_array_ptr_copy", "ijl_array_ptr_copy"), @augfunc(jl_array_ptr_copy_augfwd), @revfunc(jl_array_ptr_copy_rev), @fwdfunc(jl_array_ptr_copy_fwd), diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index 2964838947..d882fc2672 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -1,9 +1,30 @@ -function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, runtimeActivity::Val{RuntimeActivity}, ::Val{width}) where {FT1, FT2, World, width, RuntimeActivity} +function runtime_newtask_fwd( + world::Val{World}, + fn::FT1, + dfn::FT2, + post::Any, + ssize::Int, + runtimeActivity::Val{RuntimeActivity}, + ::Val{width}, +) where {FT1,FT2,World,width,RuntimeActivity} FT = Core.Typeof(fn) ghos = guaranteed_const(FT) opt_mi = world - forward = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) + forward = thunk( + opt_mi, + (ghos ? Const : Duplicated){FT}, + Const, + Tuple{}, + Val(API.DEM_ForwardMode), + Val(width), + Val((false,)), + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# ft = ghos ? Const(fn) : Duplicated(fn, dfn) function fclosure() res = forward(ft) @@ -13,12 +34,34 @@ function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ss return ccall(:jl_new_task, Ref{Task}, (Any, Any, Int), fclosure, post, ssize) end -function runtime_newtask_augfwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{ModifiedBetween}) where {FT1, FT2, World, width, ModifiedBetween, RuntimeActivity} +function runtime_newtask_augfwd( + world::Val{World}, + fn::FT1, + dfn::FT2, + post::Any, + ssize::Int, + runtimeActivity::Val{RuntimeActivity}, + ::Val{width}, + ::Val{ModifiedBetween}, +) where {FT1,FT2,World,width,ModifiedBetween,RuntimeActivity} # TODO make this AD subcall type stable FT = Core.Typeof(fn) ghos = guaranteed_const(FT) opt_mi = world - forward, adjoint = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) + forward, adjoint = thunk( + opt_mi, + (ghos ? Const : Duplicated){FT}, + Const, + Tuple{}, + Val(API.DEM_ReverseModePrimal), + Val(width), + Val(ModifiedBetween), + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# ft = ghos ? Const(fn) : Duplicated(fn, dfn) taperef = Ref{Any}() @@ -41,13 +84,17 @@ function runtime_newtask_augfwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, end -function referenceCaller(fn::Ref{Clos}, args...) where Clos +function referenceCaller(fn::Ref{Clos}, args...) where {Clos} fval = fn[] fval = fval::Clos fval(args...) end -function runtime_pfor_fwd(thunk::ThunkTy, ft::FT, threading_args...)::Cvoid where {ThunkTy, FT} +function runtime_pfor_fwd( + thunk::ThunkTy, + ft::FT, + threading_args..., +)::Cvoid where {ThunkTy,FT} function fwd(tid_args...) if length(tid_args) == 0 thunk(ft) @@ -59,12 +106,21 @@ function runtime_pfor_fwd(thunk::ThunkTy, ft::FT, threading_args...)::Cvoid wher return end -function runtime_pfor_augfwd(thunk::ThunkTy, ft::FT, ::Val{AnyJL}, ::Val{byRef}, threading_args...) where {ThunkTy, FT, AnyJL, byRef} +function runtime_pfor_augfwd( + thunk::ThunkTy, + ft::FT, + ::Val{AnyJL}, + ::Val{byRef}, + threading_args..., +) where {ThunkTy,FT,AnyJL,byRef} TapeType = EnzymeRules.tape_type(ThunkTy) tapes = if AnyJL Vector{TapeType}(undef, Base.Threads.nthreads()) else - Base.unsafe_convert(Ptr{TapeType}, Libc.malloc(sizeof(TapeType)*Base.Threads.nthreads())) + Base.unsafe_convert( + Ptr{TapeType}, + Libc.malloc(sizeof(TapeType) * Base.Threads.nthreads()), + ) end function fwd(tid_args...) @@ -94,7 +150,14 @@ function runtime_pfor_augfwd(thunk::ThunkTy, ft::FT, ::Val{AnyJL}, ::Val{byRef}, return tapes end -function runtime_pfor_rev(thunk::ThunkTy, ft::FT, ::Val{AnyJL}, ::Val{byRef}, tapes, threading_args...) where {ThunkTy, FT, AnyJL, byRef} +function runtime_pfor_rev( + thunk::ThunkTy, + ft::FT, + ::Val{AnyJL}, + ::Val{byRef}, + tapes, + threading_args..., +) where {ThunkTy,FT,AnyJL,byRef} function rev(tid_args...) tid = if length(tid_args) == 0 tid = Base.Threads.threadid() @@ -130,7 +193,7 @@ function runtime_pfor_rev(thunk::ThunkTy, ft::FT, ::Val{AnyJL}, ::Val{byRef}, ta return nothing end -@inline function threadsfor_common(orig, gutils, B, mode, tape=nothing) +@inline function threadsfor_common(orig, gutils, B, mode, tape = nothing) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -182,25 +245,54 @@ end width = get_width(gutils) ops = collect(operands(orig))[1:end-1] - dupClosure = !guaranteed_const_nongen(funcT, world) && !is_constant_value(gutils, ops[1]) + dupClosure = + !guaranteed_const_nongen(funcT, world) && !is_constant_value(gutils, ops[1]) pdupClosure = dupClosure subfunc = nothing if mode == API.DEM_ForwardMode if fwdmodenm === nothing etarget = Compiler.EnzymeTarget() - eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ForwardMode, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, FFIABI, #=ErrIfFuncWritten=#false, get_runtime_activity(gutils)) - ejob = Compiler.CompilerJob(mi2, CompilerConfig(etarget, eparams; kernel=false), world) + eparams = Compiler.EnzymeCompilerParams( + Tuple{(dupClosure ? Duplicated : Const){funcT},e_tt.parameters...}, + API.DEM_ForwardMode, + width, + Const{Nothing}, + true, + true, + modifiedBetween, + false, + false, + UnknownTapeType, + FFIABI, + false, + get_runtime_activity(gutils), + ) #=ErrIfFuncWritten=# + ejob = Compiler.CompilerJob( + mi2, + CompilerConfig(etarget, eparams; kernel = false), + world, + ) + + cmod, fwdmodenm, _, _ = _thunk(ejob, false) #=postopt=# - cmod, fwdmodenm, _, _ = _thunk(ejob, #=postopt=#false) - LLVM.link!(mod, cmod) push!(attributes, StringAttribute("enzymejl_forward", fwdmodenm)) - push!(function_attributes(functions(mod)[fwdmodenm]), EnumAttribute("alwaysinline")) + push!( + function_attributes(functions(mod)[fwdmodenm]), + EnumAttribute("alwaysinline"), + ) permit_inlining!(functions(mod)[fwdmodenm]) end - thunkTy = ForwardModeThunk{Ptr{Cvoid}, dupClosure ? Duplicated{funcT} : Const{funcT}, Const{Nothing}, e_tt, width, #=returnPrimal=#false} + thunkTy = ForwardModeThunk{ + Ptr{Cvoid}, + dupClosure ? Duplicated{funcT} : Const{funcT}, + Const{Nothing}, + e_tt, + width, + false, + } #=returnPrimal=# subfunc = functions(mod)[fwdmodenm] elseif mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient @@ -209,7 +301,7 @@ end has_active = ty == MixedState || ty == ActiveState if has_active refed = true - e_tt = Tuple{Duplicated{Base.RefValue{funcT}}, e_tt.parameters...} + e_tt = Tuple{Duplicated{Base.RefValue{funcT}},e_tt.parameters...} funcT = Core.Typeof(referenceCaller) dupClosure = false modifiedBetween = (false, modifiedBetween...) @@ -220,30 +312,75 @@ end if augfwdnm === nothing || adjointnm === nothing etarget = Compiler.EnzymeTarget() # TODO modifiedBetween - eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ReverseModePrimal, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, FFIABI, #=ErrIfFuncWritten=#false, get_runtime_activity(gutils)) - ejob = Compiler.CompilerJob(mi2, CompilerConfig(etarget, eparams; kernel=false), world) - - cmod, adjointnm, augfwdnm, TapeType = _thunk(ejob, #=postopt=#false) + eparams = Compiler.EnzymeCompilerParams( + Tuple{(dupClosure ? Duplicated : Const){funcT},e_tt.parameters...}, + API.DEM_ReverseModePrimal, + width, + Const{Nothing}, + true, + true, + modifiedBetween, + false, + false, + UnknownTapeType, + FFIABI, + false, + get_runtime_activity(gutils), + ) #=ErrIfFuncWritten=# + ejob = Compiler.CompilerJob( + mi2, + CompilerConfig(etarget, eparams; kernel = false), + world, + ) + + cmod, adjointnm, augfwdnm, TapeType = _thunk(ejob, false) #=postopt=# LLVM.link!(mod, cmod) push!(attributes, StringAttribute("enzymejl_augforward", augfwdnm)) - push!(function_attributes(functions(mod)[augfwdnm]), EnumAttribute("alwaysinline")) + push!( + function_attributes(functions(mod)[augfwdnm]), + EnumAttribute("alwaysinline"), + ) permit_inlining!(functions(mod)[augfwdnm]) push!(attributes, StringAttribute("enzymejl_adjoint", adjointnm)) - push!(function_attributes(functions(mod)[adjointnm]), EnumAttribute("alwaysinline")) + push!( + function_attributes(functions(mod)[adjointnm]), + EnumAttribute("alwaysinline"), + ) permit_inlining!(functions(mod)[adjointnm]) - push!(attributes, StringAttribute("enzymejl_tapetype", string(convert(UInt, unsafe_to_pointer(TapeType))))) - + push!( + attributes, + StringAttribute( + "enzymejl_tapetype", + string(convert(UInt, unsafe_to_pointer(TapeType))), + ), + ) + end if mode == API.DEM_ReverseModePrimal - thunkTy = AugmentedForwardThunk{Ptr{Cvoid}, dupClosure ? Duplicated{funcT} : Const{funcT}, Const{Nothing}, e_tt, width, #=returnPrimal=#true, TapeType} + thunkTy = AugmentedForwardThunk{ + Ptr{Cvoid}, + dupClosure ? Duplicated{funcT} : Const{funcT}, + Const{Nothing}, + e_tt, + width, + true, + TapeType, + } #=returnPrimal=# subfunc = functions(mod)[augfwdnm] - else - thunkTy = AdjointThunk{Ptr{Cvoid}, dupClosure ? Duplicated{funcT} : Const{funcT}, Const{Nothing}, e_tt, width, TapeType} + else + thunkTy = AdjointThunk{ + Ptr{Cvoid}, + dupClosure ? Duplicated{funcT} : Const{funcT}, + Const{Nothing}, + e_tt, + width, + TapeType, + } subfunc = functions(mod)[adjointnm] end else @@ -251,7 +388,7 @@ end end ppfuncT = pfuncT - dpfuncT = width == 1 ? pfuncT : NTuple{(Int)width, pfuncT} + dpfuncT = width == 1 ? pfuncT : NTuple{(Int)width,pfuncT} if refed dpfuncT = Base.RefValue{dpfuncT} @@ -263,7 +400,7 @@ end if width == 1 dfuncT = Duplicated{dfuncT} else - dfuncT = BatchDuplicated{dfuncT, Int(width)} + dfuncT = BatchDuplicated{dfuncT,Int(width)} end else dfuncT = Const{dfuncT} @@ -273,7 +410,7 @@ end alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) - ll_th = convert(LLVMType, thunkTy) + ll_th = convert(LLVMType, thunkTy) al = alloca!(alloctx, ll_th) al = addrspacecast!(B, al, LLVM.PointerType(ll_th, Tracked)) al = addrspacecast!(B, al, LLVM.PointerType(ll_th, Derived)) @@ -320,7 +457,12 @@ end val0 = v end - ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) + ptr = inbounds_gep!( + B, + llty, + al, + [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)], + ) store!(B, val0, ptr) if pdupClosure @@ -343,7 +485,8 @@ end if refed dval0 = dval = emit_allocobj!(B, dpfuncT) - dval = bitcast!(B, dval, LLVM.PointerType(spllty, addrspace(value_type(dval)))) + dval = + bitcast!(B, dval, LLVM.PointerType(spllty, addrspace(value_type(dval)))) dval = addrspacecast!(B, dval, LLVM.PointerType(spllty, Derived)) store!(B, dv, dval) if pv !== nothing @@ -356,7 +499,15 @@ end dval0 = dv end - dptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 1)]) + dptr = inbounds_gep!( + B, + llty, + al, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), 1), + ], + ) store!(B, dval0, dptr) end @@ -379,12 +530,15 @@ end end mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + shadow = + (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing - _, sname, dfuncT, vals, thunkTy, _, _ = threadsfor_common(orig, gutils, B, API.DEM_ForwardMode) + _, sname, dfuncT, vals, thunkTy, _, _ = + threadsfor_common(orig, gutils, B, API.DEM_ForwardMode) - tt = Tuple{thunkTy, dfuncT, Bool} + tt = Tuple{thunkTy,dfuncT,Bool} mode = get_mode(gutils) world = enzyme_extract_world(LLVM.parent(position(B))) entry = nested_codegen!(mode, mod, runtime_pfor_fwd, tt, world) @@ -414,12 +568,21 @@ end return true end - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + shadow = + (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing - byRef, sname, dfuncT, vals, thunkTy, _, copies = threadsfor_common(orig, gutils, B, API.DEM_ReverseModePrimal) + byRef, sname, dfuncT, vals, thunkTy, _, copies = + threadsfor_common(orig, gutils, B, API.DEM_ReverseModePrimal) - tt = Tuple{thunkTy, dfuncT, Val{any_jltypes(EnzymeRules.tape_type(thunkTy))}, Val{byRef}, Bool} + tt = Tuple{ + thunkTy, + dfuncT, + Val{any_jltypes(EnzymeRules.tape_type(thunkTy))}, + Val{byRef}, + Bool, + } mode = get_mode(gutils) world = enzyme_extract_world(LLVM.parent(position(B))) entry = nested_codegen!(mode, mod, runtime_pfor_augfwd, tt, world) @@ -459,7 +622,8 @@ end return end - byRef, sname, dfuncT, vals, thunkTy, TapeType, copies = threadsfor_common(orig, gutils, B, API.DEM_ReverseModeGradient, tape) + byRef, sname, dfuncT, vals, thunkTy, TapeType, copies = + threadsfor_common(orig, gutils, B, API.DEM_ReverseModeGradient, tape) STT = if !any_jltypes(TapeType) Ptr{TapeType} @@ -467,7 +631,14 @@ end Vector{TapeType} end - tt = Tuple{thunkTy, dfuncT, Val{any_jltypes(EnzymeRules.tape_type(thunkTy))}, Val{byRef}, STT, Bool} + tt = Tuple{ + thunkTy, + dfuncT, + Val{any_jltypes(EnzymeRules.tape_type(thunkTy))}, + Val{byRef}, + STT, + Bool, + } mode = get_mode(gutils) entry = nested_codegen!(mode, mod, runtime_pfor_rev, tt, world) push!(function_attributes(entry), EnumAttribute("alwaysinline")) @@ -499,15 +670,18 @@ end ops = collect(operands(orig)) vals = LLVM.Value[ - unsafe_to_llvm(B, runtime_newtask_fwd), - unsafe_to_llvm(B, Val(world)), - new_from_original(gutils, ops[1]), - invert_pointer(gutils, ops[1], B), - new_from_original(gutils, ops[2]), - (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)(B, new_from_original(gutils, ops[3])), - unsafe_to_llvm(B, Val(get_runtime_activity(gutils))), - unsafe_to_llvm(B, Val(width)), - ] + unsafe_to_llvm(B, runtime_newtask_fwd), + unsafe_to_llvm(B, Val(world)), + new_from_original(gutils, ops[1]), + invert_pointer(gutils, ops[1], B), + new_from_original(gutils, ops[2]), + (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)( + B, + new_from_original(gutils, ops[3]), + ), + unsafe_to_llvm(B, Val(get_runtime_activity(gutils))), + unsafe_to_llvm(B, Val(width)), + ] ntask = emit_apply_generic!(B, vals) debug_from_orig!(gutils, ntask, orig) @@ -532,9 +706,11 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing - + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + shadow = + (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -550,15 +726,19 @@ end ops = collect(operands(orig)) vals = LLVM.Value[ - unsafe_to_llvm(B, runtime_newtask_augfwd), - unsafe_to_llvm(B, Val(world)), - new_from_original(gutils, ops[1]), - invert_pointer(gutils, ops[1], B), - new_from_original(gutils, ops[2]), - (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)(B, new_from_original(gutils, ops[3])), - unsafe_to_llvm(B, Val(get_runtime_activity(gutils))), unsafe_to_llvm(B, Val(width)), - unsafe_to_llvm(B, Val(ModifiedBetween)), - ] + unsafe_to_llvm(B, runtime_newtask_augfwd), + unsafe_to_llvm(B, Val(world)), + new_from_original(gutils, ops[1]), + invert_pointer(gutils, ops[1], B), + new_from_original(gutils, ops[2]), + (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)( + B, + new_from_original(gutils, ops[3]), + ), + unsafe_to_llvm(B, Val(get_runtime_activity(gutils))), + unsafe_to_llvm(B, Val(width)), + unsafe_to_llvm(B, Val(ModifiedBetween)), + ] ntask = emit_apply_generic!(B, vals) debug_from_orig!(gutils, ntask, orig) @@ -569,12 +749,20 @@ end sret = LLVM.pointercast!(B, sret, LLVM.PointerType(AT, Derived)) if shadowR != C_NULL - shadow = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)])) + shadow = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]), + ) unsafe_store!(shadowR, shadow.ref) end if normalR != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) end @@ -596,16 +784,18 @@ end if width == 1 nops = LLVM.Value[inv, new_from_original(gutils, ops[2])] valTys = API.CValueType[API.VT_Shadow, API.VT_Primal] - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, nops, valTys, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, nops, valTys, false) #=lookup=# debug_from_orig!(gutils, cal, orig) callconv!(cal, callconv(orig)) else - for idx in 1:width - nops = LLVM.Value[extract_value(B, inv, idx-1), - new_from_original(gutils, ops[2])] + for idx = 1:width + nops = LLVM.Value[ + extract_value(B, inv, idx - 1), + new_from_original(gutils, ops[2]), + ] valTys = API.CValueType[API.VT_Shadow, API.VT_Primal] - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, nops, valTys, #=lookup=#false) - + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, nops, valTys, false) #=lookup=# + debug_from_orig!(gutils, cal, orig) callconv!(cal, callconv(orig)) end @@ -626,7 +816,8 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -663,7 +854,11 @@ end mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) waitfn = find_match(mod, "jl_wait") if waitfn === nothing - emit_error(B, orig, "Enzyme: could not find jl_wait fn to create shadow of jl_enq_work") + emit_error( + B, + orig, + "Enzyme: could not find jl_wait fn to create shadow of jl_enq_work", + ) return nothing end @assert waitfn !== nothing @@ -678,7 +873,8 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -689,7 +885,8 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -702,7 +899,11 @@ end mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) enq_work_fn = find_match(mod, "jl_enq_work") if enq_work_fn === nothing - emit_error(B, orig, "Enzyme: could not find jl_enq_work fn to create shadow of wait") + emit_error( + B, + orig, + "Enzyme: could not find jl_enq_work fn to create shadow of wait", + ) return nothing end @assert enq_work_fn !== nothing diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index 569ef87323..2a2d1032c1 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -1,12 +1,26 @@ -function int_return_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 +function int_return_rule( + direction::Cint, + ret::API.CTypeTreeRef, + args::Ptr{API.CTypeTreeRef}, + known_values::Ptr{API.IntList}, + numArgs::Csize_t, + val::LLVM.API.LLVMValueRef, +)::UInt8 TT = TypeTree(API.DT_Integer, LLVM.context(LLVM.Value(val))) only!(TT, -1) API.EnzymeMergeTypeTree(ret, TT) return UInt8(false) end -function inout_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 +function inout_rule( + direction::Cint, + ret::API.CTypeTreeRef, + args::Ptr{API.CTypeTreeRef}, + known_values::Ptr{API.IntList}, + numArgs::Csize_t, + val::LLVM.API.LLVMValueRef, +)::UInt8 if numArgs != 1 return UInt8(false) end diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 6117e464d8..500ff53d9a 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -1,4 +1,12 @@ -function body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs, tuple) +function body_construct_augfwd( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + tuple, +) shadow_rets = Vector{Expr}[] results = quote $(active_refs...) @@ -6,29 +14,35 @@ function body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batch @assert length(primtypes) == N @assert length(primargs) == N @assert length(batchshadowargs) == N - for i in 1:N + for i = 1:N @assert length(batchshadowargs[i]) == Width shadow_rets_i = Expr[] aref = Symbol("active_ref_$i") - for w in 1:Width - sref = Symbol("sub_shadow_"*string(i)*"_"*string(w)) - push!(shadow_rets_i, quote - $sref = if $aref == AnyState - $(primargs[i]); - else - if !ActivityTup[$i] - if ($aref == DupState || $aref == MixedState) && $(batchshadowargs[i][w]) === nothing - prim = $(primargs[i]) - throw("Error cannot store inactive but differentiable variable $prim into active tuple") - end - end - if $aref == DupState - $(batchshadowargs[i][w]) + for w = 1:Width + sref = Symbol("sub_shadow_" * string(i) * "_" * string(w)) + push!( + shadow_rets_i, + quote + $sref = if $aref == AnyState + $(primargs[i]) else - $(batchshadowargs[i][w])[] + if !ActivityTup[$i] + if ($aref == DupState || $aref == MixedState) && + $(batchshadowargs[i][w]) === nothing + prim = $(primargs[i]) + throw( + "Error cannot store inactive but differentiable variable $prim into active tuple", + ) + end + end + if $aref == DupState + $(batchshadowargs[i][w]) + else + $(batchshadowargs[i][w])[] + end end - end - end) + end, + ) end push!(shadow_rets, shadow_rets_i) end @@ -36,11 +50,11 @@ function body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batch refs = Expr[] ref_syms = Symbol[] res_syms = Symbol[] - for w in 1:Width + for w = 1:Width sres = Symbol("result_$w") ref_res = Symbol("ref_result_$w") combined = Expr[] - for i in 1:N + for i = 1:N push!(combined, shadow_rets[i][w]) end if tuple @@ -85,10 +99,18 @@ function body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batch end -function body_construct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs, tuple) +function body_construct_rev( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + tuple, +) outs = [] - for i in 1:N - for w in 1:Width + for i = 1:N + for w = 1:Width tsym = Symbol("tval_$w") expr = if tuple :($tsym[$i]) @@ -96,20 +118,25 @@ function body_construct_rev(N, Width, primtypes, active_refs, primargs, batchsha :(getfield($tsym, $i)) end shad = batchshadowargs[i][w] - out = :(if $(Symbol("active_ref_$i")) == MixedState || $(Symbol("active_ref_$i")) == ActiveState - if $shad isa Base.RefValue - $shad[] = recursive_add($shad[], $expr, identity, guaranteed_nonactive) - else - error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)) + out = :( + if $(Symbol("active_ref_$i")) == MixedState || + $(Symbol("active_ref_$i")) == ActiveState + if $shad isa Base.RefValue + $shad[] = recursive_add($shad[], $expr, identity, guaranteed_nonactive) + else + error( + "Enzyme Mutability Error: Cannot add one in place to immutable value " * + string($shad), + ) + end end - end ) push!(outs, out) end end - tapes = Expr[:(tval_1 = tape[])] - for w in 2:Width + tapes = Expr[:(tval_1 = tape[])] + for w = 2:Width sym = Symbol("tval_$w") df = Symbol("df_$w") push!(tapes, :($sym = $df[])) @@ -131,87 +158,226 @@ function body_runtime_tuple_rev(N, Width, primtypes, active_refs, primargs, batc body_construct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs, true) end -function body_runtime_newstruct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) +function body_runtime_newstruct_rev( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, +) body_construct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs, false) end -function body_runtime_tuple_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) +function body_runtime_tuple_augfwd( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, +) body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs, true) end function func_runtime_tuple_augfwd(N, Width) - primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; func=false, mixed_or_active=true) - body = body_runtime_tuple_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width; func = false, mixed_or_active = true) + body = body_runtime_tuple_augfwd( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + ) quote - function runtime_tuple_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, $(typeargs...)} + function runtime_tuple_augfwd( + activity::Type{Val{ActivityTup}}, + width::Val{$Width}, + ModifiedBetween::Val{MB}, + RT::Val{ReturnType}, + $(allargs...), + )::ReturnType where {ActivityTup,MB,ReturnType,$(typeargs...)} $body end end end -@generated function runtime_tuple_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType} +@generated function runtime_tuple_augfwd( + activity::Type{Val{ActivityTup}}, + width::Val{Width}, + ModifiedBetween::Val{MB}, + RT::Val{ReturnType}, + allargs..., +)::ReturnType where {ActivityTup,MB,Width,ReturnType} N = div(length(allargs), Width) - primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; func=false, mixed_or_active=true) - return body_runtime_tuple_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width, :allargs; func = false, mixed_or_active = true) + return body_runtime_tuple_augfwd( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + ) end function func_runtime_tuple_rev(N, Width) - primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; mixed_or_active=true) - body = body_runtime_tuple_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width; mixed_or_active = true) + body = + body_runtime_tuple_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) quote - function runtime_tuple_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, $(allargs...)) where {ActivityTup, MB, TapeType, $(typeargs...)} + function runtime_tuple_rev( + activity::Type{Val{ActivityTup}}, + width::Val{$Width}, + ModifiedBetween::Val{MB}, + tape::TapeType, + $(allargs...), + ) where {ActivityTup,MB,TapeType,$(typeargs...)} $body end end end -@generated function runtime_tuple_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, allargs...) where {ActivityTup, MB, Width, TapeType} - N = div(length(allargs)-(Width-1), Width) - primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; mixed_or_active=true) - return body_runtime_tuple_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) -end - - -function body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) - body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs, false) +@generated function runtime_tuple_rev( + activity::Type{Val{ActivityTup}}, + width::Val{Width}, + ModifiedBetween::Val{MB}, + tape::TapeType, + allargs..., +) where {ActivityTup,MB,Width,TapeType} + N = div(length(allargs) - (Width - 1), Width) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width, :allargs; mixed_or_active = true) + return body_runtime_tuple_rev( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + ) +end + + +function body_runtime_newstruct_augfwd( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, +) + body_construct_augfwd( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + false, + ) end function func_runtime_newstruct_augfwd(N, Width) - primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; mixed_or_active=true) - body = body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width; mixed_or_active = true) + body = body_runtime_newstruct_augfwd( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + ) quote - function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, ::Type{NewType}, RT::Val{ReturnType}, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, NewType, $(typeargs...)} + function runtime_newstruct_augfwd( + activity::Type{Val{ActivityTup}}, + width::Val{$Width}, + ModifiedBetween::Val{MB}, + ::Type{NewType}, + RT::Val{ReturnType}, + $(allargs...), + )::ReturnType where {ActivityTup,MB,ReturnType,NewType,$(typeargs...)} $body end end end -@generated function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, ::Type{NewType}, RT::Val{ReturnType}, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, NewType} - N = div(length(allargs)+2, Width+1)-1 - primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; mixed_or_active=true) - return body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) +@generated function runtime_newstruct_augfwd( + activity::Type{Val{ActivityTup}}, + width::Val{Width}, + ModifiedBetween::Val{MB}, + ::Type{NewType}, + RT::Val{ReturnType}, + allargs..., +)::ReturnType where {ActivityTup,MB,Width,ReturnType,NewType} + N = div(length(allargs) + 2, Width + 1) - 1 + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width, :allargs; mixed_or_active = true) + return body_runtime_newstruct_augfwd( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + ) end function func_runtime_newstruct_rev(N, Width) - primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; mixed_or_active=true) - body = body_runtime_newstruct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width; mixed_or_active = true) + body = body_runtime_newstruct_rev( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + ) quote - function runtime_newstruct_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, ::Type{NewStruct}, tape::TapeType, $(allargs...)) where {ActivityTup, MB, NewStruct, TapeType, $(typeargs...)} + function runtime_newstruct_rev( + activity::Type{Val{ActivityTup}}, + width::Val{$Width}, + ModifiedBetween::Val{MB}, + ::Type{NewStruct}, + tape::TapeType, + $(allargs...), + ) where {ActivityTup,MB,NewStruct,TapeType,$(typeargs...)} $body end end end -@generated function runtime_newstruct_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, ::Type{NewStruct}, tape::TapeType, allargs...) where {ActivityTup, MB, Width, NewStruct, TapeType} - N = div(length(allargs)-(Width-1), Width) - primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; mixed_or_active=true) - return body_runtime_newstruct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) +@generated function runtime_newstruct_rev( + activity::Type{Val{ActivityTup}}, + width::Val{Width}, + ModifiedBetween::Val{MB}, + ::Type{NewStruct}, + tape::TapeType, + allargs..., +) where {ActivityTup,MB,Width,NewStruct,TapeType} + N = div(length(allargs) - (Width - 1), Width) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width, :allargs; mixed_or_active = true) + return body_runtime_newstruct_rev( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + ) end for (N, Width) in Iterators.product(0:30, 1:10) @@ -235,7 +401,8 @@ function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR) abs = [abs_typeof(v) for v in origops[offset+1:end-1]] @assert length(icvs) == length(abs) - for (icv, (found_partial, typ_partial), (found, typ)) in zip(icvs, abs_partial, abs) + for (icv, (found_partial, typ_partial, byref_partial), (found, typ, byref)) in + zip(icvs, abs_partial, abs) # Constants not handled unless known inactive from type if icv if !found_partial @@ -251,7 +418,8 @@ function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR) if !found_partial return false end - act = active_reg_inner(typ_partial, (), world, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) + act = + active_reg_inner(typ_partial, (), world, Val(false), Val(false), Val(true)) #=abstractismixed=# if act == MixedState || act == ActiveState return false end @@ -262,7 +430,7 @@ function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR) return true end - shadowsin = LLVM.Value[invert_pointer(gutils, o, B) for o in origops[offset:end-1] ] + shadowsin = LLVM.Value[invert_pointer(gutils, o, B) for o in origops[offset:end-1]] if width == 1 if offset != 1 pushfirst!(shadowsin, origops[1]) @@ -270,17 +438,16 @@ function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR) shadowres = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), shadowsin) callconv!(shadowres, callconv(orig)) else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width - args = LLVM.Value[ - extract_value!(B, s, idx-1) for s in shadowsin - ] + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + args = LLVM.Value[extract_value!(B, s, idx - 1) for s in shadowsin] if offset != 1 pushfirst!(args, origops[1]) end tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args) callconv!(tmp, callconv(orig)) - shadowres = insert_value!(B, shadowres, tmp, idx-1) + shadowres = insert_value!(B, shadowres, tmp, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -291,17 +458,35 @@ end function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end - if !newstruct_common(#=fwd=#true, #=run=#true, offset, B, orig, gutils, normalR, shadowR) + if !newstruct_common(true, true, offset, B, orig, gutils, normalR, shadowR) #=run=# origops = collect(operands(orig)) abs_partial = [abs_typeof(v, true) for v in origops[offset+1:end-1]] icvs = [is_constant_value(gutils, v) for v in origops[offset+1:end-1]] - emit_error(B, orig, "Enzyme: Not yet implemented, mixed activity for jl_new_struct constants="*string(icvs)*" "*string(orig)*" "*string(abs_partial)*" "*string([v for v in origops[offset+1:end-1]])) + emit_error( + B, + orig, + "Enzyme: Not yet implemented, mixed activity for jl_new_struct constants=" * + string(icvs) * + " " * + string(orig) * + " " * + string(abs_partial) * + " " * + string([v for v in origops[offset+1:end-1]]), + ) end return false @@ -310,15 +495,26 @@ end function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)::Bool needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end - if !newstruct_common(#=fwd=#false, #=run=#true, offset, B, orig, gutils, normalR, shadowR) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + if !newstruct_common(false, true, offset, B, orig, gutils, normalR, shadowR) #=run=# + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : + nothing + shadow = + (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : + nothing T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -326,8 +522,20 @@ function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tap width = get_width(gutils) - sret = generic_setup(orig, runtime_newstruct_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset, B, false; firstconst=true, endcast = false, firstconst_after_tape=true, runtime_activity=false) - + sret = generic_setup( + orig, + runtime_newstruct_augfwd, + width == 1 ? Any : AnyArray(Int(width)), + gutils, + offset, + B, + false; + firstconst = true, + endcast = false, + firstconst_after_tape = true, + runtime_activity = false, + ) #=start=# + if width == 1 shadow = sret else @@ -338,10 +546,15 @@ function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tap cal = LLVM.pointercast!(B, cal, LLVM.PointerType(llty, Derived)) ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, cal, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + cal, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i - 1)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) @@ -359,18 +572,36 @@ function common_newstructv_rev(offset, B, orig, gutils, tape) end needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) needsPrimal = needsPrimalP[] != 0 needsShadow = needsShadowP[] != 0 - if !needsShadow - return - end + if !needsShadow + return + end - if !newstruct_common(#=fwd=#false, #=run=#false, offset, B, orig, gutils, #=normalR=#nothing, #=shadowR=#nothing) + if !newstruct_common(false, false, offset, B, orig, gutils, nothing, nothing) #=shadowR=# @assert tape !== C_NULL width = get_width(gutils) - generic_setup(orig, runtime_newstruct_rev, Nothing, gutils, #=start=#offset, B, true; firstconst=true, tape, firstconst_after_tape=true, runtime_activity=false) + generic_setup( + orig, + runtime_newstruct_rev, + Nothing, + gutils, + offset, + B, + true; + firstconst = true, + tape, + firstconst_after_tape = true, + runtime_activity = false, + ) #=start=# end return nothing @@ -383,15 +614,25 @@ end function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)::Bool needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if is_constant_value(gutils, orig) || needsShadowP[] == 0 + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if is_constant_value(gutils, orig) || needsShadowP[] == 0 return true end - if !newstruct_common(#=fwd=#false, #=run=#true, offset, B, orig, gutils, normalR, shadowR) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + if !newstruct_common(false, true, offset, B, orig, gutils, normalR, shadowR) #=run=# + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : + nothing + shadow = + (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : + nothing T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -399,8 +640,18 @@ function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) width = get_width(gutils) - sret = generic_setup(orig, runtime_tuple_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset+1, B, false; endcast = false, runtime_activity=false) - + sret = generic_setup( + orig, + runtime_tuple_augfwd, + width == 1 ? Any : AnyArray(Int(width)), + gutils, + offset + 1, + B, + false; + endcast = false, + runtime_activity = false, + ) #=start=# + if width == 1 shadow = sret else @@ -411,10 +662,15 @@ function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) cal = LLVM.pointercast!(B, cal, LLVM.PointerType(llty, Derived)) ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, cal, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + cal, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i - 1)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) @@ -428,7 +684,13 @@ end function common_f_tuple_rev(offset, B, orig, gutils, tape) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) needsPrimal = needsPrimalP[] != 0 needsShadow = needsShadowP[] != 0 @@ -440,7 +702,7 @@ function common_f_tuple_rev(offset, B, orig, gutils, tape) return true end - if !newstruct_common(#=fwd=#false, #=run=#false, offset, B, orig, gutils, #=normalR=#nothing, #=shadowR=#nothing) + if !newstruct_common(false, false, offset, B, orig, gutils, nothing, nothing) #=shadowR=# @assert tape !== C_NULL width = get_width(gutils) tape2 = if width != 1 @@ -456,8 +718,13 @@ function common_f_tuple_rev(offset, B, orig, gutils, tape) cal = LLVM.pointercast!(B, cal, LLVM.PointerType(llty, Derived)) ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, cal, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + cal, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i - 1)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) push!(res, ld) end @@ -465,7 +732,17 @@ function common_f_tuple_rev(offset, B, orig, gutils, tape) else tape end - generic_setup(orig, runtime_tuple_rev, Nothing, gutils, #=start=#offset+1, B, true; tape=tape2, runtime_activity=false) + generic_setup( + orig, + runtime_tuple_rev, + Nothing, + gutils, + offset + 1, + B, + true; + tape = tape2, + runtime_activity = false, + ) #=start=# end return nothing end @@ -506,7 +783,12 @@ end @assert is_constant_value(gutils, origops[1]) if is_constant_value(gutils, origops[2]) - emit_error(B, orig, "Enzyme: Not yet implemented, mixed activity for jl_new_struct_t"*string(orig)) + emit_error( + B, + orig, + "Enzyme: Not yet implemented, mixed activity for jl_new_struct_t" * + string(orig), + ) end shadowsin = invert_pointer(gutils, origops[2], B) @@ -515,12 +797,16 @@ end shadowres = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), vals) callconv!(shadowres, callconv(orig)) else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width - vals = [new_from_original(gutils, origops[1]), extract_value!(B, shadowsin, idx-1)] + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + vals = [ + new_from_original(gutils, origops[1]), + extract_value!(B, shadowsin, idx - 1), + ] tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args) callconv!(tmp, callconv(orig)) - shadowres = insert_value!(B, shadowres, tmp, idx-1) + shadowres = insert_value!(B, shadowres, tmp, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -537,14 +823,24 @@ end end needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) needsPrimal = needsPrimalP[] != 0 needsShadow = needsShadowP[] != 0 - if !needsShadow - return - end - emit_error(B, orig, "Enzyme: Not yet implemented reverse for jl_new_structt "*string(orig)) + if !needsShadow + return + end + emit_error( + B, + orig, + "Enzyme: Not yet implemented reverse for jl_new_structt " * string(orig), + ) return nothing end @@ -568,9 +864,13 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) shadowres = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args) callconv!(shadowres, callconv(orig)) else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width - args = LLVM.Value[new_from_original(gutils, origops[1]), extract_value!(B, shadowin, idx-1)] + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + args = LLVM.Value[ + new_from_original(gutils, origops[1]), + extract_value!(B, shadowin, idx - 1), + ] for a in origops[3:end-1] push!(args, new_from_original(gutils, a)) end @@ -579,7 +879,7 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) end tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args) callconv!(tmp, callconv(orig)) - shadowres = insert_value!(B, shadowres, tmp, idx-1) + shadowres = insert_value!(B, shadowres, tmp, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -588,9 +888,11 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) if width == 1 shadowres = normal else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal)))) - for idx in 1:width - shadowres = insert_value!(B, shadowres, normal, idx-1) + shadowres = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal))), + ) + for idx = 1:width + shadowres = insert_value!(B, shadowres, normal, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -598,7 +900,7 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) return false end -@generated function ntuple_ref_zero(::Val{N}, ::Type{RT}, res) where {N, RT} +@generated function ntuple_ref_zero(::Val{N}, ::Type{RT}, res) where {N,RT} expr = Vector{Expr}(undef, N) fill!(expr, :(Ref{$RT}(make_zero(res)))) return quote @@ -607,9 +909,9 @@ end end end -@generated function ntuple_ref_lookup(::Val{N}, ::Type{RT}, dptrs, symname) where {N, RT} +@generated function ntuple_ref_lookup(::Val{N}, ::Type{RT}, dptrs, symname) where {N,RT} expr = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N @inbounds expr[i] = quote begin dv = dptrs[$i] @@ -625,7 +927,7 @@ end @generated function ntuple_lookup(::Val{N}, ptrs, symname) where {N} expr = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N @inbounds expr[i] = quote begin dv = ptrs[$i] @@ -639,11 +941,17 @@ end end end -function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {NT, T, T2, Nargs, symname, isconst} +function rt_jl_getfield_aug( + ::Val{NT}, + dptr::T, + ::Type{Val{symname}}, + ::Val{isconst}, + dptrs::Vararg{T2,Nargs}, +) where {NT,T,T2,Nargs,symname,isconst} res = if dptr isa Base.RefValue - Base.getfield(dptr[], symname) + Base.getfield(dptr[], symname) else - Base.getfield(dptr, symname) + Base.getfield(dptr, symname) end RT = Core.Typeof(res) @@ -652,13 +960,16 @@ function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isco if length(dptrs) == 0 return Ref{RT}(make_zero(res)) else - return NT(ntuple_ref_zero(Val(1+length(dptrs)), RT, res)) + return NT(ntuple_ref_zero(Val(1 + length(dptrs)), RT, res)) end elseif actreg == MixedState if length(dptrs) == 0 return Ref{RT}(res) else - fval = NT((Ref{RT}(res), ntuple_ref_lookup(Val(length(dptrs)), RT, dptrs, symname)...)) + fval = NT(( + Ref{RT}(res), + ntuple_ref_lookup(Val(length(dptrs)), RT, dptrs, symname)..., + )) return fval end elseif isconst @@ -678,11 +989,17 @@ function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isco end end -function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {NT, T, T2, Nargs, symname, isconst} +function idx_jl_getfield_aug( + ::Val{NT}, + dptr::T, + ::Type{Val{symname}}, + ::Val{isconst}, + dptrs::Vararg{T2,Nargs}, +) where {NT,T,T2,Nargs,symname,isconst} res = if dptr isa Base.RefValue - Base.getfield(dptr[], symname+1) + Base.getfield(dptr[], symname + 1) else - Base.getfield(dptr, symname+1) + Base.getfield(dptr, symname + 1) end RT = Core.Typeof(res) actreg = active_reg_nothrow(RT, Val(nothing)) @@ -690,13 +1007,16 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc if length(dptrs) == 0 return Ref{RT}(make_zero(res))::Any else - return NT(ntuple_ref_zero(Val(1+length(dptrs)), RT, res)) + return NT(ntuple_ref_zero(Val(1 + length(dptrs)), RT, res)) end elseif actreg == MixedState if length(dptrs) == 0 return Ref{RT}(res) else - fval = NT((Ref{RT}(res), ntuple_ref_lookup(Val(length(dptrs)), RT, dptrs, symname+1)...)) + fval = NT(( + Ref{RT}(res), + ntuple_ref_lookup(Val(length(dptrs)), RT, dptrs, symname + 1)..., + )) return fval end elseif isconst @@ -710,16 +1030,21 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc if length(dptrs) == 0 return res::Any else - fval = NT((res, ntuple_lookup(Val(length(dptrs)), dptrs, symname+1)...)) + fval = NT((res, ntuple_lookup(Val(length(dptrs)), dptrs, symname + 1)...)) return fval end end end -@generated function recursive_field_add(::Type{dRT}, vload, ::Val{symname}, dret) where {dRT, symname} +@generated function recursive_field_add( + ::Type{dRT}, + vload, + ::Val{symname}, + dret, +) where {dRT,symname} N = fieldcount(dRT) exprs = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N @inbounds exprs[i] = if fieldname(dRT, i) == symname :(recursive_add(getfield(vload, $i), dret, identity, guaranteed_nonactive)) else @@ -733,11 +1058,17 @@ end end end -function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst} +function rt_jl_getfield_rev( + dptr::T, + dret, + ::Type{Val{symname}}, + ::Val{isconst}, + dptrs::Vararg{T2,Nargs}, +) where {T,T2,Nargs,symname,isconst} cur = if dptr isa Base.RefValue - getfield(dptr[], symname) + getfield(dptr[], symname) else - getfield(dptr, symname) + getfield(dptr, symname) end RT = Core.Typeof(cur) @@ -750,7 +1081,11 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dRT = Core.Typeof(vload) dptr[] = recursive_field_add(dRT, vload, Val(symname), dret[]) else - setfield!(dptr, symname, recursive_add(cur, dret[], identity, guaranteed_nonactive)) + setfield!( + dptr, + symname, + recursive_add(cur, dret[], identity, guaranteed_nonactive), + ) end else if dptr isa Base.RefValue @@ -760,18 +1095,22 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, else setfield!(dptr, symname, recursive_add(cur, dret[1][])) end - for i in 1:length(dptrs) + for i = 1:length(dptrs) if dptrs[i] isa Base.RefValue vload = dptrs[i][] dRT = Core.Typeof(vload) dptrs[i][] = recursive_field_add(dRT, vload, Val(symname), dret[1+i][]) else curi = if dptr isa Base.RefValue - Base.getfield(dptrs[i][], symname) + Base.getfield(dptrs[i][], symname) else - Base.getfield(dptrs[i], symname) + Base.getfield(dptrs[i], symname) end - setfield!(dptrs[i], symname, recursive_add(curi, dret[1+i][], identity, guaranteed_nonactive)) + setfield!( + dptrs[i], + symname, + recursive_add(curi, dret[1+i][], identity, guaranteed_nonactive), + ) end end end @@ -779,10 +1118,15 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, return nothing end -@generated function recursive_index_add(::Type{dRT}, vload, ::Val{symname}, dret) where {dRT, symname} +@generated function recursive_index_add( + ::Type{dRT}, + vload, + ::Val{symname}, + dret, +) where {dRT,symname} N = fieldcount(dRT) exprs = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N @inbounds exprs[i] = if i == symname :(recursive_add(getfield(vload, $i), dret, identity, guaranteed_nonactive)) else @@ -796,11 +1140,17 @@ end end end -function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst} +function idx_jl_getfield_rev( + dptr::T, + dret, + ::Type{Val{symname}}, + ::Val{isconst}, + dptrs::Vararg{T2,Nargs}, +) where {T,T2,Nargs,symname,isconst} cur = if dptr isa Base.RefValue - Base.getfield(dptr[], symname+1) + Base.getfield(dptr[], symname + 1) else - Base.getfield(dptr, symname+1) + Base.getfield(dptr, symname + 1) end RT = Core.Typeof(cur) @@ -811,30 +1161,43 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} if dptr isa Base.RefValue vload = dptr[] dRT = Core.Typeof(vload) - dptr[] = recursive_index_add(dRT, vload, Val(symname+1), dret[]) + dptr[] = recursive_index_add(dRT, vload, Val(symname + 1), dret[]) else - setfield!(dptr, symname+1, recursive_add(cur, dret[], identity, guaranteed_nonactive)) + setfield!( + dptr, + symname + 1, + recursive_add(cur, dret[], identity, guaranteed_nonactive), + ) end else if dptr isa Base.RefValue vload = dptr[] dRT = Core.Typeof(vload) - dptr[] = recursive_index_add(dRT, vload, Val(symname+1), dret[1][]) + dptr[] = recursive_index_add(dRT, vload, Val(symname + 1), dret[1][]) else - setfield!(dptr, symname+1, recursive_add(cur, dret[1][], identity, guaranteed_nonactive)) + setfield!( + dptr, + symname + 1, + recursive_add(cur, dret[1][], identity, guaranteed_nonactive), + ) end - for i in 1:length(dptrs) + for i = 1:length(dptrs) if dptrs[i] isa Base.RefValue vload = dptrs[i][] dRT = Core.Typeof(vload) - dptrs[i][] = recursive_index_add(dRT, vload, Val(symname+1), dret[1+i][]) + dptrs[i][] = + recursive_index_add(dRT, vload, Val(symname + 1), dret[1+i][]) else curi = if dptr isa Base.RefValue - Base.getfield(dptrs[i][], symname+1) + Base.getfield(dptrs[i][], symname + 1) else - Base.getfield(dptrs[i], symname+1) + Base.getfield(dptrs[i], symname + 1) end - setfield!(dptrs[i], symname+1, recursive_add(curi, dret[1+i][], identity, guaranteed_nonactive)) + setfield!( + dptrs[i], + symname + 1, + recursive_add(curi, dret[1+i][], identity, guaranteed_nonactive), + ) end end end @@ -862,8 +1225,8 @@ function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, ta inps = [inp] else inps = LLVM.Value[] - for w in 1:width - push!(inps, extract_value!(B, inp, w-1)) + for w = 1:width + push!(inps, extract_value!(B, inp, w - 1)) end end else @@ -899,18 +1262,23 @@ function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, ta if !is_constant_value(gutils, ops[2]) forgep = LLVM.addrspacecast!(B, forgep, LLVM.PointerType(T_jlvalue, Derived)) forgep = LLVM.pointercast!(B, forgep, LLVM.PointerType(AT, Derived)) - end + end ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width + for i = 1:width if !is_constant_value(gutils, ops[2]) - gep = LLVM.inbounds_gep!(B, AT, forgep, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) + gep = LLVM.inbounds_gep!( + B, + AT, + forgep, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i - 1)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) else ld = forgep end - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end shadowres = shadow end @@ -927,7 +1295,13 @@ function common_jl_getfield_rev(offset, B, orig, gutils, tape) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) if needsShadowP[] == 0 return end @@ -936,7 +1310,7 @@ function common_jl_getfield_rev(offset, B, orig, gutils, tape) width = get_width(gutils) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) - + if !is_constant_value(gutils, ops[2]) inp = invert_pointer(gutils, ops[2], B) inp = lookup_value(gutils, inp, B) @@ -944,8 +1318,8 @@ function common_jl_getfield_rev(offset, B, orig, gutils, tape) inps = [inp] else inps = LLVM.Value[] - for w in 1:width - push!(inps, extract_value!(B, inp, w-1)) + for w = 1:width + push!(inps, extract_value!(B, inp, w - 1)) end end else @@ -988,21 +1362,22 @@ end shadowin = invert_pointer(gutils, origops[1], B) if width == 1 args = LLVM.Value[ - shadowin - new_from_original(gutils, origops[2]) - ] + shadowin + new_from_original(gutils, origops[2]) + ] shadowres = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args) callconv!(shadowres, callconv(orig)) else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width args = LLVM.Value[ - extract_value!(B, shadowin, idx-1) - new_from_original(gutils, origops[2]) - ] + extract_value!(B, shadowin, idx - 1) + new_from_original(gutils, origops[2]) + ] tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args) callconv!(tmp, callconv(orig)) - shadowres = insert_value!(B, shadowres, tmp, idx-1) + shadowres = insert_value!(B, shadowres, tmp, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -1011,9 +1386,11 @@ end if width == 1 shadowres = normal else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal)))) - for idx in 1:width - shadowres = insert_value!(B, shadowres, normal, idx-1) + shadowres = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal))), + ) + for idx = 1:width + shadowres = insert_value!(B, shadowres, normal, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -1029,7 +1406,7 @@ end width = get_width(gutils) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) - + T_int8 = LLVM.Int8Type() T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -1040,8 +1417,8 @@ end inps = [inp] else inps = LLVM.Value[] - for w in 1:width - push!(inps, extract_value!(B, inp, w-1)) + for w = 1:width + push!(inps, extract_value!(B, inp, w - 1)) end end else @@ -1077,18 +1454,23 @@ end if !is_constant_value(gutils, ops[1]) forgep = LLVM.addrspacecast!(B, forgep, LLVM.PointerType(T_jlvalue, Derived)) forgep = LLVM.pointercast!(B, forgep, LLVM.PointerType(AT, Derived)) - end + end ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width + for i = 1:width if !is_constant_value(gutils, ops[1]) - gep = LLVM.inbounds_gep!(B, AT, forgep, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) + gep = LLVM.inbounds_gep!( + B, + AT, + forgep, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i - 1)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) else ld = forgep end - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end shadowres = shadow end @@ -1104,19 +1486,25 @@ end needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) needsPrimal = needsPrimalP[] != 0 needsShadow = needsShadowP[] != 0 - if !needsShadow - return - end + if !needsShadow + return + end ops = collect(operands(orig)) width = get_width(gutils) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) - + if !is_constant_value(gutils, ops[1]) inp = invert_pointer(gutils, ops[1], B) inp = lookup_value(gutils, inp, B) @@ -1124,8 +1512,8 @@ end inps = [inp] else inps = LLVM.Value[] - for w in 1:width - push!(inps, extract_value!(B, inp, w-1)) + for w = 1:width + push!(inps, extract_value!(B, inp, w - 1)) end end else @@ -1170,7 +1558,8 @@ end end function common_setfield_fwd(offset, B, orig, gutils, normalR, shadowR) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -1188,34 +1577,48 @@ function common_setfield_fwd(offset, B, orig, gutils, normalR, shadowR) shadowout = invert_pointer(gutils, origops[4], B) if width == 1 args = LLVM.Value[ - new_from_original(gutils, origops[1]) - shadowin - new_from_original(gutils, origops[3]) - shadowout - ] - valTys = API.CValueType[API.VT_Primal, API.VT_Shadow, API.VT_Primal, API.VT_Shadow] + new_from_original(gutils, origops[1]) + shadowin + new_from_original(gutils, origops[3]) + shadowout + ] + valTys = + API.CValueType[API.VT_Primal, API.VT_Shadow, API.VT_Primal, API.VT_Shadow] if offset != 1 pushfirst!(args, first(operands(orig))) pushfirst!(valTys, API.VT_Primal) end - shadowres = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, #=lookup=#false) + shadowres = + call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, false) #=lookup=# callconv!(shadowres, callconv(orig)) else - for idx in 1:width + for idx = 1:width args = LLVM.Value[ - new_from_original(gutils, origops[1]) - extract_value!(B, shadowin, idx-1) - new_from_original(gutils, origops[3]) - extract_value!(B, shadowout, idx-1) - ] - valTys = API.CValueType[API.VT_Primal, API.VT_Shadow, API.VT_Primal, API.VT_Shadow] + new_from_original(gutils, origops[1]) + extract_value!(B, shadowin, idx - 1) + new_from_original(gutils, origops[3]) + extract_value!(B, shadowout, idx - 1) + ] + valTys = API.CValueType[ + API.VT_Primal, + API.VT_Shadow, + API.VT_Primal, + API.VT_Shadow, + ] if offset != 1 pushfirst!(args, first(operands(orig))) pushfirst!(valTys, API.VT_Primal) end - tmp = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, #=lookup=#false) + tmp = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + valTys, + false, + ) #=lookup=# callconv!(tmp, callconv(orig)) end @@ -1225,7 +1628,7 @@ function common_setfield_fwd(offset, B, orig, gutils, normalR, shadowR) end -function rt_jl_setfield_aug(dptr::T, idx, ::Val{isconst}, val, dval) where {T, isconst} +function rt_jl_setfield_aug(dptr::T, idx, ::Val{isconst}, val, dval) where {T,isconst} RT = Core.Typeof(val) if active_reg(RT) setfield!(dptr, idx, make_zero(val)) @@ -1234,7 +1637,7 @@ function rt_jl_setfield_aug(dptr::T, idx, ::Val{isconst}, val, dval) where {T, i end end -function rt_jl_setfield_rev(dptr::T, idx, ::Val{isconst}, val, dval) where {T, isconst} +function rt_jl_setfield_rev(dptr::T, idx, ::Val{isconst}, val, dval) where {T,isconst} RT = Core.Typeof(val) if active_reg(RT) && !isconst dval[] += getfield(dptr, idx) @@ -1244,7 +1647,8 @@ end function common_setfield_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -1263,15 +1667,16 @@ function common_setfield_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) - for idx in 1:width + for idx = 1:width vals = LLVM.Value[ - (width == 1) ? shadowstruct : extract_value!(B, shadowstruct, idx-1), - new_from_original(gutils, origops[3]), - unsafe_to_llvm(B, Val(is_constant_value(gutils, origops[4]))), - new_from_original(gutils, origops[4]), - is_constant_value(gutils, origops[4]) ? unsafe_to_llvm(B, nothing) : ((width == 1) ? shadowval : extract_value!(B, shadowval, idx-1)), + (width == 1) ? shadowstruct : extract_value!(B, shadowstruct, idx - 1), + new_from_original(gutils, origops[3]), + unsafe_to_llvm(B, Val(is_constant_value(gutils, origops[4]))), + new_from_original(gutils, origops[4]), + is_constant_value(gutils, origops[4]) ? unsafe_to_llvm(B, nothing) : + ((width == 1) ? shadowval : extract_value!(B, shadowval, idx - 1)), ] - + pushfirst!(vals, unsafe_to_llvm(B, rt_jl_setfield_aug)) cal = emit_apply_generic!(B, vals) @@ -1295,18 +1700,27 @@ function common_setfield_rev(offset, B, orig, gutils, tape) else nothing end - + mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) - for idx in 1:width + for idx = 1:width vals = LLVM.Value[ - lookup_value(gutils, (width == 1) ? shadowstruct : extract_value!(B, shadowstruct, idx-1), B), - lookup_value(gutils, new_from_original(gutils, origops[3]), B), - unsafe_to_llvm(B, Val(is_constant_value(gutils, origops[4]))), - lookup_value(gutils, new_from_original(gutils, origops[4]), B), - is_constant_value(gutils, origops[4]) ? unsafe_to_llvm(B, nothing) : lookup_value(gutils, ((width == 1) ? shadowval : extract_value!(B, shadowval, idx-1)), B), + lookup_value( + gutils, + (width == 1) ? shadowstruct : extract_value!(B, shadowstruct, idx - 1), + B, + ), + lookup_value(gutils, new_from_original(gutils, origops[3]), B), + unsafe_to_llvm(B, Val(is_constant_value(gutils, origops[4]))), + lookup_value(gutils, new_from_original(gutils, origops[4]), B), + is_constant_value(gutils, origops[4]) ? unsafe_to_llvm(B, nothing) : + lookup_value( + gutils, + ((width == 1) ? shadowval : extract_value!(B, shadowval, idx - 1)), + B, + ), ] - + pushfirst!(vals, unsafe_to_llvm(B, rt_jl_setfield_rev)) cal = emit_apply_generic!(B, vals) @@ -1314,7 +1728,7 @@ function common_setfield_rev(offset, B, orig, gutils, tape) debug_from_orig!(gutils, cal, orig) end end - return nothing + return nothing end @@ -1330,7 +1744,7 @@ end common_setfield_rev(1, B, orig, gutils, tape) end -function error_if_differentiable(::Type{T}) where T +function error_if_differentiable(::Type{T}) where {T} seen = () areg = active_reg_inner(T, seen, nothing) if areg != AnyState @@ -1349,13 +1763,13 @@ function common_f_svec_ref_fwd(offset, B, orig, gutils, normalR, shadowR) origmi, origh, origkey = operands(orig)[offset:end-1] shadowh = invert_pointer(gutils, origh, B) - + newvals = API.CValueType[API.VT_Primal, API.VT_Shadow, API.VT_Primal] if offset != 1 pushfirst!(newvals, API.VT_Primal) end - + mi = new_from_original(gutils, origmi) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -1365,27 +1779,50 @@ function common_f_svec_ref_fwd(offset, B, orig, gutils, normalR, shadowR) if offset != 1 pushfirst!(newops, operands(orig)[1]) end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, false) #=lookup=# callconv!(cal, callconv(orig)) - + if is_constant_value(gutils, origh) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_differentiable), emit_jltypeof!(B, cal)]) + emit_apply_generic!( + B, + LLVM.Value[ + unsafe_to_llvm(B, error_if_differentiable), + emit_jltypeof!(B, cal), + ], + ) end cal else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for j in 1:width - newops = LLVM.Value[mi, extract_value!(B, shadowh, j-1), new_from_original(gutils, origkey)] + for j = 1:width + newops = LLVM.Value[ + mi, + extract_value!(B, shadowh, j - 1), + new_from_original(gutils, origkey), + ] if offset != 1 pushfirst!(newops, operands(orig)[1]) end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + newops, + newvals, + false, + ) #=lookup=# callconv!(cal, callconv(orig)) if is_constant_value(gutils, origh) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_differentiable), emit_jltypeof!(B, cal)]) + emit_apply_generic!( + B, + LLVM.Value[ + unsafe_to_llvm(B, error_if_differentiable), + emit_jltypeof!(B, cal), + ], + ) end - shadow = insert_value!(B, shadow, cal, j-1) + shadow = insert_value!(B, shadow, cal, j - 1) end shadow end @@ -1405,19 +1842,19 @@ function common_f_svec_ref_augfwd(offset, B, orig, gutils, normalR, shadowR, tap origmi, origh, origkey = operands(orig)[offset:end-1] shadowh = invert_pointer(gutils, origh, B) - + newvals = API.CValueType[API.VT_Primal, API.VT_Shadow, API.VT_Primal] if offset != 1 pushfirst!(newvals, API.VT_Primal) end - + errfn = if is_constant_value(gutils, origh) error_if_differentiable else error_if_active end - + mi = new_from_original(gutils, origmi) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -1427,24 +1864,38 @@ function common_f_svec_ref_augfwd(offset, B, orig, gutils, normalR, shadowR, tap if offset != 1 pushfirst!(newops, operands(orig)[1]) end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, false) #=lookup=# callconv!(cal, callconv(orig)) - - + + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, errfn), emit_jltypeof!(B, cal)]) cal else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for j in 1:width - newops = LLVM.Value[mi, extract_value!(B, shadowh, j-1), new_from_original(gutils, origkey)] + for j = 1:width + newops = LLVM.Value[ + mi, + extract_value!(B, shadowh, j - 1), + new_from_original(gutils, origkey), + ] if offset != 1 pushfirst!(newops, operands(orig)[1]) end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + newops, + newvals, + false, + ) #=lookup=# callconv!(cal, callconv(orig)) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, errfn), emit_jltypeof!(B, cal)]) - shadow = insert_value!(B, shadow, cal, j-1) + emit_apply_generic!( + B, + LLVM.Value[unsafe_to_llvm(B, errfn), emit_jltypeof!(B, cal)], + ) + shadow = insert_value!(B, shadow, cal, j - 1) end shadow end @@ -1463,7 +1914,8 @@ function common_finalizer_fwd(offset, B, orig, gutils, normalR, shadowR) return true end emit_error(B, orig, "Enzyme: unhandled forward for jl_f_finalizer") - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -1475,7 +1927,8 @@ function common_finalizer_augfwd(offset, B, orig, gutils, normalR, shadowR, tape return true end emit_error(B, orig, "Enzyme: unhandled augmented forward for jl_f_finalizer") - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end diff --git a/src/typeanalysis.jl b/src/typeanalysis.jl index a1d90ba81f..a84f96f856 100644 --- a/src/typeanalysis.jl +++ b/src/typeanalysis.jl @@ -7,7 +7,10 @@ end Base.unsafe_convert(::Type{API.EnzymeTypeAnalysisRef}, ta::TypeAnalysis) = ta.ref LLVM.dispose(ta::TypeAnalysis) = API.FreeTypeAnalysis(ta) -function TypeAnalysis(logic, typerules::Dict{String, CustomRuleType}=Dict{String,CustomRuleType}()) +function TypeAnalysis( + logic, + typerules::Dict{String,CustomRuleType} = Dict{String,CustomRuleType}(), +) rulenames = String[] rules = CustomRuleType[] for (rulename, rule) in typerules @@ -20,4 +23,4 @@ end # typedef bool (*CustomRuleType)(int /*direction*/, CTypeTree * /*return*/, # CTypeTree * /*args*/, size_t /*numArgs*/, -# LLVMValueRef)=T \ No newline at end of file +# LLVMValueRef)=T diff --git a/src/typetree.jl b/src/typetree.jl index 89e5a040f3..8ddce070b2 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -51,7 +51,7 @@ function shift!(tt::TypeTree, dl, offset, maxSize, addOffset) API.EnzymeTypeTreeShiftIndiciesEq(tt, dl, offset, maxSize, addOffset) end -function merge!(dst::TypeTree, src::TypeTree; consume=true) +function merge!(dst::TypeTree, src::TypeTree; consume = true) API.EnzymeMergeTypeTree(dst, src) if consume LLVM.dispose(src) @@ -80,28 +80,12 @@ end @static if VERSION >= v"1.11-" -const TypeTreePrimitives = ( - Char, - Float16, - Float32, - Float64, - Core.BFloat16 -) + const TypeTreePrimitives = (Char, Float16, Float32, Float64, Core.BFloat16) else -const TypeTreePrimitives = ( - Char, - Float16, - Float32, - Float64 -) + const TypeTreePrimitives = (Char, Float16, Float32, Float64) end -const TypeTreeEmptyPointers = ( - BigFloat, - Any, - Symbol, - Union{}, -) +const TypeTreeEmptyPointers = (BigFloat, Any, Symbol, Union{}) function get_offsets(@nospecialize(T::Type)) for sT in (Integer, TypeTreePrimitives...) @@ -109,18 +93,22 @@ function get_offsets(@nospecialize(T::Type)) return ((typetree_primitive(T), 0),) end end - for sT in (DataType, AbstractString, TypeTreeEmptyPointers...) + for sT in (DataType, AbstractString) if T <: sT return ((API.DT_Pointer, 0),) end end - -@static if VERSION < v"1.11-" - TypeTreePtrs = (Core.SimpleVector, Ptr, Core.LLVMPtr, Array) -else - TypeTreePtrs = (Core.SimpleVector, Ptr, Core.LLVMPtr, Array, GenericMemory) -end for sT in TypeTreeEmptyPointers + if T == sT + return ((API.DT_Pointer, 0),) + end + end + @static if VERSION < v"1.11-" + TypeTreePtrs = (Core.SimpleVector, Ptr, Core.LLVMPtr, Array) + else + TypeTreePtrs = (Core.SimpleVector, Ptr, Core.LLVMPtr, Array, GenericMemory) + end + for sT in TypeTreePtrs if T <: sT return ((API.DT_Pointer, 0),) end @@ -132,8 +120,8 @@ end return () end - results = Tuple{API.CConcreteType, Int}[] - for f in 1:fieldcount(T) + results = Tuple{API.CConcreteType,Int}[] + for f = 1:fieldcount(T) offset = fieldoffset(T, f) subT = fieldtype(T, f) @@ -141,9 +129,9 @@ end push!(results, (API.DT_Pointer, offset)) continue end - + for (sT, sO) in get_offsets(subT) - push!(results, (sT, sO+offset)) + push!(results, (sT, sO + offset)) end end return results @@ -173,10 +161,17 @@ function to_fullmd(@nospecialize(T::Type)) end function to_md(tt::TypeTree, ctx) - return LLVM.Metadata(LLVM.MetadataAsValue(ccall((:EnzymeTypeTreeToMD, API.libEnzyme), - LLVM.API.LLVMValueRef, - (API.CTypeTreeRef, - LLVM.API.LLVMContextRef), tt, ctx))) + return LLVM.Metadata( + LLVM.MetadataAsValue( + ccall( + (:EnzymeTypeTreeToMD, API.libEnzyme), + LLVM.API.LLVMValueRef, + (API.CTypeTreeRef, LLVM.API.LLVMContextRef), + tt, + ctx, + ), + ), + ) end const TypeTreeTable = IdDict{Any,Union{Nothing,TypeTree}} @@ -190,7 +185,7 @@ Construct a Enzyme typetree from a Julia type. When using a memoized lookup by providing `seen` across multiple calls to typtree the user must call `copy` on the returned value before mutating it. """ -function typetree(@nospecialize(T::Type), ctx, dl, seen=TypeTreeTable()) +function typetree(@nospecialize(T::Type), ctx, dl, seen = TypeTreeTable()) if haskey(seen, T) tree = seen[T] if tree === nothing @@ -209,7 +204,7 @@ function typetree_inner(::Type{<:Integer}, ctx, dl, seen::TypeTreeTable) end for sT in TypeTreePrimitives @eval function typetree_inner(::Type{$sT}, ctx, dl, seen::TypeTreeTable) - return TypeTree($(typetree_primitive(sT)), -1, ctx) + return TypeTree($(typetree_primitive(sT)), -1, ctx) end end @@ -221,21 +216,25 @@ function typetree_inner(::Type{<:AbstractString}, ctx, dl, seen::TypeTreeTable) end for sT in TypeTreeEmptyPointers @eval function typetree_inner(::Type{$sT}, ctx, dl, seen::TypeTreeTable) - return TypeTree() + return TypeTree() end end function typetree_inner(::Type{Core.SimpleVector}, ctx, dl, seen::TypeTreeTable) tt = TypeTree() - for i in 0:(sizeof(Csize_t) - 1) + for i = 0:(sizeof(Csize_t)-1) merge!(tt, TypeTree(API.DT_Integer, i, ctx)) end return tt end -function typetree_inner(::Type{<:Union{Ptr{T},Core.LLVMPtr{T}}}, ctx, dl, - seen::TypeTreeTable) where {T} +function typetree_inner( + ::Type{<:Union{Ptr{T},Core.LLVMPtr{T}}}, + ctx, + dl, + seen::TypeTreeTable, +) where {T} tt = copy(typetree(T, ctx, dl, seen)) merge!(tt, TypeTree(API.DT_Pointer, ctx)) only!(tt, -1) @@ -261,13 +260,18 @@ end sizeofstruct += sizeof(Csize_t) end - for i in offset:(sizeofstruct-1) + for i = offset:(sizeofstruct-1) merge!(tt, TypeTree(API.DT_Integer, i, ctx)) end return tt end else - function typetree_inner(::Type{<:GenericMemory{kind, T}}, ctx, dl, seen::TypeTreeTable) where {kind, T} + function typetree_inner( + ::Type{<:GenericMemory{kind,T}}, + ctx, + dl, + seen::TypeTreeTable, + ) where {kind,T} offset = 0 tt = copy(typetree(T, ctx, dl, seen)) if !allocatedinline(T) && Base.isconcretetype(T) @@ -277,7 +281,7 @@ else merge!(tt, TypeTree(API.DT_Pointer, ctx)) only!(tt, sizeof(Csize_t)) - for i in 0:(sizeof(Csize_t)-1) + for i = 0:(sizeof(Csize_t)-1) merge!(tt, TypeTree(API.DT_Integer, i, ctx)) end return tt @@ -327,7 +331,7 @@ function typetree_inner(@nospecialize(T::Type), ctx, dl, seen::TypeTreeTable) end tt = TypeTree() - for f in 1:fieldcount(T) + for f = 1:fieldcount(T) offset = fieldoffset(T, f) subT = fieldtype(T, f) subtree = copy(typetree(subT, ctx, dl, seen)) @@ -358,7 +362,10 @@ struct FnTypeInfo end Base.cconvert(::Type{API.CFnTypeInfo}, fnti::FnTypeInfo) = fnti function Base.unsafe_convert(::Type{API.CFnTypeInfo}, fnti::FnTypeInfo) - args_kv = Base.unsafe_convert(Ptr{API.IntList}, Base.cconvert(Ptr{API.IntList}, fnti.known_values)) + args_kv = Base.unsafe_convert( + Ptr{API.IntList}, + Base.cconvert(Ptr{API.IntList}, fnti.known_values), + ) rTT = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, fnti.rTT)) tts = API.CTypeTreeRef[] @@ -366,6 +373,9 @@ function Base.unsafe_convert(::Type{API.CFnTypeInfo}, fnti::FnTypeInfo) raw_tt = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, tt)) push!(tts, raw_tt) end - argTTs = Base.unsafe_convert(Ptr{API.CTypeTreeRef}, Base.cconvert(Ptr{API.CTypeTreeRef}, tts)) + argTTs = Base.unsafe_convert( + Ptr{API.CTypeTreeRef}, + Base.cconvert(Ptr{API.CTypeTreeRef}, tts), + ) return API.CFnTypeInfo(argTTs, rTT, args_kv) end diff --git a/src/utils.jl b/src/utils.jl index ac312e8295..d042859b89 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,10 +5,16 @@ Assumes that `val` is globally rooted and pointer to it can be leaked. Prefer `pointer_from_objref`. Only use inside Enzyme.jl should be for Types. """ -@inline unsafe_to_pointer(val::Type{T}) where T = ccall(Base.@cfunction(Base.identity, Ptr{Cvoid}, (Ptr{Cvoid},)), Ptr{Cvoid}, (Any,), val) +@inline unsafe_to_pointer(val::Type{T}) where {T} = ccall( + Base.@cfunction(Base.identity, Ptr{Cvoid}, (Ptr{Cvoid},)), + Ptr{Cvoid}, + (Any,), + val, +) export unsafe_to_pointer -@inline is_concrete_tuple(x::Type{T2}) where T2 = (T2 <: Tuple) && !(T2 === Tuple) && !(T2 isa UnionAll) +@inline is_concrete_tuple(x::Type{T2}) where {T2} = + (T2 <: Tuple) && !(T2 === Tuple) && !(T2 isa UnionAll) export is_concrete_tuple const Tracked = 10 @@ -20,11 +26,11 @@ const captured_constants = Base.IdSet{Any}() function unsafe_nothing_to_llvm(mod::LLVM.Module) globs = LLVM.globals(mod) k = "jl_nothing" - if Base.haskey(globs, "ejl_"*k) + if Base.haskey(globs, "ejl_" * k) return globs["ejl_"*k] end T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) - gv = LLVM.GlobalVariable(mod, T_jlvalue, "ejl_"*k, Tracked) + gv = LLVM.GlobalVariable(mod, T_jlvalue, "ejl_" * k, Tracked) API.SetMD(gv, "enzyme_ta_norecur", LLVM.MDNode(LLVM.Metadata[])) API.SetMD(gv, "enzyme_inactive", LLVM.MDNode(LLVM.Metadata[])) @@ -56,13 +62,13 @@ function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val)) if v === val mod = LLVM.parent(LLVM.parent(LLVM.position(B))) globs = LLVM.globals(mod) - if Base.haskey(globs, "ejl_"*k) + if Base.haskey(globs, "ejl_" * k) return globs["ejl_"*k] end - gv = LLVM.GlobalVariable(mod, T_jlvalue, "ejl_"*k, Tracked) + gv = LLVM.GlobalVariable(mod, T_jlvalue, "ejl_" * k, Tracked) API.SetMD(gv, "enzyme_ta_norecur", LLVM.MDNode(LLVM.Metadata[])) - legal, jTy = Compiler.abs_typeof(gv, true) + legal, jTy, byref = Compiler.abs_typeof(gv, true) if legal curent_bb = position(B) fn = LLVM.parent(curent_bb) @@ -78,12 +84,12 @@ function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val)) if v === val mod = LLVM.parent(LLVM.parent(LLVM.position(B))) globs = LLVM.globals(mod) - if Base.haskey(globs, "ejl_"*k) + if Base.haskey(globs, "ejl_" * k) return globs["ejl_"*k] end - gv = LLVM.GlobalVariable(mod, T_jlvalue, "ejl_"*k, Tracked) + gv = LLVM.GlobalVariable(mod, T_jlvalue, "ejl_" * k, Tracked) API.SetMD(gv, "enzyme_ta_norecur", LLVM.MDNode(LLVM.Metadata[])) - legal, jTy = Compiler.abs_typeof(gv, true) + legal, jTy, byref = Compiler.abs_typeof(gv, true) if legal curent_bb = position(B) fn = LLVM.parent(curent_bb) @@ -153,7 +159,11 @@ using Base: _methods_by_ftype # on 1.10 (JuliaLang/julia#48611) the generated function knows which world it was invoked in function _generated_ex(world, source, ex) - stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :ft, :tt), Core.svec()) + stub = Core.GeneratedFunctionStub( + identity, + Core.svec(:methodinstance, :ft, :tt), + Core.svec(), + ) stub(world, source, ex) end @@ -164,23 +174,38 @@ function codegen_world_age_generator(world::UInt, source, self, ft::Type, tt::Ty tt = tt.parameters[1] # validation - ft <: Core.Builtin && error("$(GPUCompiler.unsafe_function_from_type(ft)) is not a generic function") + ft <: Core.Builtin && + error("$(GPUCompiler.unsafe_function_from_type(ft)) is not a generic function") # look up the method method_error = :(throw(MethodError(ft, tt, $world))) - sig = Tuple{ft, tt.parameters...} + sig = Tuple{ft,tt.parameters...} min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results - mthds = Base._methods_by_ftype(sig, #=mt=# nothing, #=lim=# -1, - world, #=ambig=# false, - min_world, max_world, has_ambig) + mthds = Base._methods_by_ftype( + sig, + nothing, + -1, #=lim=# + world, + false, #=ambig=# + min_world, + max_world, + has_ambig, + ) mthds === nothing && return _generated_ex(world, source, method_error) length(mthds) == 1 || return _generated_ex(world, source, method_error) # look up the method and code instance mtypes, msp, m = mthds[1] - mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp) + mi = ccall( + :jl_specializations_get_linfo, + Ref{MethodInstance}, + (Any, Any, Any), + m, + mtypes, + msp, + ) ci = retrieve_code_info(mi, world)::CodeInfo # prepare a new code info @@ -222,8 +247,3 @@ end end export codegen_world_age - - - - - diff --git a/test/runtests.jl b/test/runtests.jl index 573140f2c2..92bfa47513 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2856,6 +2856,17 @@ end @test dx[3] ≈ 0 end +function unstable_fun(A0) + A = 'N' in ('H', 'h', 'S', 's') ? wrap(A0) : A0 + (@inbounds A[1])::eltype(A0) +end +@testset "Type unstable static array index" begin + inp = ones(SVector{2, Float64}) + res = Enzyme.gradient(Enzyme.Reverse, unstable_fun, inp)[1] + @test res ≈ [1.0, 0.0] + res = Enzyme.gradient(Enzyme.Forward, unstable_fun, inp)[1] + @test res ≈ [1.0, 0.0] +end function sparse_eval(x::Vector{Float64}) A = sparsevec([1, 1, 2, 3], [2.0*x[2]^3.0, 1.0-x[1], 2.0+x[3], -1.0]) diff --git a/test/typetree.jl b/test/typetree.jl index 51c284d6e9..1a869d6687 100644 --- a/test/typetree.jl +++ b/test/typetree.jl @@ -79,3 +79,16 @@ end "{[0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,4]:Float@float, [24,8]:Float@double}" end end + +@testset "GetOffsets" begin + @test Enzyme.get_offsets(Float16) == ((Enzyme.API.DT_Half,0),) + @test Enzyme.get_offsets(Float32) == ((Enzyme.API.DT_Float,0),) + @test Enzyme.get_offsets(Float64) == ((Enzyme.API.DT_Double,0),) + @test Enzyme.get_offsets(Int) == ((Enzyme.API.DT_Integer,0),) + @test Enzyme.get_offsets(Char) == ((Enzyme.API.DT_Integer,0),) + @test Enzyme.get_offsets(Ptr) == ((Enzyme.API.DT_Pointer,0),) + @test Enzyme.get_offsets(Ptr{Char}) == ((Enzyme.API.DT_Pointer,0),) + @test Enzyme.get_offsets(Ptr{Float32}) == ((Enzyme.API.DT_Pointer,0),) + @test Enzyme.get_offsets(Vector{Float32}) == ((Enzyme.API.DT_Pointer,0),) + @test Enzyme.get_offsets(Tuple{Float64, Int}) == [(Enzyme.API.DT_Double,0),(Enzyme.API.DT_Integer, 8)] +end From 63921794388c1ed36343e0d0ef5c4e7ddb03f0dd Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 24 Sep 2024 11:35:55 -0500 Subject: [PATCH 304/495] Update Project.toml (#1885) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5a0e192de5..3c93057f90 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.3" +version = "0.13.2" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From fe41647c22a1e939e7d5ef1ba1d9470833c420a3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 24 Sep 2024 13:32:46 -0500 Subject: [PATCH 305/495] fix array (#1884) * fix array * fix * Update absint.jl * fix * fix * Fix flake * sym --- src/absint.jl | 98 ++++++++++++++++++++++++++---------------- test/internal_rules.jl | 2 +- test/runtests.jl | 29 +++++++++++++ 3 files changed, 92 insertions(+), 37 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 585b1625a3..03cf53cf4e 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -158,7 +158,7 @@ function absint(arg::LLVM.Value, partial::Bool = false) end function actual_size(@nospecialize(typ2)) - if typ2 <: Array || typ2 <: AbstractString + if typ2 <: Array || typ2 <: AbstractString || typ2 <: Symbol return sizeof(Int) elseif Base.isconcretetype(typ2) return sizeof(typ2) @@ -359,52 +359,78 @@ function abs_typeof( end if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF dl = LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(arg)))) - if offset === nothing - byref = GPUCompiler.BITS_VALUE - legal = true - typ2 = typ - while actual_size(typ2) != sizeof(dl, value_type(arg)) - if fieldcount(typ2) > 0 - typ2 = fieldtype(typ, 1) - if !Base.allocatedinline(typ2) - if byref != GPUCompiler.BITS_VALUE - legal = false - break + function should_recurse(typ2, arg_t) + if actual_size(typ2) != sizeof(dl, arg_t) + return true + else + if Base.isconcretetype(typ2) + if fieldcount(typ2) > 0 + if actual_size(fieldtype(typ2,1)) == actual_size(fieldtype(typ2, 1)) + return true end - byref = GPUCompiler.MUT_REF - continue end end - legal = false - break - end - if legal - return (true, typ2, byref) + return false end - else + end + + byref = GPUCompiler.BITS_VALUE + legal = true + + while offset !== nothing && legal @assert Base.isconcretetype(typ) + seen = false + lasti = 1 for i = 1:fieldcount(typ) + fo = fieldoffset(typ, i) if fieldoffset(typ, i) == offset - subT = fieldtype(typ, i) - fsize = if i == fieldcount(typ) - sizeof(typ) - else - fieldoffset(typ, i + 1) - end - offset - if fsize == sizeof(dl, value_type(arg)) - if Base.isconcretetype(subT) && - is_concrete_tuple(subT) && - length(subT.parameters) == 1 - subT = subT.parameters[1] - end - if Base.allocatedinline(subT) - return (true, subT, GPUCompiler.BITS_VALUE) - else - return (true, subT, GPUCompiler.MUT_REF) + offset = nothing + typ = fieldtype(typ, i) + if !Base.allocatedinline(typ) + if byref != GPUCompiler.BITS_VALUE + legal = false end + byref = GPUCompiler.MUT_REF + end + seen = true + break + elseif fieldoffset(typ, i) > offset + offset = offset - fieldoffset(typ, lasti) + typ = fieldtype(typ, lasti) + if !Base.allocatedinline(typ) + legal = false + end + seen = true + break + end + + if fo != 0 && fo != fieldoffset(typ, i-1) + lasti = i + end + end + if !seen + legal = false + end + end + + typ2 = typ + while should_recurse(typ2, value_type(arg)) + if fieldcount(typ2) > 0 + typ2 = fieldtype(typ2, 1) + if !Base.allocatedinline(typ2) + if byref != GPUCompiler.BITS_VALUE + legal = false + break end + byref = GPUCompiler.MUT_REF + continue end end + legal = false + break + end + if legal + return (true, typ2, byref) end end elseif legal && if typ <: Ptr && Base.isconcretetype(typ) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 32a206c62e..0d5bbdae01 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -591,7 +591,7 @@ end TM in (Const, Duplicated, BatchDuplicated), TB in (Const, Duplicated, BatchDuplicated) are_activities_compatible(Const, TY, TM, TB) || continue - test_reverse(f!, TY, (Y, TY), (M, TM), (B, TB), (_A, Const)) + test_reverse(f!, TY, (Y, TY), (M, TM), (B, TB), (_A, Const); atol = 1.0e-5, rtol = 1.0e-5) end end @testset "test through `Adjoint` wrapper (regression test for #1306)" begin diff --git a/test/runtests.jl b/test/runtests.jl index 92bfa47513..69e6d51cd5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3959,6 +3959,35 @@ function harmonic_f!(inter_list, coords, inters) return si end +function invwsumsq(w::AbstractVector, a::AbstractVector) + s = zero(zero(eltype(a)) / zero(eltype(w))) + for i in eachindex(w) + s += abs2(a[i]) / w[i] + end + return s +end + +_logpdf(d, x) = invwsumsq(d.Σ.diag, x .- d.μ) + +function demo_func(x::Any=transpose([1.5 2.0;]);) + m = [-0.30725218207431315, 0.5492115788562757] + d = (; Σ = LinearAlgebra.Diagonal([1.0, 1.0]), μ = m) + logp = _logpdf(d, reshape(x, (2,))) + return logp +end + +demof(x) = demo_func() + +@testset "Type checks" begin + x = [0.0, 0.0] + Enzyme.autodiff( + Enzyme.Reverse, + Enzyme.Const(demof), + Enzyme.Active, + Enzyme.Duplicated(x, zero(x)), + ) +end + @testset "Decay preservation" begin inters = [HarmonicAngle(1.0, 0.1), HarmonicAngle(2.0, 0.3)] inter_list = [1, 3] From 9f6663311d18fa72b3de7e68eca3287e0aa31cc3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 25 Sep 2024 12:15:24 -0500 Subject: [PATCH 306/495] Fix abs cstring (#1888) --- Project.toml | 2 +- src/absint.jl | 54 +++++++++++++++++++++++++-------------------------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/Project.toml b/Project.toml index 3c93057f90..5a0e192de5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.2" +version = "0.13.3" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/src/absint.jl b/src/absint.jl index 03cf53cf4e..0739b7ded5 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -489,30 +489,30 @@ function abs_typeof( end return (false, nothing, nothing) end -# -# function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} -# if isa(arg, ConstantExpr) -# ce = arg -# while isa(ce, ConstantExpr) -# if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr -# ce = operands(ce)[1] -# elseif opcode(ce) == LLVM.API.LLVMGetElementPtr -# if all(x -> isa(x, LLVM.ConstantInt) && convert(UInt, x) == 0, operands(ce)[2:end]) -# ce = operands(ce)[1] -# else -# break -# end -# else -# break -# end -# end -# if isa(ce, LLVM.GlobalVariable) -# ce = LLVM.initializer(ce) -# if (isa(ce, LLVM.ConstantArray) || isa(ce, LLVM.ConstantDataArray)) && eltype(value_type(ce)) == LLVM.IntType(8) -# return (true, String(map((x)->convert(UInt8, x), collect(ce)[1:(end-1)]))) -# end -# -# end -# end -# return (false, "") -# end + +function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} + if isa(arg, ConstantExpr) + ce = arg + while isa(ce, ConstantExpr) + if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr + ce = operands(ce)[1] + elseif opcode(ce) == LLVM.API.LLVMGetElementPtr + if all(x -> isa(x, LLVM.ConstantInt) && convert(UInt, x) == 0, operands(ce)[2:end]) + ce = operands(ce)[1] + else + break + end + else + break + end + end + if isa(ce, LLVM.GlobalVariable) + ce = LLVM.initializer(ce) + if (isa(ce, LLVM.ConstantArray) || isa(ce, LLVM.ConstantDataArray)) && eltype(value_type(ce)) == LLVM.IntType(8) + return (true, String(map((x)->convert(UInt8, x), collect(ce)[1:(end-1)]))) + end + + end + end + return (false, "") +end From 519b693c7c8530414352c6a5a3e7f6c5c7fa4be1 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 25 Sep 2024 12:15:42 -0500 Subject: [PATCH 307/495] recfix (#1886) --- src/absint.jl | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 0739b7ded5..27f66d2c06 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -359,18 +359,24 @@ function abs_typeof( end if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF dl = LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(arg)))) - function should_recurse(typ2, arg_t) - if actual_size(typ2) != sizeof(dl, arg_t) - return true + function should_recurse(typ2, arg_t, byref) + sz = sizeof(dl, arg_t) + if byref != GPUCompiler.BITS_VALUE + @assert sz == sizeof(Int) + return false else - if Base.isconcretetype(typ2) - if fieldcount(typ2) > 0 - if actual_size(fieldtype(typ2,1)) == actual_size(fieldtype(typ2, 1)) - return true + if actual_size(typ2) != sz + return true + else + if Base.isconcretetype(typ2) + if fieldcount(typ2) > 0 + if actual_size(fieldtype(typ2,1)) == sz + return true + end end end + return false end - return false end end @@ -414,7 +420,7 @@ function abs_typeof( end typ2 = typ - while should_recurse(typ2, value_type(arg)) + while should_recurse(typ2, value_type(arg), byref) if fieldcount(typ2) > 0 typ2 = fieldtype(typ2, 1) if !Base.allocatedinline(typ2) From a72efb1ba163158c199a4811d4839b2640e4dd91 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 25 Sep 2024 22:26:13 -0500 Subject: [PATCH 308/495] Concrete type assertion (#1890) * Concrete type assertion * fix * fix * fix --- src/absint.jl | 72 ++++++++++++++++++++++++++++++------------------- src/compiler.jl | 4 ++- 2 files changed, 48 insertions(+), 28 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 27f66d2c06..9eec24bcf3 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -167,6 +167,43 @@ function actual_size(@nospecialize(typ2)) end end +@inline function first_non_ghost(@nospecialize(typ2)) + fc = fieldcount(typ2) + for i in 1:fc + if i == fc + return (i, sizeof(typ2)) + else + fo = fieldoffset(typ2, i+1) + if fo != 0 + return (i, fo) + end + end + end + return (-1, 0) +end + +function should_recurse(@nospecialize(typ2), arg_t, byref, dl) + sz = sizeof(dl, arg_t) + if byref != GPUCompiler.BITS_VALUE + @assert sz == sizeof(Int) + return false + else + if actual_size(typ2) != sz + return true + else + if Base.isconcretetype(typ2) + idx, sz2 = first_non_ghost(typ2) + if idx != -1 + if sz2 == sz + return true + end + end + end + return false + end + end +end + function abs_typeof( arg::LLVM.Value, partial::Bool = false, @@ -346,7 +383,7 @@ function abs_typeof( if !error legal, typ, byref = abs_typeof(larg) - if legal && (byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF) + if legal && (byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF) && Base.isconcretetype(typ) @static if VERSION < v"1.11-" if typ <: Array && Base.isconcretetype(typ) T = eltype(typ) @@ -359,31 +396,11 @@ function abs_typeof( end if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF dl = LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(arg)))) - function should_recurse(typ2, arg_t, byref) - sz = sizeof(dl, arg_t) - if byref != GPUCompiler.BITS_VALUE - @assert sz == sizeof(Int) - return false - else - if actual_size(typ2) != sz - return true - else - if Base.isconcretetype(typ2) - if fieldcount(typ2) > 0 - if actual_size(fieldtype(typ2,1)) == sz - return true - end - end - end - return false - end - end - end byref = GPUCompiler.BITS_VALUE legal = true - while offset !== nothing && legal + while (offset !== nothing && offset != 0) && legal @assert Base.isconcretetype(typ) seen = false lasti = 1 @@ -403,6 +420,7 @@ function abs_typeof( elseif fieldoffset(typ, i) > offset offset = offset - fieldoffset(typ, lasti) typ = fieldtype(typ, lasti) + @assert Base.isconcretetype(typ) if !Base.allocatedinline(typ) legal = false end @@ -420,9 +438,10 @@ function abs_typeof( end typ2 = typ - while should_recurse(typ2, value_type(arg), byref) - if fieldcount(typ2) > 0 - typ2 = fieldtype(typ2, 1) + while should_recurse(typ2, value_type(arg), byref, dl) + idx, _ = first_non_ghost(typ2) + if idx != -1 + typ2 = fieldtype(typ2, idx) if !Base.allocatedinline(typ2) if byref != GPUCompiler.BITS_VALUE legal = false @@ -439,10 +458,9 @@ function abs_typeof( return (true, typ2, byref) end end - elseif legal && if typ <: Ptr && Base.isconcretetype(typ) + elseif legal && typ <: Ptr && Base.isconcretetype(typ) return (true, eltype(typ), GPUCompiler.BITS_VALUE) end - end end end diff --git a/src/compiler.jl b/src/compiler.jl index f3680cc4c0..32dd293ece 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -7936,8 +7936,10 @@ function GPUCompiler.codegen( elseif byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF Ptr{source_typ} else - println(string(mod)) + # println(string(mod)) + println(string(f)) @show legal, source_typ, byref, llvm_source_typ, codegen_typ, string(inst) + @show enzyme_custom_extract_mi(f) @assert false end else From 31c60beca9b422adbd2f7d86e32802e65eaad31b Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 26 Sep 2024 10:56:35 +0200 Subject: [PATCH 309/495] remove spurious checkin --- test/ext/.chainrulescore.jl.swp | Bin 12288 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test/ext/.chainrulescore.jl.swp diff --git a/test/ext/.chainrulescore.jl.swp b/test/ext/.chainrulescore.jl.swp deleted file mode 100644 index 94b31875e128106c0c92d4e28639ebc06d4e3451..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12288 zcmeI2&u<$=6vrpMP@ts<4i&dYe85{A{|b?4R99(3f{;oMsNqtoXl&2A3-+$NJ67x- zhy;H@4^W`M0d55$aYLe9;D)$!;m83YAt51gtoY8%I=gn_G;J?LdMkZ)*R$`vdGEV# zY(<%7_088_r?cg%!11&YU*B5xfBJjp*r^k9V!Ib5DlrcZKAhb`Pqvy-^iDsH>g70+ zy>bw06^F7r_I}qHRyu*Mtc5p5Jym|YThS`f6*bhA)-@S~@t~`cRur2@V?VyK0<6Fz zDlpLg!pkQ&7wa{<)C=>^)3eWPK4MUIX9ZXRR)7^?1y})AfE8c`Sb?LYfbP$U9VC1# zP5NAVotwI*ANj%xumY?AE5Hh{0;~WlzzVPetN<&(3a|o4PyyK%;^t#Qe0&_q6AN##eJ+nEJH$IoX27 zKI7{leT^9ov$Jr^kK%0^w~I*>>k?g#70$&W?uMQxtQk%PdjnGpQxH<|l|jE7w4_$; z5?z$KbuAmndsz>J+~bp(Z$ukPwKJ#OoD=z7l!+?4B&Q~H`L0G`q9;;5q`OPtL4#xk z?C&{i^33m`d1rY~Rh|zq^(d#rtYG`6l8qiB#MGotimAavgvsV~5wg#4&Nakjo1328 zJ7Pv>Rs1mB?fOcXu;wVxL-cgInRg}V^|8vfe6xv{^r){Qzqs4i9?^wxBgH4}SFesAS(_D8WP+tN-!$uw+Pf=qo|>wIMe3BZcI3AM-!G16;XL_^MH_bH znb00J8Vxf=7Bn>;rpEX#fPKz*DK$Q*Du2;4)mV_W4ZdpjFTY{-FRz;Y%PUDsc4fR} z#XwS+X9Nw3OJ|}`6^EE`%NCnc{5y`kS=&4X+8#d;le6@bvWj?b) zYv5O%1WUe?;J~Wt9<6sHsTb?_uc(e1^5vw(*kN0V4N;(^uUdg0v?vKWq11y|nL_p{ zrS)aj!ib- Op22j$L`rX6Xz?G6bcs*^ From 0f7b3557e0c1791105718ae5417e30aae72cb27d Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 26 Sep 2024 11:11:56 -0500 Subject: [PATCH 310/495] CustomRules: fix body check (#1896) --- src/rules/customrules.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 1985283da3..cb6c60d98d 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -977,7 +977,7 @@ end !(aug_RT === Union{}) TapeT = EnzymeRules.tape_type(aug_RT) elseif (aug_RT isa UnionAll) && - (aug_RT <: EnzymeRules.AugmentedReturn) && + (aug_RT <: EnzymeRules.AugmentedReturn) && hasfield(typeof(aug_RT.body), :name) && aug_RT.body.name == EnzymeCore.EnzymeRules.AugmentedReturn.body.body.body.name if aug_RT.body.parameters[3] isa TypeVar TapeT = aug_RT.body.parameters[3].ub @@ -985,7 +985,7 @@ end TapeT = Any end elseif (aug_RT isa UnionAll) && - (aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && + (aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && hasfield(typeof(aug_RT.body), :name) && aug_RT.body.name == EnzymeCore.EnzymeRules.AugmentedReturnFlexShadow.body.body.body.name if aug_RT.body.parameters[3] isa TypeVar From 9495698d2d1fa163413ed6a3e113a29a02292ea4 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 26 Sep 2024 14:32:17 -0500 Subject: [PATCH 311/495] Sparsearrays ext (#1891) * Sparsearrays ext * fix --- Project.toml | 1 + src/compiler.jl | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5a0e192de5..f2b99062a0 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ ObjectFile = "d8793406-e978-5875-9003-1fc021f44a92" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" diff --git a/src/compiler.jl b/src/compiler.jl index 32dd293ece..1417379a83 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -38,7 +38,7 @@ import Enzyme_jll import GPUCompiler: CompilerJob, codegen, safe_name using LLVM.Interop import LLVM: Target, TargetMachine - +import SparseArrays using Printf using Preferences @@ -522,6 +522,7 @@ end @inline ptreltype(::Type{Tuple{Vararg{T}}}) where {T} = T @inline ptreltype(::Type{IdDict{K,V}}) where {K,V} = V @inline ptreltype(::Type{IdDict{K,V} where K}) where {V} = V +@inline ptreltype(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = T @inline is_arrayorvararg_ty(::Type) = false @inline is_arrayorvararg_ty(::Type{Array{T,N}}) where {T,N} = true @@ -533,6 +534,7 @@ end @inline is_arrayorvararg_ty(::Type{Base.RefValue{T}}) where {T} = true @inline is_arrayorvararg_ty(::Type{IdDict{K,V}}) where {K,V} = true @inline is_arrayorvararg_ty(::Type{IdDict{K,V} where K}) where {V} = true +@inline is_arrayorvararg_ty(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = true @inline function datatype_fieldcount(t::Type{T}) where {T} return Base.datatype_fieldcount(t) From f91eabb764d6d0e0d24b6e929a2aa0ffc86aec9b Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Thu, 26 Sep 2024 17:16:26 -0400 Subject: [PATCH 312/495] Add WithPrimal and NoPrimal function (#1898) * Add WithPrimal and NoPrimal function * version bumps --- Project.toml | 2 +- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/EnzymeCore.jl | 23 +++++++++++++++++++++++ src/Enzyme.jl | 6 +++++- test/runtests.jl | 19 +++++++++++++++++++ 5 files changed, 49 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index f2b99062a0..4cffef367c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.3" +version = "0.13.4" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 37ddaf6457..3a871b930c 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.2" +version = "0.8.3" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index f51c742f5d..3231674de5 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -244,6 +244,21 @@ const ReverseHolomorphicWithPrimal = ReverseMode{true,false,DefaultABI, true, fa @inline set_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,rt,ABI,Holomorphic,ErrIfFuncWritten}() @inline clear_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,false,ABI,Holomorphic,ErrIfFuncWritten}() +""" + WithPrimal(::Enzyme.Mode) + +Modifies the mode to include the primal value. +""" +@inline WithPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{true,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}() + +""" + NoPrimal(::Enzyme.Mode) + +Modifies the mode to exclude the primal value. +""" +@inline NoPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{false,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}() + + """ struct ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI} <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity} @@ -267,6 +282,10 @@ const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,Defau @inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() @inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() +@inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() +@inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() + + """ struct Forward{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} @@ -286,6 +305,10 @@ const ForwardWithPrimal = ForwardMode{true, DefaultABI, false, false}() @inline set_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, rt::Bool) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,rt}() @inline clear_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,false}() +@inline WithPrimal(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{true,ABI,ErrIfFuncWritten,RuntimeActivity}() +@inline NoPrimal(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() + + function autodiff end function autodiff_deferred end function autodiff_thunk end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index c99114e038..b49c3738f6 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -46,7 +46,9 @@ import EnzymeCore: set_abi, set_runtime_activity, clear_runtime_activity, - within_autodiff + within_autodiff, + WithPrimal, + NoPrimal export Annotation, Const, Active, @@ -63,6 +65,8 @@ export Annotation, set_abi, set_runtime_activity, clear_runtime_activity, + WithPrimal, + NoPrimal, within_autodiff import EnzymeCore: BatchDuplicatedFunc diff --git a/test/runtests.jl b/test/runtests.jl index 69e6d51cd5..d499febd77 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4066,6 +4066,25 @@ end @test res[2][6] ≈ 6.0 end +@testset "WithPrimal" begin + @test WithPrimal(Reverse) === ReverseWithPrimal + @test NoPrimal(Reverse) === Reverse + @test WithPrimal(ReverseWithPrimal) === ReverseWithPrimal + @test NoPrimal(ReverseWithPrimal) === Reverse + + @test WithPrimal(set_runtime_activity(Reverse)) === set_runtime_activity(ReverseWithPrimal) + + @test WithPrimal(Forward) === ForwardWithPrimal + @test NoPrimal(Forward) === Forward + @test WithPrimal(ForwardWithPrimal) === ForwardWithPrimal + @test NoPrimal(ForwardWithPrimal) === Forward + + @test WithPrimal(ReverseSplitNoPrimal) === ReverseSplitWithPrimal + @test NoPrimal(ReverseSplitNoPrimal) === ReverseSplitNoPrimal + @test WithPrimal(ReverseSplitWithPrimal) === ReverseSplitWithPrimal + @test NoPrimal(ReverseSplitWithPrimal) === ReverseSplitNoPrimal +end + # TEST EXTENSIONS using SpecialFunctions @testset "SpecialFunctions ext" begin From 3565c573d5b92330e33b3440623b9604ed7ebfbc Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 26 Sep 2024 16:45:35 -0500 Subject: [PATCH 313/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4cffef367c..fd0882e97e 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ EnzymeStaticArraysExt = "StaticArrays" BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.8" +EnzymeCore = "0.8.3" Enzyme_jll = "0.0.150" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" From 17a0c7f9b6dbea81d09e5266a032bb6bfed3b4f3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 26 Sep 2024 17:12:32 -0500 Subject: [PATCH 314/495] Skip japi1 activity rule (#1899) --- src/rules/activityrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl index 7a940259fa..f84b6befb8 100644 --- a/src/rules/activityrules.jl +++ b/src/rules/activityrules.jl @@ -1,6 +1,6 @@ function julia_activity_rule(f::LLVM.Function) - if startswith(LLVM.name(f)) == "japi3" + if startswith(LLVM.name(f)) == "japi3" || startswith(LLVM.name(f)) == "japi1" return end mi, RT = enzyme_custom_extract_mi(f) From bbafecf3f3f5f05b1bc0f794652309ee28e2b108 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 15:29:29 -0500 Subject: [PATCH 315/495] Support fillarray return (#1901) --- src/compiler.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 1417379a83..36d899ac29 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -8778,6 +8778,8 @@ end function add_one_in_place(x) if x isa Base.RefValue x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x)))) + elseif x isa (Array{T,0} where T) + x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x)))) else error( "Enzyme Mutability Error: Cannot add one in place to immutable value " * From 23c2fde4bb3ed4d3465b72aac78d7052797943a6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 16:22:44 -0500 Subject: [PATCH 316/495] More info for dupnoneed (#1904) * More info for dupnoneed * Update lib/EnzymeCore/src/EnzymeCore.jl Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> * Update EnzymeCore.jl --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> --- lib/EnzymeCore/src/EnzymeCore.jl | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 3231674de5..f76b302302 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -79,7 +79,27 @@ end DuplicatedNoNeed(x, ∂f_∂x) Like [`Duplicated`](@ref), except also specifies that Enzyme may avoid computing -the original result and only compute the derivative values. +the original result and only compute the derivative values. This creates opportunities +for improved performance. + +```julia + +function square_byref(out, v) + out[] = v * v + nothing +end + +out = Ref(0.0) +dout = Ref(1.0) +Enzyme.autodiff(Reverse, square_byref, DuplicatedNoNeed(out, dout), Active(1.0)) +dout[] + +# output +0.0 +``` + +For example, marking the out variable as `DuplicatedNoNeed` instead of `Duplicated` allows +Enzyme to avoid computing `v * v` (while still computing its derivative). This should only be used if `x` is a write-only variable. Otherwise, if the differentiated function stores values in `x` and reads them back in subsequent computations, using From d092d4ab20cf8a01f489bc2822f0e6e6538df549 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 16:50:20 -0500 Subject: [PATCH 317/495] Abs int end of load (#1905) --- src/absint.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/absint.jl b/src/absint.jl index 9eec24bcf3..041b3bd1cc 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -432,13 +432,22 @@ function abs_typeof( lasti = i end end + if !seen && fieldcount(typ) > 0 + offset = offset - fieldoffset(typ, lasti) + typ = fieldtype(typ, lasti) + @assert Base.isconcretetype(typ) + if !Base.allocatedinline(typ) + legal = false + end + seen = true + end if !seen legal = false end end typ2 = typ - while should_recurse(typ2, value_type(arg), byref, dl) + while legal && should_recurse(typ2, value_type(arg), byref, dl) idx, _ = first_non_ghost(typ2) if idx != -1 typ2 = fieldtype(typ2, idx) From 126127910baa58642f106f78b47dc1d9e05108f2 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Fri, 27 Sep 2024 18:02:48 -0400 Subject: [PATCH 318/495] Add reverse rule for Sparse dense matmul/vec (#1792) * Add sparse array internal rule * Add sparsearray extension for mul! * Add more testing * Add BatchDuplicated (still broken) * Remove BatchMode since it isn't applicable? * Add sparse array testing * Don't support batchmode for now * Revert project to old style * Add sparse array compat bound * reenable batch mode for bug hunting * Turn on BatchDuplicated stuff again * Remove Q comment * Encorporate BatchDuplicated into testing properly * Consider constant fp in runtime activity (#1797) * Consider constant fp in runtime activity * fix * Suggest workaround in error for overwritten active by ref (#1791) * Fix custom active reverse mode check (#1798) * Remove Q comment * Encorporate BatchDuplicated into testing properly * Look for more writebarrier opportunities (#1800) * Look for more writebarrier opportunities * Update compiler.jl * Restrict version to 1.10+ (#1809) * Restrict version to 1.10+ * fix * fixup * Update CI.yml * Update Project.toml * Update Project.toml * Update Project.toml * Fix MixedDuplicated ABI error on primalerror (#1815) * Update test * Move new SparseArrays Cholmod into extension * Make LinearAlgebra.mul! explicit * Make sparse arrays not a extension * Fix rules for 0.13 * Remove sparse arrays ext file * Update compiler --------- Co-authored-by: William Moses Co-authored-by: Daniel Wennberg --- Project.toml | 1 + src/Enzyme.jl | 1 + src/internal_rules.jl | 129 ++++++++++++++++++++++++++++++++++++++++- test/internal_rules.jl | 42 ++++++++++++++ test/runtests.jl | 2 +- 5 files changed, 173 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index fd0882e97e..9623428c29 100644 --- a/Project.toml +++ b/Project.toml @@ -42,6 +42,7 @@ LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" ObjectFile = "0.4" Preferences = "1.4" +SparseArrays = "1" SpecialFunctions = "1, 2" StaticArrays = "1" julia = "1.10" diff --git a/src/Enzyme.jl b/src/Enzyme.jl index b49c3738f6..091021cb8b 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -98,6 +98,7 @@ export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient! export markType, batch_size, onehot, chunkedonehot using LinearAlgebra +import SparseArrays import EnzymeCore: ReverseMode, ReverseModeSplit, ForwardMode, Mode import EnzymeCore: EnzymeRules diff --git a/src/internal_rules.jl b/src/internal_rules.jl index f8c6e730bb..b6d081d57d 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -724,6 +724,133 @@ function EnzymeRules.reverse( return (nothing, nothing) end + +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, + func::Const{typeof(LinearAlgebra.mul!)}, + ::Type{RT}, + C::Annotation{<:StridedVecOrMat}, + A::Const{<:SparseArrays.SparseMatrixCSCUnion}, + B::Annotation{<:StridedVecOrMat}, + α::Annotation{<:Number}, + β::Annotation{<:Number} + ) where {RT} + + cache_C = !(isa(β, Const)) ? copy(C.val) : nothing + # Always need to do forward pass otherwise primal may not be correct + func.val(C.val, A.val, B.val, α.val, β.val) + + primal = if EnzymeRules.needs_primal(config) + C.val + else + nothing + end + + shadow = if EnzymeRules.needs_shadow(config) + C.dval + else + nothing + end + + # Check if A is overwritten and B is active (and thus required) + cache_A = ( EnzymeRules.overwritten(config)[5] + && !(typeof(B) <: Const) + && !(typeof(C) <: Const) + ) ? copy(A.val) : nothing + + # cache_B = ( EnzymeRules.overwritten(config)[6]) ? copy(B.val) : nothing + + if !isa(α, Const) + cache_α = A.val*B.val + else + cache_α = nothing + end + + cache = (cache_C, cache_A, cache_α) + + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse(config::EnzymeRules.RevConfig, + func::Const{typeof(LinearAlgebra.mul!)}, + ::Type{RT}, cache, + C::Annotation{<:StridedVecOrMat}, + A::Const{<:SparseArrays.SparseMatrixCSCUnion}, + B::Annotation{<:StridedVecOrMat}, + α::Annotation{<:Number}, + β::Annotation{<:Number} + ) where {RT} + + cache_C, cache_A, cache_α = cache + Cval = !isnothing(cache_C) ? cache_C : C.val + Aval = !isnothing(cache_A) ? cache_A : A.val + # Bval = !isnothing(cache_B) ? cache_B : B.val + + N = EnzymeRules.width(config) + if !isa(C, Const) + dCs = C.dval + dBs = isa(B, Const) ? dCs : B.dval + + dα = if !isa(α, Const) + if N == 1 + LinearAlgebra.dot(C.dval, cache_α) + else + ntuple(Val(N)) do i + Base.@_inline_meta + LinearAlgebra.dot(C.dval[i], cache_α) + end + end + else + nothing + end + + dβ = if !isa(β, Const) + if N == 1 + LinearAlgebra.dot(C.dval, Cval) + else + ntuple(Val(N)) do i + Base.@_inline_meta + LinearAlgebra.dot(C.dval[i], Cval) + end + end + else + nothing + end + + for i in 1:N + # This rule is incorrect since you need to project dA to have the same + # sparsity pattern as A. + # if !isa(A, Const) + # dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] + # #dA .+= α*dC*B' + # mul!(dA, dC, Bval', α.val, true) + # end + + if !isa(B, Const) + #dB .+= α*A'*dC + if N ==1 + func.val(dBs, Aval', dCs, α.val, true) + else + func.val(dBs[i], Aval', dCs[i], α.val, true) + end + end + + if N==1 + dCs .*= β.val + else + dCs[i] .*= β.val + end + end + end + + return (nothing, nothing, nothing, dα, dβ) +end + + + + + + + function EnzymeRules.forward( config::EnzymeRules.FwdConfig, ::Const{typeof(sort!)}, @@ -1269,4 +1396,4 @@ function EnzymeRules.reverse( smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, ) where {rngty<:Union{TaskLocalRNG,Xoshiro},FT<:Union{Float32,Float64}} return (nothing, nothing, nothing) -end +end \ No newline at end of file diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 0d5bbdae01..246929272b 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -677,4 +677,46 @@ end # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f4(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((0.0,0.0)),) end +@testset "SparseArrays spmatvec reverse rule" begin + C = zeros(18) + M = sprand(18, 9, 0.1) + v = randn(9) + α = 2.0 + β = 1.0 + + for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), + Tα in (Const, Active), Tβ in (Const, Active) + + are_activities_compatible(Tret, Tret, Tv, Tα, Tβ) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ)) + + end + + + for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) + are_activities_compatible(Tret, Tret, Tv) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) + end +end + +@testset "SparseArrays spmatmat reverse rule" begin + C = zeros(18, 11) + M = sprand(18, 9, 0.1) + v = randn(9, 11) + α = 2.0 + β = 1.0 + + for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), + Tα in (Const, Active), Tβ in (Const, Active) + + are_activities_compatible(Tret, Tv, Tα, Tβ) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ)) + end + + for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) + are_activities_compatible(Tret, Tv) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) + end +end + end # InternalRules diff --git a/test/runtests.jl b/test/runtests.jl index d499febd77..bd1c7dd90d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4101,4 +4101,4 @@ include("ext/logexpfunctions.jl") @testset "BFloat16s ext" begin include("ext/bfloat16s.jl") -end +end \ No newline at end of file From 8e4c50a174fccc9bba81f1a78a16acdf3042cdb3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 18:18:23 -0500 Subject: [PATCH 319/495] Fix japi1 (#1907) --- src/rules/activityrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl index f84b6befb8..13bacb06a5 100644 --- a/src/rules/activityrules.jl +++ b/src/rules/activityrules.jl @@ -1,6 +1,6 @@ function julia_activity_rule(f::LLVM.Function) - if startswith(LLVM.name(f)) == "japi3" || startswith(LLVM.name(f)) == "japi1" + if startswith(LLVM.name(f), "japi3") || startswith(LLVM.name(f), "japi1") return end mi, RT = enzyme_custom_extract_mi(f) From 5fe7d91d82a5e1c3465836ab09504a8fb1a6464b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 27 Sep 2024 20:18:20 -0500 Subject: [PATCH 320/495] CompatHelper: bump compat for Enzyme_jll to 0.0.151, (keep existing compat) (#1908) * CompatHelper: bump compat for Enzyme_jll to 0.0.151, (keep existing compat) * Add symv and bump jll --------- Co-authored-by: CompatHelper Julia Co-authored-by: William Moses --- Project.toml | 2 +- src/compiler.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 9623428c29..0a03500ac9 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.3" -Enzyme_jll = "0.0.150" +Enzyme_jll = "0.0.151" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" diff --git a/src/compiler.jl b/src/compiler.jl index 36d899ac29..08dc5f05c9 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -7106,7 +7106,7 @@ function GPUCompiler.codegen( disableFallback = String[] ForwardModeDerivatives = - ("nrm2", "dot", "gemm", "gemv", "axpy", "copy", "scal", "symm", "syrk", "potrf") + ("nrm2", "dot", "gemm", "gemv", "axpy", "copy", "scal", "symv", "symm", "syrk", "potrf") ReverseModeDerivatives = ( "nrm2", "dot", @@ -7115,6 +7115,7 @@ function GPUCompiler.codegen( "axpy", "copy", "scal", + "symv", "symm", "trmv", "syrk", From 7c4a31aa28752138b4ded2c46833c2ad2c9ed63e Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 20:18:47 -0500 Subject: [PATCH 321/495] Fix randn (#1906) * Fix randn * Update internal_rules.jl --- src/internal_rules.jl | 71 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 69 insertions(+), 2 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index b6d081d57d..438fcc8ecb 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -74,7 +74,7 @@ function EnzymeRules.inactive( ) return nothing end -function EnzymeRules.inactive(::typeof(Random.randn!), args...) +function EnzymeRules.inactive(::typeof(Random.randn!), ::Random.AbstractRNG, ::AbstractArray) return nothing end function EnzymeRules.inactive(::typeof(Random.default_rng), args...) @@ -1396,4 +1396,71 @@ function EnzymeRules.reverse( smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, ) where {rngty<:Union{TaskLocalRNG,Xoshiro},FT<:Union{Float32,Float64}} return (nothing, nothing, nothing) -end \ No newline at end of file +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + Ty::Const{typeof(Random.randn!)}, + RT::Type, + rng::Annotation{<:Random.AbstractRNG}, + dst::Annotation{<:AbstractArray}) + + Ty.val(rng.val, dst.val) + + if !(dst isa Const) + if EnzymeRules.width(config) == 1 + make_zero!(dst.dval) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + make_zero!(dst.dval[i]) + nothing + end + end + end + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + dst + elseif EnzymeRules.needs_shadow(config) + dst.dval + elseif EnzymeRules.needs_primal(config) + dst.val + else + nothing + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + Ty::Const{typeof(Random.randn!)}, + RT::Type, + rng::Annotation{<:Random.AbstractRNG}, + dst::Annotation{<:AbstractArray} +) + Ty.val(rng.val, dst.val) + if RT <: Duplicated || RT <: DuplicatedNoNeed + make_zero!(dst.dval) + dst.dval + elseif RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + make_zero!(dst.dval[i]) + nothing + end + end + return EnzymeRules.AugmentedReturn( + EnzymeRules.needs_primal(config) ? dst.val : nothing, + EnzymeRules.needs_shadow(config) ? dst.dval : nothing, + nothing, + ) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + Ty::Const{typeof(Random.randn!)}, + RT::Type, + tape, + rng::Annotation{<:Random.AbstractRNG}, + dst::Annotation{<:AbstractArray}) + return (nothing, nothing) +end From a5c6fee3e6470a17aea20342b9166feee1a0ffa3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 20:53:43 -0500 Subject: [PATCH 322/495] Fix deferred any active return (#1909) * Fix deferred any active return * fix * fix --- Project.toml | 2 +- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/EnzymeCore.jl | 24 +-- src/Enzyme.jl | 245 +++++++++++++++++++------------ src/internal_rules.jl | 4 + test/abi.jl | 14 ++ test/runtests.jl | 4 +- 7 files changed, 188 insertions(+), 107 deletions(-) diff --git a/Project.toml b/Project.toml index 0a03500ac9..3b8ddd2060 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ EnzymeStaticArraysExt = "StaticArrays" BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.8.3" +EnzymeCore = "0.8.4" Enzyme_jll = "0.0.151" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 3a871b930c..2e45d2c2f6 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.3" +version = "0.8.4" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index f76b302302..a536a664aa 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -289,21 +289,21 @@ Reverse mode differentiation. - `Width`: Batch Size (0 if to be automatically derived) - `ModifiedBetween`: Tuple of each argument's modified between state (true if to be automatically derived). """ -struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI, ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end -const ReverseSplitNoPrimal = ReverseModeSplit{false, true, false, 0, true,DefaultABI, false}() -const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,DefaultABI, false}() -@inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, Width, MBO, ABI, ErrIfFuncWritten}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,MBO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,MB,ABI, ErrIfFuncWritten}() -@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, WidthO, MB, ABI, ErrIfFuncWritten}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,WidthO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,ABI, ErrIfFuncWritten}() +struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI,Holomorphic,ErrIfFuncWritten,ShadowInit} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end +const ReverseSplitNoPrimal = ReverseModeSplit{false, true, false, 0, true,DefaultABI, false, false, false}() +const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,DefaultABI, false, false, false}() +@inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, Width, MBO, ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,MBO,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,MB,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() +@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, WidthO, MB, ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,WidthO,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() -@inline set_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, true}() -@inline clear_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, false}() +@inline set_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, true, ShadowInit}() +@inline clear_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, Holomorphic, false, ShadowInit}() -@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,true,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() -@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() -@inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() +@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,true,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() +@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() +@inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() -@inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() -@inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() +@inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() +@inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() """ diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 091021cb8b..aa018ea23b 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -350,18 +350,13 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) throw(ErrorException("Cannot differentiate with a batch size of 0")) end - ModifiedBetween = Val(falses_from_args(Nargs + 1)) + ModifiedBetweenT = falses_from_args(Nargs + 1) + ModifiedBetween = Val(ModifiedBetweenT) tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} FTy = Core.Typeof(f.val) - opt_mi = if RABI <: NonGenABI - Compiler.fspec(eltype(FA), tt′) - else - Val(codegen_world_age(FTy, tt)) - end - rt = if A isa UnionAll Compiler.primal_return_type(rmode, Val(codegen_world_age(FTy, tt)), FTy, tt) else @@ -370,20 +365,22 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) if A <: Active if (!allocatedinline(rt) || rt isa Union) && rt != Union{} - forward, adjoint = Enzyme.Compiler.thunk( - opt_mi, + forward, adjoint = autodiff_thunk( + ReverseModeSplit{ + ReturnPrimal, + #=ReturnShadow=#false, + RuntimeActivity, + width, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#true + }(), FA, Duplicated{rt}, - tt′, - Val(API.DEM_ReverseModeGradient), - Val(width), - ModifiedBetween, - Val(ReturnPrimal), - Val(true), - RABI, - Val(ErrIfFuncWritten), - Val(RuntimeActivity), - ) #=ShadowInit=# + (tt′).parameters... + ) res = forward(f, args...) tape = res[1] if ReturnPrimal @@ -400,6 +397,12 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) throw(ErrorException("Duplicated Returns not yet handled")) end + opt_mi = if RABI <: NonGenABI + Compiler.fspec(eltype(FA), tt′) + else + Val(codegen_world_age(FTy, tt)) + end + if (A <: Active && rt <: Complex) && rt != Union{} if Holomorphic seen = IdDict() @@ -651,7 +654,7 @@ Same as [`autodiff`](@ref) but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ @inline function autodiff_deferred( - ::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, + rmode::ReverseMode{ReturnPrimal,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs}, @@ -660,7 +663,7 @@ code, as well as high-order differentiation. A<:Annotation, ReturnPrimal, Nargs, - ABI, + RABI<:ABI, Holomorphic, ErrIfFuncWritten, RuntimeActivity, @@ -672,27 +675,85 @@ code, as well as high-order differentiation. end tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} - world = codegen_world_age(Core.Typeof(f.val), tt) + FTy = Core.Typeof(f.val) + world = codegen_world_age(FTy, tt) + + A2 = A if A isa UnionAll - rt = Core.Compiler.return_type(f.val, tt) - rt = A{rt} + rt = Compiler.primal_return_type(rmode, Val(world), FTy, tt) + A2 = A{rt} else @assert A isa DataType rt = A end - if eltype(rt) == Union{} + if rt == Union{} error("Return type inferred to be Union{}. Giving up.") end - ModifiedBetween = Val(falses_from_args(Nargs + 1)) + ModifiedBetweenT = falses_from_args(Nargs + 1) + ModifiedBetween = Val(ModifiedBetweenT) + + if A <: Active + if (!allocatedinline(rt) || rt isa Union) && rt != Union{} + rs = ReverseModeSplit{ + ReturnPrimal, + #=ReturnShadow=#false, + RuntimeActivity, + width, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#true + }() + TapeType = tape_type(rs, FA, Duplicated{rt}, + (tt′).parameters...) + forward, adjoint = autodiff_deferred_thunk( + rs, + TapeType, + FA, + Duplicated{rt}, + (tt′).parameters... + ) + res = forward(f, args...) + tape = res[1] + if ReturnPrimal + return (adjoint(f, args..., tape)[1], res[2]) + else + return adjoint(f, args..., tape) + end + end + elseif A <: Duplicated || + A <: DuplicatedNoNeed || + A <: BatchDuplicated || + A <: BatchDuplicatedNoNeed || + A <: BatchDuplicatedFunc + throw(ErrorException("Duplicated Returns not yet handled")) + end + + if (A <: Active && rt <: Complex) && rt != Union{} + if Holomorphic + throw( + ErrorException( + "Reverse-mode Active Holomorphic is not yet implemented in deferred codegen", + ), + ) + end + + throw( + ErrorException( + "Reverse-mode Active Complex return is ambiguous and requires more information to specify the desired result. See https://enzyme.mit.edu/julia/stable/faq/#Complex-numbers for more details.", + ), + ) + end adjoint_ptr = Compiler.deferred_codegen( Val(world), FA, Val(tt′), - Val(rt), + Val(A), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, @@ -704,9 +765,9 @@ code, as well as high-order differentiation. ) #=ShadowInit=# thunk = - Compiler.CombinedAdjointThunk{Ptr{Cvoid},FA,rt,tt′,width,ReturnPrimal}(adjoint_ptr) - if rt <: Active - args = (args..., Compiler.default_adjoint(eltype(rt))) + Compiler.CombinedAdjointThunk{Ptr{Cvoid},FA,A2,tt′,width,ReturnPrimal}(adjoint_ptr) + if A <: Active + args = (args..., Compiler.default_adjoint(rt)) elseif A <: Duplicated || A <: DuplicatedNoNeed || A <: BatchDuplicated || @@ -723,7 +784,7 @@ Same as `autodiff(::ForwardMode, f, Activity, args...)` but uses deferred compil code, as well as high-order differentiation. """ @inline function autodiff_deferred( - ::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, + ::ForwardMode{ReturnPrimal,RABI,ErrIfFuncWritten,RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs}, @@ -732,7 +793,7 @@ code, as well as high-order differentiation. FA<:Annotation, A<:Annotation, Nargs, - ABI, + RABI<:ABI, ErrIfFuncWritten, RuntimeActivity, } @@ -857,7 +918,9 @@ result, ∂v, ∂A Width, ModifiedBetweenT, RABI, + #=Holomorphic=#false, ErrIfFuncWritten, + ShadowInit }, ::Type{FA}, ::Type{A}, @@ -872,6 +935,7 @@ result, ∂v, ∂A RABI<:ABI, Nargs, ErrIfFuncWritten, + ShadowInit, RuntimeActivity, } width = if Width == 0 @@ -892,9 +956,6 @@ result, ∂v, ∂A tt = Tuple{map(eltype, args)...} - if !(A <: Const) - @assert ReturnShadow - end tt′ = Tuple{args...} opt_mi = if RABI <: NonGenABI Compiler.fspec(eltype(FA), tt′) @@ -910,7 +971,7 @@ result, ∂v, ∂A Val(width), ModifiedBetween, Val(ReturnPrimal), - Val(false), + Val(ShadowInit), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity), @@ -1055,7 +1116,9 @@ end Width, ModifiedBetweenT, RABI, + #=Holomorphic=#false, ErrIfFuncWritten, + ShadowInit, }, ::Type{FA}, ::Type{A}, @@ -1071,6 +1134,7 @@ end Nargs, ErrIfFuncWritten, RuntimeActivity, + ShadowInit, } width = if Width == 0 w = same_or_one(1, args...) @@ -1088,7 +1152,6 @@ end ModifiedBetween = Val(ModifiedBetweenT) end - @assert ReturnShadow TT = Tuple{args...} primal_tt = Tuple{map(eltype, args)...} @@ -1106,7 +1169,7 @@ end Val(width), ModifiedBetween, Val(ReturnPrimal), - Val(false), + Val(ShadowInit), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity), @@ -1134,6 +1197,9 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType Width, ModifiedBetweenT, RABI, + #=Holomorphic=#false, + #=ErrIfFuncWritten=#false, + #=ShadowInit=#false, }, ::Type{FA}, ::Type{A}, @@ -1215,7 +1281,7 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType end """ - autodiff_deferred_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Type{<:Annotation}...) + autodiff_deferred_thunk(::ReverseModeSplit, TapeType::Type, ftype::Type{<:Annotation}, Activity::Type{<:Annotation}, argtypes::Type{<:Annotation}...) Provide the split forward and reverse pass functions for annotated function type ftype when called with args of type `argtypes` when using reverse mode. @@ -1266,7 +1332,9 @@ result, ∂v, ∂A Width, ModifiedBetweenT, RABI, + #=Holomorphic=#false, ErrIfFuncWritten, + ShadowInit, }, tt::Type{TapeType}, fa::Type{FA}, @@ -1284,6 +1352,7 @@ result, ∂v, ∂A Nargs, ErrIfFuncWritten, RuntimeActivity, + ShadowInit } @assert RABI == FFIABI width = if Width == 0 @@ -1302,7 +1371,6 @@ result, ∂v, ∂A ModifiedBetween = Val(ModifiedBetweenT) end - @assert ReturnShadow TT = Tuple{args...} primal_tt = Tuple{map(eltype, args)...} @@ -1317,7 +1385,7 @@ result, ∂v, ∂A Val(width), ModifiedBetween, Val(ReturnPrimal), - Val(false), + Val(ShadowInit), TapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity), @@ -2054,7 +2122,6 @@ this function will retun an AbstractArray of shape `size(output)` of values of t jac end else - @assert !Holomorphic n_out_val = if length(Compiler.element(n_outs)) == 0 0 else @@ -2074,32 +2141,27 @@ this function will retun an AbstractArray of shape `size(output)` of values of t Core.Compiler.return_type(f, tt) end - ModifiedBetween = Val((false, false)) + ModifiedBetweenT = (false, false) FRT = Core.Typeof(f) FA = Const{FRT} - opt_mi = if RABI <: NonGenABI - Compiler.fspec(FRT, tt′) - else - Val(codegen_world_age(FRT, tt)) - end - if chunk == Val(1) || chunk == nothing - tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} - primal, adjoint = Enzyme.Compiler.thunk( - opt_mi, + primal, adjoint = autodiff_thunk( + ReverseModeSplit{ + #=ReturnPrimal=#false, + #=ReturnShadow=#true, + RuntimeActivity, + #=width=#1, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#false + }(), FA, DuplicatedNoNeed{rt}, - tt′, - Val(API.DEM_ReverseModeGradient), - Val(1), - ModifiedBetween, - Val(false), - Val(false), - RABI, - Val(ErrIfFuncWritten), - Val(RuntimeActivity), - ) #=ShadowInit=# + MD ? MixedDuplicated{XT} : Duplicated{XT} + ) tmp = ntuple(Val(n_out_val)) do i Base.@_inline_meta z = make_zero(x) @@ -2115,23 +2177,22 @@ this function will retun an AbstractArray of shape `size(output)` of values of t rows, outshape else chunksize = Compiler.element(chunk) - tt′ = - MD ? Tuple{BatchMixedDuplicated{XT,chunksize}} : - Tuple{BatchDuplicated{XT,chunksize}} - primal, adjoint = Enzyme.Compiler.thunk( - opt_mi, + primal, adjoint = autodiff_thunk( + ReverseModeSplit{ + #=ReturnPrimal=#false, + #=ReturnShadow=#true, + RuntimeActivity, + chunksize, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#false + }(), FA, - BatchDuplicatedNoNeed{rt}, - tt′, - Val(API.DEM_ReverseModeGradient), - chunk, - ModifiedBetween, - Val(false), - Val(false), - RABI, - Val(ErrIfFuncWritten), - Val(RuntimeActivity), - ) #=ShadowInit=# + BatchDuplicatedNoNeed{rt, chunksize}, + MD ? BatchMixedDuplicated{XT, chunksize} : BatchDuplicated{XT, chunksize} + ) num = ((n_out_val + chunksize - 1) ÷ chunksize) @@ -2141,20 +2202,22 @@ this function will retun an AbstractArray of shape `size(output)` of values of t else last_size = n_out_val - (num - 1) * chunksize tt′ = Tuple{BatchDuplicated{Core.Typeof(x),last_size}} - primal2, adjoint2 = Enzyme.Compiler.thunk( - opt_mi, + primal2, adjoint2 = autodiff_thunk( + ReverseModeSplit{ + #=ReturnPrimal=#false, + #=ReturnShadow=#true, + RuntimeActivity, + last_size, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#false + }(), FA, - BatchDuplicatedNoNeed{rt}, - tt′, - Val(API.DEM_ReverseModeGradient), - Val(last_size), - ModifiedBetween, - Val(false), - Val(false), - RABI, - Val(ErrIfFuncWritten), - Val(RuntimeActivity), - ) #=ShadowInit=# + BatchDuplicatedNoNeed{rt, last_size}, + MD ? BatchMixedDuplicated{XT, last_size} : BatchDuplicated{XT, last_size} + ) end tmp = ntuple(num) do i diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 438fcc8ecb..53ca1f9283 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -351,6 +351,8 @@ function EnzymeRules.augmented_primal( EnzymeRules.overwritten(config)[2:end], InlineABI, false, + false, + false }() fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) @@ -405,6 +407,8 @@ function EnzymeRules.reverse( EnzymeRules.overwritten(config)[2:end], InlineABI, false, + false, + false }() fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) diff --git a/test/abi.jl b/test/abi.jl index cbd467c155..5acb30e04f 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -300,6 +300,20 @@ using Test # returns: sret, const/ghost, !deserve_retbox end +unstable_load(x) = Base.inferencebarrier(x)[1] + +@testset "Any Return" begin + x = [2.7] + dx = [0.0] + Enzyme.autodiff(Reverse, Const(unstable_load), Active, Duplicated(x, dx)) + @test dx ≈ [1.0] + + x = [2.7] + dx = [0.0] + Enzyme.autodiff_deferred(Reverse, Const(unstable_load), Active, Duplicated(x, dx)) + @test dx ≈ [1.0] +end + @testset "Mutable Struct ABI" begin mutable struct MStruct val::Float32 diff --git a/test/runtests.jl b/test/runtests.jl index bd1c7dd90d..8b496b071f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -513,8 +513,8 @@ end mul3(z) = Base.inferencebarrier(2 * z) - @test_throws ErrorException autodiff(ReverseHolomorphic, mul3, Active, Active(z)) - @test_throws ErrorException autodiff(ReverseHolomorphic, mul3, Active{Complex}, Active(z)) + @test_throws MethodError autodiff(ReverseHolomorphic, mul3, Active, Active(z)) + @test_throws MethodError autodiff(ReverseHolomorphic, mul3, Active{Complex}, Active(z)) vals = Complex{Float64}[3.4 + 2.7im] dvals = Complex{Float64}[0.0] From 327558b3c31ac8714f5468409d8752b0eec1df0e Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 20:53:59 -0500 Subject: [PATCH 323/495] Handle type unstable getglobal (#1910) --- src/rules/jitrules.jl | 105 +++++++++++++++++++++++------------------- 1 file changed, 57 insertions(+), 48 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 75bc415654..28ecb7afea 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -456,26 +456,32 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) else annotation0 end - world = codegen_world_age(FT, tt) - opt_mi = Val(world) - forward, adjoint = thunk( - opt_mi, - dupClosure0 ? $dupty : Const{FT}, - annotationA, - Tuple{$(Types...)}, - Val(API.DEM_ReverseModePrimal), - width, - ModifiedBetween, - Val(true), - Val(false), - FFIABI, - Val(false), - runtimeActivity, - ) #=erriffuncwritten=# + internal_tape, origRet, initShadow, annotation = if f isa typeof(Core.getglobal) + gv = Core.getglobal(args[1].val, args[2].val) + @assert sizeof(gv) == 0 + (nothing, f, nothing, Const) + else + world = codegen_world_age(FT, tt) - internal_tape, origRet, initShadow = forward(dupClosure0 ? $dup : Const(f), args...) - annotation = annotationA + opt_mi = Val(world) + forward, adjoint = thunk( + opt_mi, + dupClosure0 ? $dupty : Const{FT}, + annotationA, + Tuple{$(Types...)}, + Val(API.DEM_ReverseModePrimal), + width, + ModifiedBetween, + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# + + (forward(dupClosure0 ? $dup : Const(f), args...)..., annotationA) + end resT = typeof(origRet) if annotation <: Const @@ -649,39 +655,42 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act annotation0 end - world = codegen_world_age(FT, tt) + if f isa typeof(Core.getglobal) + else + world = codegen_world_age(FT, tt) - opt_mi = Val(world) - _, adjoint = thunk( - opt_mi, - dupClosure0 ? $dupty : Const{FT}, - annotation, - Tuple{$(Types...)}, - Val(API.DEM_ReverseModePrimal), - width, - ModifiedBetween, - Val(true), - Val(false), - FFIABI, - Val(false), - runtimeActivity, - ) #=erriffuncwritten=# + opt_mi = Val(world) + _, adjoint = thunk( + opt_mi, + dupClosure0 ? $dupty : Const{FT}, + annotation, + Tuple{$(Types...)}, + Val(API.DEM_ReverseModePrimal), + width, + ModifiedBetween, + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# - tup = - if annotation0 <: Active || - annotation0 <: MixedDuplicated || - annotation0 <: BatchMixedDuplicated - adjoint( - dupClosure0 ? $dup : Const(f), - args..., - $shadowret, - tape.internal_tape, - )[1] - else - adjoint(dupClosure0 ? $dup : Const(f), args..., tape.internal_tape)[1] - end + tup = + if annotation0 <: Active || + annotation0 <: MixedDuplicated || + annotation0 <: BatchMixedDuplicated + adjoint( + dupClosure0 ? $dup : Const(f), + args..., + $shadowret, + tape.internal_tape, + )[1] + else + adjoint(dupClosure0 ? $dup : Const(f), args..., tape.internal_tape)[1] + end - $(outs...) + $(outs...) + end return nothing end From 259f16132157709f9a94020854439975db92b953 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 22:08:28 -0500 Subject: [PATCH 324/495] Optimize active only rev grad (#1911) * Optimize active only rev grad * Update Project.toml * add makezero s/marray --- ext/EnzymeStaticArraysExt.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index bcaa3ec6cb..b751c336a2 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -23,4 +23,11 @@ end end end +@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:SArray} + return Base.zero(x) +end +@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:MArray} + return Base.zero(x) +end + end From 4ff5e44de2c3403818dd6f5c2b10d66bbc6e359d Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 23:00:42 -0500 Subject: [PATCH 325/495] Fix getglobal value (#1912) --- src/rules/jitrules.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 28ecb7afea..f169ab2c4b 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -458,9 +458,9 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) end internal_tape, origRet, initShadow, annotation = if f isa typeof(Core.getglobal) - gv = Core.getglobal(args[1].val, args[2].val) + gv = Core.getglobal(map(x->x.val, args)...) @assert sizeof(gv) == 0 - (nothing, f, nothing, Const) + (nothing, gv, nothing, Const) else world = codegen_world_age(FT, tt) From edd00954a7e274064bcd9097485bcdf19f3624cc Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 28 Sep 2024 18:26:41 +0200 Subject: [PATCH 326/495] Improve documentation of modes (#1895) * Improve documentation of modes * Alignment * Add comment on setter functions * List helper functions * More cases for set_runtime_activity * Merge remote-tracking branch 'upstream/main' into gd/modes_doc * Cleaner diff * Smaller diff --- lib/EnzymeCore/src/EnzymeCore.jl | 251 +++++++++++++++++++++++++++---- 1 file changed, 220 insertions(+), 31 deletions(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index a536a664aa..394cd00a5f 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -209,50 +209,123 @@ end abstract type ABI Abstract type for what ABI will be used. + +# Subtypes + +- [`FFIABI`](@ref) (the default) +- [`InlineABI`](@ref) +- [`NonGenABI`](@ref) """ abstract type ABI end """ struct FFIABI <: ABI -Foreign function call ABI. JIT the differentiated function, then inttoptr call the address. +Foreign function call [`ABI`](@ref). JIT the differentiated function, then inttoptr call the address. """ struct FFIABI <: ABI end + """ struct InlineABI <: ABI -Inlining function call ABI. +Inlining function call [`ABI`](@ref). """ struct InlineABI <: ABI end + """ struct NonGenABI <: ABI -Non-generated function ABI. +Non-generated function [`ABI`](@ref). """ struct NonGenABI <: ABI end + const DefaultABI = FFIABI """ - abstract type Mode + abstract type Mode{ABI,ErrIfFuncWritten,RuntimeActivity} + +Abstract type for which differentiation mode will be used. -Abstract type for what differentiation mode will be used. +# Subtypes + +- [`ForwardMode`](@ref) +- [`ReverseMode`](@ref) +- [`ReverseModeSplit`](@ref) + +# Type parameters + +- `ABI`: what runtime [`ABI`](@ref) to use +- `ErrIfFuncWritten`: whether to error when the function differentiated is a closure and written to. +- `RuntimeActivity`: whether to enable runtime activity (default off) + +!!! warning + The type parameters of `Mode` are not part of the public API and can change without notice. + You can modify them with the following helper functions: + - [`WithPrimal`](@ref) / [`NoPrimal`](@ref) + - [`set_err_if_func_written`](@ref) / [`clear_err_if_func_written`](@ref) + - [`set_runtime_activity`](@ref) / [`clear_runtime_activity`](@ref) + - [`set_abi`](@ref) """ abstract type Mode{ABI, ErrIfFuncWritten, RuntimeActivity} end """ - struct ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} + struct ReverseMode{ + ReturnPrimal, + RuntimeActivity, + ABI, + Holomorphic, + ErrIfFuncWritten + } <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity} + +Subtype of [`Mode`](@ref) for reverse mode differentiation. -Reverse mode differentiation. -- `ReturnPrimal`: Should Enzyme return the primal return value from the augmented-forward. -- `RuntimeActivity`: Should Enzyme enable runtime activity (default off) -- `ABI`: What runtime ABI to use -- `Holomorphic`: Whether the complex result function is holomorphic and we should compute d/dz -- `ErrIfFuncWritten`: Should Enzyme err if the function differentiated is a closure and written to. +# Type parameters + +- `ReturnPrimal`: whether to return the primal return value from the augmented-forward pass. +- `Holomorphic`: Whether the complex result function is holomorphic and we should compute `d/dz` +- other parameters: see [`Mode`](@ref) + +!!! warning + The type parameters of `ReverseMode` are not part of the public API and can change without notice. + Please use one of the following concrete instantiations instead: + - [`Reverse`](@ref) + - [`ReverseWithPrimal`](@ref) + - [`ReverseHolomorphic`](@ref) + - [`ReverseHolomorphicWithPrimal`](@ref) + You can modify them with the following helper functions: + - [`WithPrimal`](@ref) / [`NoPrimal`](@ref) + - [`set_err_if_func_written`](@ref) / [`clear_err_if_func_written`](@ref) + - [`set_runtime_activity`](@ref) / [`clear_runtime_activity`](@ref) + - [`set_abi`](@ref) """ struct ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end + +""" + const Reverse + +Default instance of [`ReverseMode`](@ref) that doesn't return the primal +""" const Reverse = ReverseMode{false,false,DefaultABI, false, false}() + +""" + const ReverseWithPrimal + +Default instance of [`ReverseMode`](@ref) that also returns the primal. +""" const ReverseWithPrimal = ReverseMode{true,false,DefaultABI, false, false}() + +""" + const ReverseHolomorphic + +Holomorphic instance of [`ReverseMode`](@ref) that doesn't return the primal +""" const ReverseHolomorphic = ReverseMode{false,false,DefaultABI, true, false}() + +""" + const ReverseHolomorphicWithPrimal + +Holomorphic instance of [`ReverseMode`](@ref) that also returns the primal +""" const ReverseHolomorphicWithPrimal = ReverseMode{true,false,DefaultABI, true, false}() @inline set_err_if_func_written(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,true}() @@ -265,34 +338,80 @@ const ReverseHolomorphicWithPrimal = ReverseMode{true,false,DefaultABI, true, fa @inline clear_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,false,ABI,Holomorphic,ErrIfFuncWritten}() """ - WithPrimal(::Enzyme.Mode) + WithPrimal(::Mode) -Modifies the mode to include the primal value. +Return a new mode which includes the primal value. """ @inline WithPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{true,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}() """ - NoPrimal(::Enzyme.Mode) + NoPrimal(::Mode) -Modifies the mode to exclude the primal value. +Return a new mode which excludes the primal value. """ @inline NoPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{false,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}() - """ - struct ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI} <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity} + struct ReverseModeSplit{ + ReturnPrimal, + ReturnShadow, + Width, + RuntimeActivity, + ModifiedBetween, + ABI, + ErrFuncIfWritten + } <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity} + WithPrimal(::Enzyme.Mode) + +Subtype of [`Mode`](@ref) for split reverse mode differentiation, to use in [`autodiff_thunk`](@ref) and variants. + +# Type parameters -Reverse mode differentiation. -- `ReturnPrimal`: Should Enzyme return the primal return value from the augmented-forward. -- `ReturnShadow`: Should Enzyme return the shadow return value from the augmented-forward. -- `RuntimeActivity`: Should Enzyme differentiate with runtime activity on (default off). -- `Width`: Batch Size (0 if to be automatically derived) -- `ModifiedBetween`: Tuple of each argument's modified between state (true if to be automatically derived). +- `ReturnShadow`: whether to return the shadow return value from the augmented-forward. +- `Width`: batch size (pick `0` to derive it automatically) +- `ModifiedBetween`: `Tuple` of each argument's "modified between" state (pick `true` to derive it automatically). +- other parameters: see [`ReverseMode`](@ref) + +!!! warning + The type parameters of `ReverseModeSplit` are not part of the public API and can change without notice. + Please use one of the following concrete instantiations instead: + - [`ReverseSplitNoPrimal`](@ref) + - [`ReverseSplitWithPrimal`](@ref) + You can modify them with the following helper functions: + - [`WithPrimal`](@ref) / [`NoPrimal`](@ref) + - [`set_err_if_func_written`](@ref) / [`clear_err_if_func_written`](@ref) + - [`set_runtime_activity`](@ref) / [`clear_runtime_activity`](@ref) + - [`set_abi`](@ref) + - [`ReverseSplitModified`](@ref), [`ReverseSplitWidth`](@ref) """ struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI,Holomorphic,ErrIfFuncWritten,ShadowInit} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end + +""" + const ReverseSplitNoPrimal + +Default instance of [`ReverseModeSplit`](@ref) that doesn't return the primal +""" const ReverseSplitNoPrimal = ReverseModeSplit{false, true, false, 0, true,DefaultABI, false, false, false}() + +""" + const ReverseSplitWithPrimal + +Default instance of [`ReverseModeSplit`](@ref) that also returns the primal +""" const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,DefaultABI, false, false, false}() + +""" + ReverseSplitModified(::ReverseModeSplit, ::Val{MB}) + +Return a new instance of [`ReverseModeSplit`](@ref) mode where `ModifiedBetween` is set to `MB`. +""" @inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, Width, MBO, ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,MBO,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,MB,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() + +""" + ReverseSplitWidth(::ReverseModeSplit, ::Val{W}) + +Return a new instance of [`ReverseModeSplit`](@ref) mode where `Width` is set to `W`. +""" @inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, WidthO, MB, ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,WidthO,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() @inline set_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, true, ShadowInit}() @@ -307,13 +426,46 @@ const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,Defau """ - struct Forward{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} + struct ForwardMode{ + ReturnPrimal, + ABI, + ErrIfFuncWritten, + RuntimeActivity + } <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity} + +Subtype of [`Mode`](@ref) for forward mode differentiation. -Forward mode differentiation +# Type parameters + +- `ReturnPrimal`: whether to return the primal return value from the augmented-forward. +- other parameters: see [`Mode`](@ref) + +!!! warning + The type parameters of `ForwardMode` are not part of the public API and can change without notice. + Please use one of the following concrete instantiations instead: + - [`Forward`](@ref) + - [`ForwardWithPrimal`](@ref) + You can modify them with the following helper functions: + - [`WithPrimal`](@ref) / [`NoPrimal`](@ref) + - [`set_err_if_func_written`](@ref) / [`clear_err_if_func_written`](@ref) + - [`set_runtime_activity`](@ref) / [`clear_runtime_activity`](@ref) + - [`set_abi`](@ref) """ struct ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} end + +""" + const Forward + +Default instance of [`ForwardMode`](@ref) that doesn't return the primal +""" const Forward = ForwardMode{false, DefaultABI, false, false}() + +""" + const ForwardWithPrimal + +Default instance of [`ForwardMode`](@ref) that also returns the primal +""" const ForwardWithPrimal = ForwardMode{true, DefaultABI, false, false}() @inline set_err_if_func_written(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,true,RuntimeActivity}() @@ -337,22 +489,22 @@ function autodiff_deferred_thunk end """ make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T - Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies - what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value. +Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies +what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value. """ function make_zero end """ make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing - Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`. +Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`. """ function make_zero! end """ make_zero(prev::T) - Helper function to recursively make zero. +Helper function to recursively make zero. """ @inline function make_zero(prev::T, ::Val{copy_if_inactive}=Val(false)) where {T, copy_if_inactive} make_zero(Core.Typeof(prev), IdDict(), prev, Val(copy_if_inactive)) @@ -383,10 +535,47 @@ if !isdefined(Base, :get_extension) end """ - within_autodiff() + within_autodiff() Returns true if within autodiff, otherwise false. """ function within_autodiff end +""" + set_err_if_func_written(::Mode) + +Return a new mode which throws an error for any attempt to write into an unannotated function object. +""" +function set_err_if_func_written end + +""" + clear_err_if_func_written(::Mode) + +Return a new mode which doesn't throw an error for attempts to write into an unannotated function object. +""" +function clear_err_if_func_written end + +""" + set_runtime_activity(::Mode) + set_runtime_activity(::Mode, activitiy::Bool) + set_runtime_activity(::Mode, config::Union{FwdConfig,RevConfig}) + +Return a new mode where runtime activity analysis is activated / set to the desired value. +""" +function set_runtime_activity end + +""" + clear_runtime_activity(::Mode) + +Return a new mode where runtime activity analysis is deactivated. +""" +function clear_runtime_activity end + +""" + set_abi(::Mode, ::Type{ABI}) + +Return a new mode with its [`ABI`](@ref) set to the chosen type. +""" +function set_abi end + end # module EnzymeCore From 467b4f7a7cd368ba2189b59464d5904cb3259394 Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Sat, 28 Sep 2024 12:27:23 -0400 Subject: [PATCH 327/495] fix some cases of gradient/jacobian with StaticArrays (#1875) * fix some cases of gradient/jacobian with StaticArrays * add tests * Update EnzymeStaticArraysExt.jl * jacobian is exported, wtf? --------- Co-authored-by: William Moses --- ext/EnzymeStaticArraysExt.jl | 9 ++++++- test/runtests.jl | 47 +++++++++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index b751c336a2..af31d405d7 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -3,7 +3,14 @@ module EnzymeStaticArraysExt using StaticArrays using Enzyme -@inline Enzyme.tupstack(rows::(NTuple{N, <:StaticArrays.SArray} where N), inshape, outshape) = reshape(cat(rows..., dims=length(inshape)), (inshape..., outshape...)) +@inline function Base.convert(::Type{SArray}, tpa::Enzyme.TupleArray{T,S,L,N}) where {T,S,L,N} + SArray{Tuple{S...},T,N,L}(tpa.data) +end +@inline Base.convert(::Type{StaticArray}, tpa::Enzyme.TupleArray) = convert(SArray, tpa) + +@inline function Enzyme.tupstack(rows::(NTuple{N, <:StaticArrays.SArray} where N), inshape, outshape) + reshape(reduce(hcat, map(vec, rows)), Size(inshape..., outshape...)) +end @inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}) where {S, T, N, L} ntuple(Val(L)) do i diff --git a/test/runtests.jl b/test/runtests.jl index 8b496b071f..eb9dfccb6c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2854,6 +2854,51 @@ end @test dx[1] ≈ 0 @test dx[2] ≈ 30 @test dx[3] ≈ 0 + + f0 = x -> sum(2*x) + f1 = x -> @SVector Float64[x[2], 2*x[2]] + f2 = x -> @SMatrix Float64[x[2] x[1]; 2*x[2] 2*x[1]] + + x = @SVector Float64[1, 2] + + dx = gradient(Forward, f0, x)[1] + @test dx isa Enzyme.TupleArray + @test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works + @test gradient(Forward, f1, x)[1] isa SMatrix + @test gradient(Forward, f1, x)[1] == [0 1.0; 0 2.0] + @test Enzyme.jacobian(Forward, f2, x)[1] isa SArray + @test Enzyme.jacobian(Forward, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2)) + + x = @SMatrix Float64[1 2; 3 4] + + dx = gradient(Forward, f0, x)[1] + @test dx isa Enzyme.TupleArray + @test convert(SArray, dx) == fill(2.0, (2,2)) + @test gradient(Forward, f1, x)[1] isa SArray + @test gradient(Forward, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2)) + @test Enzyme.jacobian(Forward, f2, x)[1] isa SArray + @test Enzyme.jacobian(Forward, f2, x)[1] == reshape( + Float64[0,0,1,2,1,2,0,0,0,0,0,0,0,0,0,0], (2,2,2,2), + ) + + x = @SVector Float64[1, 2] + + dx = gradient(Reverse, f0, x)[1] + @test dx isa SVector + @test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works + @test_broken gradient(Reverse, f1, x)[1] isa SMatrix + @test_broken gradient(Reverse, f1, x)[1] == [0 1.0; 0 2.0] + @test_broken Enzyme.jacobian(Reverse, f2, x)[1] isa SArray + @test_broken Enzyme.jacobian(Reverse, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2)) + + x = @SMatrix Float64[1 2; 3 4] + + @test_broken gradient(Reverse, f1, x)[1] isa SArray + @test_broken gradient(Reverse, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2)) + @test_broken Enzyme.jacobian(Reverse, f2, x)[1] isa SArray + @test_broken Enzyme.jacobian(Reverse, f2, x)[1] == reshape( + Float64[0,0,1,2,1,2,0,0,0,0,0,0,0,0,0,0], (2,2,2,2), + ) end function unstable_fun(A0) @@ -4101,4 +4146,4 @@ include("ext/logexpfunctions.jl") @testset "BFloat16s ext" begin include("ext/bfloat16s.jl") -end \ No newline at end of file +end From b999e5a00301a11dddfa1f9da416c6736eaf3bd9 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Sep 2024 13:14:28 -0500 Subject: [PATCH 328/495] Use actual_size instead of sizeof (#1915) * Use actual_size instead of sizeof * Better error str --- src/rules/llvmrules.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index a41912cf82..3c4b95d8ee 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -1519,7 +1519,7 @@ end found, arty, byref = abs_typeof(origops[1]) anti = shadowin elSize = if found - LLVM.ConstantInt(Csize_t(sizeof(eltype(arty)))) + LLVM.ConstantInt(Csize_t(actual_size(eltype(arty)))) else elSize = LLVM.zext!( B, @@ -1534,7 +1534,12 @@ end length = LLVM.mul!(B, len, elSize) if !found && !(eltype(arty) <: Base.IEEEFloat) - GPUCompiler.@safe_warn "TODO reverse jl_array_del_end zero-set used memset rather than runtime type of $((found, arty)) in $(string(origops[1]))" + bt = GPUCompiler.backtrace(orig) + btstr = sprint() do io + print(io, "\nCaused by:") + Base.show_backtrace(io, bt) + end + GPUCompiler.@safe_warn "TODO reverse jl_array_del_end zero-set used memset rather than runtime type of $((found, arty)) in $(string(origops[1])) $btstr" end toset = get_array_data(B, anti) toset = gep!(B, i8, toset, LLVM.Value[length]) From 55582f8ba60f41411d162b3d3d73155d0545199c Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Sep 2024 14:00:41 -0500 Subject: [PATCH 329/495] Even more indexed typeinfo (#1916) --- src/absint.jl | 61 ++++++++++++++++++++++++++----------------------- src/compiler.jl | 8 ++++--- src/typetree.jl | 25 +++++++++++++++++--- 3 files changed, 59 insertions(+), 35 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 041b3bd1cc..77ce2b6a7e 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -204,6 +204,34 @@ function should_recurse(@nospecialize(typ2), arg_t, byref, dl) end end +function get_base_and_offset(larg::LLVM.Value)::Tuple{LLVM.Value, Int, Bool} + offset = 0 + error = false + while true + if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.AddrSpaceCastInst) + larg = operands(larg)[1] + continue + end + if isa(larg, LLVM.GetElementPtrInst) && + all(x -> isa(x, LLVM.ConstantInt), operands(larg)[2:end]) + b = LLVM.IRBuilder() + position!(b, larg) + offty = LLVM.IntType(8 * sizeof(Int)) + offset2 = API.EnzymeComputeByteOffsetOfGEP(b, larg, offty) + @assert isa(offset2, LLVM.ConstantInt) + offset += convert(Int, offset2) + larg = operands(larg)[1] + continue + end + if isa(larg, LLVM.Argument) + break + end + error = true + break + end + return larg, offset, error +end + function abs_typeof( arg::LLVM.Value, partial::Bool = false, @@ -354,32 +382,7 @@ function abs_typeof( end if isa(arg, LLVM.LoadInst) - larg = operands(arg)[1] - offset = nothing - error = false - while true - if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.AddrSpaceCastInst) - larg = operands(larg)[1] - continue - end - if offset === nothing && - isa(larg, LLVM.GetElementPtrInst) && - all(x -> isa(x, LLVM.ConstantInt), operands(larg)[2:end]) - b = LLVM.IRBuilder() - position!(b, larg) - offty = LLVM.IntType(8 * sizeof(Int)) - offset = API.EnzymeComputeByteOffsetOfGEP(b, larg, offty) - @assert isa(offset, LLVM.ConstantInt) - offset = convert(Int, offset) - larg = operands(larg)[1] - continue - end - if isa(larg, LLVM.Argument) - break - end - error = true - break - end + larg, offset, error = get_base_and_offset(operands(arg)[1]) if !error legal, typ, byref = abs_typeof(larg) @@ -387,7 +390,7 @@ function abs_typeof( @static if VERSION < v"1.11-" if typ <: Array && Base.isconcretetype(typ) T = eltype(typ) - if offset === nothing || offset == 0 + if offset == 0 return (true, Ptr{T}, GPUCompiler.BITS_VALUE) else return (true, Int, GPUCompiler.BITS_VALUE) @@ -400,14 +403,14 @@ function abs_typeof( byref = GPUCompiler.BITS_VALUE legal = true - while (offset !== nothing && offset != 0) && legal + while offset != 0 && legal @assert Base.isconcretetype(typ) seen = false lasti = 1 for i = 1:fieldcount(typ) fo = fieldoffset(typ, i) if fieldoffset(typ, i) == offset - offset = nothing + offset = 0 typ = fieldtype(typ, i) if !Base.allocatedinline(typ) if byref != GPUCompiler.BITS_VALUE diff --git a/src/compiler.jl b/src/compiler.jl index 08dc5f05c9..e890bd998d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -7988,7 +7988,8 @@ function GPUCompiler.codegen( if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id || intr == LLVM.Intrinsic("llvm.memset").id - legal, jTy, byref = abs_typeof(operands(inst)[1]) + base, offset, _ = get_base_and_offset(operands(inst)[1]) + legal, jTy, byref = abs_typeof(base) sz = if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id @@ -8007,8 +8008,9 @@ function GPUCompiler.codegen( any(T2 isa Core.TypeofVararg for T2 in jTy.parameters) ) ) - if isa(sz, LLVM.ConstantInt) && sizeof(jTy) == convert(Int, sz) - md = to_fullmd(jTy) + if offset < sizeof(jTy) && isa(sz, LLVM.ConstantInt) && sizeof(jTy) - offset >= convert(Int, sz) + lim = convert(Int, sz) + md = to_fullmd(jTy, offset, lim) @assert byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF metadata(inst)["enzyme_truetype"] = md diff --git a/src/typetree.jl b/src/typetree.jl index 8ddce070b2..c96d41fb2b 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -137,9 +137,28 @@ function get_offsets(@nospecialize(T::Type)) return results end -function to_fullmd(@nospecialize(T::Type)) +function to_fullmd(@nospecialize(T::Type), offset::Int, lim::Int) mds = LLVM.Metadata[] - for (sT, sO) in get_offsets(T) + offs = get_offsets(T) + + minoff = -1 + for (sT, sO) in offs + if sO >= offset + if sO == offset + minOff = sO + end + else + minoff = max(minoff, sO) + end + end + + for (sT, sO) in offs + if sO != minoff && (sO < offset) + continue + end + if sO >= lim + continue + end if sT == API.DT_Pointer push!(mds, LLVM.MDString("Pointer")) elseif sT == API.DT_Integer @@ -155,7 +174,7 @@ function to_fullmd(@nospecialize(T::Type)) else @assert false end - push!(mds, LLVM.Metadata(LLVM.ConstantInt(sO))) + push!(mds, LLVM.Metadata(LLVM.ConstantInt(min(0, sO - offset)))) end return LLVM.MDNode(mds) end From b739cbfb241b6305ae4ca73683d590d23f141ed8 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Sep 2024 15:45:48 -0500 Subject: [PATCH 330/495] Correct offset to use max (#1918) * Correct offset to use max * fix --- src/typetree.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/typetree.jl b/src/typetree.jl index c96d41fb2b..89fe522670 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -174,7 +174,7 @@ function to_fullmd(@nospecialize(T::Type), offset::Int, lim::Int) else @assert false end - push!(mds, LLVM.Metadata(LLVM.ConstantInt(min(0, sO - offset)))) + push!(mds, LLVM.Metadata(LLVM.ConstantInt(max(0, sO - offset)))) end return LLVM.MDNode(mds) end From 287b847b2c7c7382da0503c7a6bd810f933def16 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Sep 2024 15:46:04 -0500 Subject: [PATCH 331/495] Fix active reg inner of literal type (#1917) --- src/compiler.jl | 2 +- test/runtests.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index e890bd998d..560ef0fcb3 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -663,7 +663,7 @@ end return AnyState end - if isghostty(T) || Core.Compiler.isconstType(T) + if isghostty(T) || Core.Compiler.isconstType(T) || T <: Type return AnyState end diff --git a/test/runtests.jl b/test/runtests.jl index eb9dfccb6c..f3b5be421f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -116,6 +116,7 @@ mutable struct MInts{A, B} end @testset "Internal tests" begin + @assert Enzyme.Compiler.active_reg_inner(Type{Array}, (), nothing) == Enzyme.Compiler.AnyState @assert Enzyme.Compiler.active_reg_inner(Ints{<:Any, Integer}, (), nothing) == Enzyme.Compiler.AnyState @assert Enzyme.Compiler.active_reg_inner(Ints{<:Any, Float64}, (), nothing) == Enzyme.Compiler.DupState @assert Enzyme.Compiler.active_reg_inner(Ints{Integer, <:Any}, (), nothing) == Enzyme.Compiler.DupState From d91151bae770bceb1aa639e211f54e16c913c642 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Sep 2024 16:09:11 -0500 Subject: [PATCH 332/495] Fix limit being relative (#1919) * Fix limit being relative * fix --- src/typetree.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/typetree.jl b/src/typetree.jl index 89fe522670..c886c683ce 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -145,7 +145,7 @@ function to_fullmd(@nospecialize(T::Type), offset::Int, lim::Int) for (sT, sO) in offs if sO >= offset if sO == offset - minOff = sO + minoff = sO end else minoff = max(minoff, sO) @@ -156,7 +156,7 @@ function to_fullmd(@nospecialize(T::Type), offset::Int, lim::Int) if sO != minoff && (sO < offset) continue end - if sO >= lim + if sO >= lim + offset continue end if sT == API.DT_Pointer From 4ab422baa2f76064219636273135c8854f94e48b Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Sep 2024 20:39:03 -0500 Subject: [PATCH 333/495] Fix memory of float (#1920) --- src/compiler.jl | 8 ++++++++ test/runtests.jl | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 560ef0fcb3..1b11da719e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -523,6 +523,10 @@ end @inline ptreltype(::Type{IdDict{K,V}}) where {K,V} = V @inline ptreltype(::Type{IdDict{K,V} where K}) where {V} = V @inline ptreltype(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = T +@static if VERSION < v"1.11-" +else +@inline ptreltype(::Type{Memory{T}}) where T = T +end @inline is_arrayorvararg_ty(::Type) = false @inline is_arrayorvararg_ty(::Type{Array{T,N}}) where {T,N} = true @@ -535,6 +539,10 @@ end @inline is_arrayorvararg_ty(::Type{IdDict{K,V}}) where {K,V} = true @inline is_arrayorvararg_ty(::Type{IdDict{K,V} where K}) where {V} = true @inline is_arrayorvararg_ty(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = true +@static if VERSION < v"1.11-" +else +@inline is_arrayorvararg_ty(::Type{Memory{T}}) where T = true +end @inline function datatype_fieldcount(t::Type{T}) where {T} return Base.datatype_fieldcount(t) diff --git a/test/runtests.jl b/test/runtests.jl index f3b5be421f..b4aab36752 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -116,6 +116,10 @@ mutable struct MInts{A, B} end @testset "Internal tests" begin + @static if VERSION < v"1.11-" + else + @assert Enzyme.Compiler.active_reg_inner(Memory{Float64}, (), nothing) == Enzyme.Compiler.DupState + end @assert Enzyme.Compiler.active_reg_inner(Type{Array}, (), nothing) == Enzyme.Compiler.AnyState @assert Enzyme.Compiler.active_reg_inner(Ints{<:Any, Integer}, (), nothing) == Enzyme.Compiler.AnyState @assert Enzyme.Compiler.active_reg_inner(Ints{<:Any, Float64}, (), nothing) == Enzyme.Compiler.DupState From b968cfe5a54bf5b9285217fd476605eaae54ed4a Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Sep 2024 23:52:01 -0500 Subject: [PATCH 334/495] Attempt to fix apple (#1834) --- test/runtests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index b4aab36752..e720ba46c4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -364,6 +364,9 @@ make3() = (1.0, 2.0, 3.0) @test autodiff(Forward, tanh, Duplicated(1.0f0, 1.0f0))[1] ≈ Float32(0.41997434161402606939) for T in (Float64, Float32, Float16) + if T == Float16 && Sys.isapple() + continue + end res = autodiff(Reverse, tanh, Active, Active(T(1)))[1][1] @test res isa T cmp = if T == Float64 From a16f41a68fc8a19e9d70a2b57d349bfd53252063 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 29 Sep 2024 01:30:54 -0500 Subject: [PATCH 335/495] Bump jll (#1921) --- Project.toml | 2 +- test/runtests.jl | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3b8ddd2060..21d0cab215 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.4" -Enzyme_jll = "0.0.151" +Enzyme_jll = "0.0.152" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" diff --git a/test/runtests.jl b/test/runtests.jl index e720ba46c4..ba4462bc23 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3557,6 +3557,36 @@ end @test din[2, 1] ≈ 1.0 end +@testset "View Vars" begin + + x = [Float32(0.25)] + dx = [Float32(0.0)] + rng = Base.UnitRange{Int64}(1, 0) + + f = Const(Base.SubArray{T, N, P, I, L} where L where I where P where N where T) + a1 = Const(Base.IndexLinear()) + a2 = Duplicated(x, dx) + a3 = Const((rng,)) + a4 = Const((true,)) + + fwd, rev = autodiff_thunk(ReverseSplitWithPrimal, + typeof(f), + Duplicated, + typeof(a1), + typeof(a2), + typeof(a3), + typeof(a4) + ) + + res = fwd(f,a1,a2,a3,a4) + @test res[2].indices == (rng,) + @test res[3].indices == (rng,) + @test res[2].offset1 == 0 + @test res[3].offset1 == 0 + @test res[2].stride1 == 1 + @test res[3].stride1 == 1 +end + @testset "Uncached batch sizes" begin genericsin(x) = Base.invokelatest(sin, x) res = Enzyme.autodiff(Forward, genericsin, BatchDuplicated(2.0, NTuple{10,Float64}((Float64(i) for i in 1:10))))[1] From f14ad34dfde2323d50b0339bec2e88a2bb729aee Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 29 Sep 2024 13:21:51 -0500 Subject: [PATCH 336/495] Bump jll (#1922) --- Project.toml | 4 ++-- test/runtests.jl | 27 +++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 21d0cab215..87ab945958 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.4" +version = "0.13.5" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.4" -Enzyme_jll = "0.0.152" +Enzyme_jll = "0.0.153" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" diff --git a/test/runtests.jl b/test/runtests.jl index ba4462bc23..902b9e4f65 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3585,6 +3585,33 @@ end @test res[3].offset1 == 0 @test res[2].stride1 == 1 @test res[3].stride1 == 1 + + x = [Float32(0.25)] + dx = [Float32(0.0)] + rng = Base.UnitRange{Int64}(1, 0) + + f = Const(Base.SubArray{T, N, P, I, L} where L where I where P where N where T) + a1 = Const(Base.IndexLinear()) + a2 = Duplicated(x, dx) + a3 = Const((rng,)) + a4 = Const((true,)) + + fwd, rev = autodiff_thunk(set_runtime_activity(ReverseSplitWithPrimal), + typeof(f), + Duplicated, + typeof(a1), + typeof(a2), + typeof(a3), + typeof(a4) + ) + + res = fwd(f,a1,a2,a3,a4) + @test res[2].indices == (rng,) + @test res[3].indices == (rng,) + @test res[2].offset1 == 0 + @test res[3].offset1 == 0 + @test res[2].stride1 == 1 + @test res[3].stride1 == 1 end @testset "Uncached batch sizes" begin From c9eae5b83f9053186111573da57f8e7ef3ffc947 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 29 Sep 2024 18:56:50 -0500 Subject: [PATCH 337/495] Fix make_zero on constant fields (#1926) * Fix make_zero on constant fields * type --- src/compiler.jl | 400 +--------------------------------------------- src/make_zero.jl | 404 +++++++++++++++++++++++++++++++++++++++++++++++ test/abi.jl | 13 ++ 3 files changed, 418 insertions(+), 399 deletions(-) create mode 100644 src/make_zero.jl diff --git a/src/compiler.jl b/src/compiler.jl index 1b11da719e..16d54481cf 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1803,405 +1803,7 @@ function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, N) allocate_sret!(B, N) end -@inline function EnzymeCore.make_zero(x::FT)::FT where {FT<:AbstractFloat} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero(x::Complex{FT})::Complex{FT} where {FT<:AbstractFloat} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero( - x::Array{FT,N}, -)::Array{FT,N} where {FT<:AbstractFloat,N} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero( - x::Array{Complex{FT},N}, -)::Array{Complex{FT},N} where {FT<:AbstractFloat,N} - return Base.zero(x) -end - -@inline function EnzymeCore.make_zero( - ::Type{Array{FT,N}}, - seen::IdDict, - prev::Array{FT,N}, - ::Val{copy_if_inactive} = Val(false), -)::Array{FT,N} where {copy_if_inactive,FT<:AbstractFloat,N} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end -@inline function EnzymeCore.make_zero( - ::Type{Array{Complex{FT},N}}, - seen::IdDict, - prev::Array{Complex{FT},N}, - ::Val{copy_if_inactive} = Val(false), -)::Array{Complex{FT},N} where {copy_if_inactive,FT<:AbstractFloat,N} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:AbstractFloat} - return RT(0) -end - -@inline function EnzymeCore.make_zero( - ::Type{Complex{RT}}, - seen::IdDict, - prev::Complex{RT}, - ::Val{copy_if_inactive} = Val(false), -)::Complex{RT} where {copy_if_inactive,RT<:AbstractFloat} - return RT(0) -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:Array} - if haskey(seen, prev) - return seen[prev] - end - if guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - newa = RT(undef, size(prev)) - seen[prev] = newa - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - innerty = Core.Typeof(pv) - @inbounds newa[I] = - EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) - end - end - return newa -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:Tuple} - return ntuple(length(prev)) do i - Base.@_inline_meta - EnzymeCore.make_zero(RT.parameters[i], seen, prev[i], Val(copy_if_inactive)) - end -end - -@inline function EnzymeCore.make_zero( - ::Type{NamedTuple{A,RT}}, - seen::IdDict, - prev::NamedTuple{A,RT}, - ::Val{copy_if_inactive} = Val(false), -)::NamedTuple{A,RT} where {copy_if_inactive,A,RT} - return NamedTuple{A,RT}(EnzymeCore.make_zero(RT, seen, RT(prev), Val(copy_if_inactive))) -end - -@inline function EnzymeCore.make_zero( - ::Type{Core.Box}, - seen::IdDict, - prev::Core.Box, - ::Val{copy_if_inactive} = Val(false), -) where {copy_if_inactive} - if haskey(seen, prev) - return seen[prev] - end - prev2 = prev.contents - res = Core.Box() - seen[prev] = res - res.contents = Base.Ref( - EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)), - ) - return res -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT} - if guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - if haskey(seen, prev) - return seen[prev] - end - @assert !Base.isabstracttype(RT) - @assert Base.isconcretetype(RT) - nf = fieldcount(RT) - - if ismutable(prev) - y = ccall(:jl_new_struct_uninit, Any, (Any,), RT) - seen[prev] = y - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - T = Core.Typeof(xi) - xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) - setfield!(y, i, xi) - end - end - return y - end - - if nf == 0 - return prev - end - - flds = Vector{Any}(undef, nf) - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - xi = EnzymeCore.make_zero(Core.Typeof(xi), seen, xi, Val(copy_if_inactive)) - flds[i] = xi - else - nf = i - 1 # rest of tail must be undefined values - break - end - end - y = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf) - seen[prev] = y - return y -end - -function make_zero_immutable!(prev::T, seen::S)::T where {T<:AbstractFloat,S} - zero(T) -end - -function make_zero_immutable!( - prev::Complex{T}, - seen::S, -)::Complex{T} where {T<:AbstractFloat,S} - zero(T) -end - -function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} - ntuple(Val(length(T.parameters))) do i - Base.@_inline_meta - make_zero_immutable!(prev[i], seen) - end -end - -function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} - NamedTuple{a,b}(ntuple(Val(length(T.parameters))) do i - Base.@_inline_meta - make_zero_immutable!(prev[a[i]], seen) - end) -end - - -function make_zero_immutable!(prev::T, seen::S)::T where {T,S} - if guaranteed_const_nongen(T, nothing) - return prev - end - @assert !ismutable(prev) - - RT = Core.Typeof(prev) - @assert !Base.isabstracttype(RT) - @assert Base.isconcretetype(RT) - nf = fieldcount(RT) - - flds = Vector{Any}(undef, nf) - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - ST = Core.Typeof(xi) - flds[i] = if active_reg_inner(ST, (), nothing, Val(true)) == ActiveState #=justActive=# - make_zero_immutable!(xi, seen) - else - EnzymeCore.make_zero!(xi, seen) - xi - end - else - nf = i - 1 # rest of tail must be undefined values - break - end - end - ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)::T -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, - seen::ST, -)::Nothing where {T<:AbstractFloat,ST} - T[] = zero(T) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{Complex{T}}, - seen::ST, -)::Nothing where {T<:AbstractFloat,ST} - T[] = zero(Complex{T}) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{T,N}, - seen::ST, -)::Nothing where {T<:AbstractFloat,N,ST} - fill!(prev, zero(T)) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{Complex{T},N}, - seen::ST, -)::Nothing where {T<:AbstractFloat,N,ST} - fill!(prev, zero(Complex{T})) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, -)::Nothing where {T<:AbstractFloat} - EnzymeCore.make_zero!(prev, nothing) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{Complex{T}}, -)::Nothing where {T<:AbstractFloat} - EnzymeCore.make_zero!(prev, nothing) - nothing -end - -@inline function EnzymeCore.make_zero!(prev::Array{T,N})::Nothing where {T<:AbstractFloat,N} - EnzymeCore.make_zero!(prev, nothing) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{Complex{T},N}, -)::Nothing where {T<:AbstractFloat,N} - EnzymeCore.make_zero!(prev, nothing) - nothing -end - -@inline function EnzymeCore.make_zero!(prev::Array{T,N}, seen::ST)::Nothing where {T,N,ST} - if guaranteed_const_nongen(T, nothing) - return - end - if in(seen, prev) - return - end - push!(seen, prev) - - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - @inbounds prev[I] = make_zero_immutable!(pv, seen) - nothing - else - EnzymeCore.make_zero!(pv, seen) - nothing - end - end - end - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, - seen::ST, -)::Nothing where {T,ST} - if guaranteed_const_nongen(T, nothing) - return - end - if in(seen, prev) - return - end - push!(seen, prev) - - pv = prev[] - SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - prev[] = make_zero_immutable!(pv, seen) - nothing - else - EnzymeCore.make_zero!(pv, seen) - nothing - end - nothing -end - -@inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST} - pv = prev.contents - T = Core.Typeof(pv) - if guaranteed_const_nongen(T, nothing) - return - end - if in(seen, prev) - return - end - push!(seen, prev) - SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) - nothing - else - EnzymeCore.make_zero!(pv, seen) - nothing - end - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::T, - seen::S = Base.IdSet{Any}(), -)::Nothing where {T,S} - if guaranteed_const_nongen(T, nothing) - return - end - if in(prev, seen) - return - end - @assert !Base.isabstracttype(T) - @assert Base.isconcretetype(T) - nf = fieldcount(T) - - - if nf == 0 - return - end - - push!(seen, prev) - - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - SBT = Core.Typeof(xi) - if guaranteed_const_nongen(SBT, nothing) - continue - end - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - setfield!(prev, i, make_zero_immutable!(xi, seen)) - nothing - else - EnzymeCore.make_zero!(xi, seen) - nothing - end - end - end - return -end +include("make_zero.jl") function emit_error(B::LLVM.IRBuilder, orig, string, errty = EnzymeRuntimeException) curent_bb = position(B) diff --git a/src/make_zero.jl b/src/make_zero.jl new file mode 100644 index 0000000000..4f627581ea --- /dev/null +++ b/src/make_zero.jl @@ -0,0 +1,404 @@ + +@inline function EnzymeCore.make_zero(x::FT)::FT where {FT<:AbstractFloat} + return Base.zero(x) +end +@inline function EnzymeCore.make_zero(x::Complex{FT})::Complex{FT} where {FT<:AbstractFloat} + return Base.zero(x) +end +@inline function EnzymeCore.make_zero( + x::Array{FT,N}, +)::Array{FT,N} where {FT<:AbstractFloat,N} + return Base.zero(x) +end +@inline function EnzymeCore.make_zero( + x::Array{Complex{FT},N}, +)::Array{Complex{FT},N} where {FT<:AbstractFloat,N} + return Base.zero(x) +end + +@inline function EnzymeCore.make_zero( + ::Type{Array{FT,N}}, + seen::IdDict, + prev::Array{FT,N}, + ::Val{copy_if_inactive} = Val(false), +)::Array{FT,N} where {copy_if_inactive,FT<:AbstractFloat,N} + if haskey(seen, prev) + return seen[prev] + end + newa = Base.zero(prev) + seen[prev] = newa + return newa +end +@inline function EnzymeCore.make_zero( + ::Type{Array{Complex{FT},N}}, + seen::IdDict, + prev::Array{Complex{FT},N}, + ::Val{copy_if_inactive} = Val(false), +)::Array{Complex{FT},N} where {copy_if_inactive,FT<:AbstractFloat,N} + if haskey(seen, prev) + return seen[prev] + end + newa = Base.zero(prev) + seen[prev] = newa + return newa +end + +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT<:AbstractFloat} + return RT(0) +end + +@inline function EnzymeCore.make_zero( + ::Type{Complex{RT}}, + seen::IdDict, + prev::Complex{RT}, + ::Val{copy_if_inactive} = Val(false), +)::Complex{RT} where {copy_if_inactive,RT<:AbstractFloat} + return RT(0) +end + +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT<:Array} + if haskey(seen, prev) + return seen[prev] + end + if guaranteed_const_nongen(RT, nothing) + return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev + end + newa = RT(undef, size(prev)) + seen[prev] = newa + for I in eachindex(prev) + if isassigned(prev, I) + pv = prev[I] + innerty = Core.Typeof(pv) + @inbounds newa[I] = + EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) + end + end + return newa +end + +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT<:Tuple} + return ntuple(length(prev)) do i + Base.@_inline_meta + EnzymeCore.make_zero(RT.parameters[i], seen, prev[i], Val(copy_if_inactive)) + end +end + +@inline function EnzymeCore.make_zero( + ::Type{NamedTuple{A,RT}}, + seen::IdDict, + prev::NamedTuple{A,RT}, + ::Val{copy_if_inactive} = Val(false), +)::NamedTuple{A,RT} where {copy_if_inactive,A,RT} + return NamedTuple{A,RT}(EnzymeCore.make_zero(RT, seen, RT(prev), Val(copy_if_inactive))) +end + +@inline function EnzymeCore.make_zero( + ::Type{Core.Box}, + seen::IdDict, + prev::Core.Box, + ::Val{copy_if_inactive} = Val(false), +) where {copy_if_inactive} + if haskey(seen, prev) + return seen[prev] + end + prev2 = prev.contents + res = Core.Box() + seen[prev] = res + res.contents = Base.Ref( + EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)), + ) + return res +end + +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT} + if guaranteed_const_nongen(RT, nothing) + return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev + end + if haskey(seen, prev) + return seen[prev] + end + @assert !Base.isabstracttype(RT) + @assert Base.isconcretetype(RT) + nf = fieldcount(RT) + + if ismutable(prev) + y = ccall(:jl_new_struct_uninit, Any, (Any,), RT)::RT + seen[prev] = y + for i = 1:nf + if isdefined(prev, i) + xi = getfield(prev, i) + T = Core.Typeof(xi) + xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) + if Base.isconst(RT, i) + ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi) + else + setfield!(y, i, xi) + end + end + end + return y + end + + if nf == 0 + return prev + end + + flds = Vector{Any}(undef, nf) + for i = 1:nf + if isdefined(prev, i) + xi = getfield(prev, i) + xi = EnzymeCore.make_zero(Core.Typeof(xi), seen, xi, Val(copy_if_inactive)) + flds[i] = xi + else + nf = i - 1 # rest of tail must be undefined values + break + end + end + y = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf) + seen[prev] = y + return y +end + +function make_zero_immutable!(prev::T, seen::S)::T where {T<:AbstractFloat,S} + zero(T) +end + +function make_zero_immutable!( + prev::Complex{T}, + seen::S, +)::Complex{T} where {T<:AbstractFloat,S} + zero(T) +end + +function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} + ntuple(Val(length(T.parameters))) do i + Base.@_inline_meta + make_zero_immutable!(prev[i], seen) + end +end + +function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} + NamedTuple{a,b}(ntuple(Val(length(T.parameters))) do i + Base.@_inline_meta + make_zero_immutable!(prev[a[i]], seen) + end) +end + + +function make_zero_immutable!(prev::T, seen::S)::T where {T,S} + if guaranteed_const_nongen(T, nothing) + return prev + end + @assert !ismutable(prev) + + RT = Core.Typeof(prev) + @assert !Base.isabstracttype(RT) + @assert Base.isconcretetype(RT) + nf = fieldcount(RT) + + flds = Vector{Any}(undef, nf) + for i = 1:nf + if isdefined(prev, i) + xi = getfield(prev, i) + ST = Core.Typeof(xi) + flds[i] = if active_reg_inner(ST, (), nothing, Val(true)) == ActiveState #=justActive=# + make_zero_immutable!(xi, seen) + else + EnzymeCore.make_zero!(xi, seen) + xi + end + else + nf = i - 1 # rest of tail must be undefined values + break + end + end + ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)::T +end + +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{T}, + seen::ST, +)::Nothing where {T<:AbstractFloat,ST} + T[] = zero(T) + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{Complex{T}}, + seen::ST, +)::Nothing where {T<:AbstractFloat,ST} + T[] = zero(Complex{T}) + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::Array{T,N}, + seen::ST, +)::Nothing where {T<:AbstractFloat,N,ST} + fill!(prev, zero(T)) + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::Array{Complex{T},N}, + seen::ST, +)::Nothing where {T<:AbstractFloat,N,ST} + fill!(prev, zero(Complex{T})) + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{T}, +)::Nothing where {T<:AbstractFloat} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{Complex{T}}, +)::Nothing where {T<:AbstractFloat} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{T,N})::Nothing where {T<:AbstractFloat,N} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::Array{Complex{T},N}, +)::Nothing where {T<:AbstractFloat,N} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{T,N}, seen::ST)::Nothing where {T,N,ST} + if guaranteed_const_nongen(T, nothing) + return + end + if in(seen, prev) + return + end + push!(seen, prev) + + for I in eachindex(prev) + if isassigned(prev, I) + pv = prev[I] + SBT = Core.Typeof(pv) + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + @inbounds prev[I] = make_zero_immutable!(pv, seen) + nothing + else + EnzymeCore.make_zero!(pv, seen) + nothing + end + end + end + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{T}, + seen::ST, +)::Nothing where {T,ST} + if guaranteed_const_nongen(T, nothing) + return + end + if in(seen, prev) + return + end + push!(seen, prev) + + pv = prev[] + SBT = Core.Typeof(pv) + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + prev[] = make_zero_immutable!(pv, seen) + nothing + else + EnzymeCore.make_zero!(pv, seen) + nothing + end + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST} + pv = prev.contents + T = Core.Typeof(pv) + if guaranteed_const_nongen(T, nothing) + return + end + if in(seen, prev) + return + end + push!(seen, prev) + SBT = Core.Typeof(pv) + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) + nothing + else + EnzymeCore.make_zero!(pv, seen) + nothing + end + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::T, + seen::S = Base.IdSet{Any}(), +)::Nothing where {T,S} + if guaranteed_const_nongen(T, nothing) + return + end + if in(prev, seen) + return + end + @assert !Base.isabstracttype(T) + @assert Base.isconcretetype(T) + nf = fieldcount(T) + + + if nf == 0 + return + end + + push!(seen, prev) + + for i = 1:nf + if isdefined(prev, i) + xi = getfield(prev, i) + SBT = Core.Typeof(xi) + if guaranteed_const_nongen(SBT, nothing) + continue + end + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + setfield!(prev, i, make_zero_immutable!(xi, seen)) + nothing + else + EnzymeCore.make_zero!(xi, seen) + nothing + end + end + end + return +end diff --git a/test/abi.jl b/test/abi.jl index 5acb30e04f..7a7917553f 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -480,6 +480,19 @@ mulsin(x) = sin(x[1] * x[2]) @test Enzyme.autodiff(ForwardWithPrimal, () -> Enzyme.within_autodiff())[1] end +mutable struct ConstVal + x::Float64 + const y::Float64 +end + +@testset "Make Zero" begin + v = ConstVal(2.0, 3.0) + dv = make_zero(v) + @test dv isa ConstVal + @test dv.x ≈ 0.0 + @test dv.y ≈ 0.0 +end + @testset "Type inference" begin x = ones(10) @inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x)) From 288a419d464bd423218e05e6996463d8c98bce42 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 29 Sep 2024 19:45:09 -0500 Subject: [PATCH 338/495] Fix pass manager bug which allows functions to be deleted and replaced (#1924) * Fix pass manager bug which allows functions to be deleted and replaced * fix --- Project.toml | 2 +- src/compiler/optimize.jl | 41 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 87ab945958..b70795fca1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.5" +version = "0.13.6" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index d11daaa0b3..cc143ce4f2 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -2410,8 +2410,35 @@ function optimize!(mod::LLVM.Module, tm) mem_cpy_opt!(pm) always_inliner!(pm) alloc_opt_tm!(pm, tm) + LLVM.run!(pm, mod) + end + + # Globalopt is separated as it can delete functions, which invalidates the Julia hardcoded pointers to + # known functions + ModulePassManager() do pm + + add_library_info!(pm, triple(mod)) + add_transform_info!(pm, tm) + + scoped_no_alias_aa!(pm) + type_based_alias_analysis!(pm) + basic_alias_analysis!(pm) + cpu_features_tm!(pm, tm) + LLVM.API.LLVMAddGlobalOptimizerPass(pm) # Extra gvn!(pm) # Extra + LLVM.run!(pm, mod) + end + + ModulePassManager() do pm + add_library_info!(pm, triple(mod)) + add_transform_info!(pm, tm) + + scoped_no_alias_aa!(pm) + type_based_alias_analysis!(pm) + basic_alias_analysis!(pm) + cpu_features_tm!(pm, tm) + instruction_combining!(pm) jl_inst_simplify!(pm) cfgsimplification!(pm) @@ -2473,6 +2500,20 @@ function optimize!(mod::LLVM.Module, tm) cfgsimplification!(pm) instruction_combining!(pm) # Extra for Enzyme jl_inst_simplify!(pm) + LLVM.run!(pm, mod) + end + + # Globalopt is separated as it can delete functions, which invalidates the Julia hardcoded pointers to + # known functions + ModulePassManager() do pm + add_library_info!(pm, triple(mod)) + add_transform_info!(pm, tm) + + scoped_no_alias_aa!(pm) + type_based_alias_analysis!(pm) + basic_alias_analysis!(pm) + cpu_features_tm!(pm, tm) + LLVM.API.LLVMAddGlobalOptimizerPass(pm) # Exxtra gvn!(pm) # Exxtra LLVM.run!(pm, mod) From dad67bfc3913f4eb66126d7c186a59b0c1f18586 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 29 Sep 2024 22:39:20 -0500 Subject: [PATCH 339/495] Union member type info (#1927) * Union member type info * fix * fix --- src/typetree.jl | 7 ++++++- test/typetree.jl | 35 +++++++++++++++++++++++++---------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/src/typetree.jl b/src/typetree.jl index c886c683ce..61d700acb8 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -353,12 +353,17 @@ function typetree_inner(@nospecialize(T::Type), ctx, dl, seen::TypeTreeTable) for f = 1:fieldcount(T) offset = fieldoffset(T, f) subT = fieldtype(T, f) - subtree = copy(typetree(subT, ctx, dl, seen)) if subT isa UnionAll || subT isa Union || subT == Union{} + if !allocatedinline(subT) + subtree = TypeTree(API.DT_Pointer, offset, ctx) + merge!(tt, subtree) + end # FIXME: Handle union continue end + + subtree = copy(typetree(subT, ctx, dl, seen)) # Allocated inline so adjust first path if allocatedinline(subT) diff --git a/test/typetree.jl b/test/typetree.jl index 1a869d6687..3b47161f62 100644 --- a/test/typetree.jl +++ b/test/typetree.jl @@ -37,6 +37,12 @@ struct Sibling2{T} b::T end +struct UnionMember + a::Float32 + b::Union{Function, Number} + c::Bool +end + @testset "TypeTree" begin @test tt(Float16) == "{[-1]:Float@half}" @test tt(Float32) == "{[-1]:Float@float}" @@ -55,28 +61,31 @@ end @test at2.z == 0.0 @test at2.type == 4 + if Sys.WORD_SIZE == 64 - @test tt(LList2{Float64}) == "{[8]:Float@double}" - @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,8]:Float@double}" + @test tt(UnionMember) == "{[0]:Float@float, [8]:Pointer, [16]:Integer}" + @test tt(LList2{Float64}) == "{[0]:Pointer, [8]:Float@double}" + @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Float@double}" @test tt(Sibling2{LList2{Float64}}) == - "{[0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@double}" + "{[0]:Pointer, [0,0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,0]:Pointer, [16,8]:Float@double}" @test tt(Sibling{Tuple{Int,Float64}}) == "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Float@double, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Float@double}" @test tt(Sibling{LList2{Tuple{Int,Float64}}}) == - "{[-1]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}" + "{[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}" @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == - "{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" + "{[0]:Pointer, [0,0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,0]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,0]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,0]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" else - @test tt(LList2{Float64}) == "{[4]:Float@double}" - @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,4]:Float@double}" + @test tt(UnionMember) == "{[0]:Float@float, [4]:Pointer, [8]:Integer}" + @test tt(LList2{Float64}) == "{[0]:Pointer, [4]:Float@double}" + @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,0]:Pointer, [-1,4]:Float@double}" @test tt(Sibling2{LList2{Float64}}) == - "{[0]:Pointer, [0,4]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@double}" + "{[0]:Pointer, [0,0]:Pointer, [0,4]:Float@double, [4]:Integer, [8]:Pointer, [8,0]:Pointer, [8,4]:Float@double}" @test tt(Sibling{Tuple{Int,Float64}}) == "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Float@double, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Float@double}" @test tt(Sibling{LList2{Tuple{Int,Float64}}}) == - "{[-1]:Pointer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Float@double}" + "{[-1]:Pointer, [-1,0]:Pointer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Float@double}" @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == - "{[0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,4]:Float@float, [24,8]:Float@double}" + "{[0]:Pointer, [0,0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,0]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,0]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,0]:Pointer, [24,4]:Float@float, [24,8]:Float@double}" end end @@ -91,4 +100,10 @@ end @test Enzyme.get_offsets(Ptr{Float32}) == ((Enzyme.API.DT_Pointer,0),) @test Enzyme.get_offsets(Vector{Float32}) == ((Enzyme.API.DT_Pointer,0),) @test Enzyme.get_offsets(Tuple{Float64, Int}) == [(Enzyme.API.DT_Double,0),(Enzyme.API.DT_Integer, 8)] + + if Sys.WORD_SIZE == 64 + @test Enzyme.get_offsets(UnionMember) == [(Enzyme.API.DT_Float,0),(Enzyme.API.DT_Pointer, 8), (Enzyme.API.DT_Integer, 16)] + else + @test Enzyme.get_offsets(UnionMember) == [(Enzyme.API.DT_Float, 0), (Enzyme.API.DT_Pointer, 4), (Enzyme.API.DT_Integer, 8)] + end end From d97bb83cdfdf55bf0abdbcbdfc6cf61d7f062d01 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 29 Sep 2024 22:39:31 -0500 Subject: [PATCH 340/495] Stabilize global (#1928) --- src/rules/jitrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index f169ab2c4b..bf98aaf885 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -458,7 +458,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) end internal_tape, origRet, initShadow, annotation = if f isa typeof(Core.getglobal) - gv = Core.getglobal(map(x->x.val, args)...) + gv = Core.getglobal(args[1].val, args[2].val) @assert sizeof(gv) == 0 (nothing, gv, nothing, Const) else From 66ef0f3566fae94f6af8da6f91dae7e36179cba6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 30 Sep 2024 13:37:50 -0500 Subject: [PATCH 341/495] Fix error exception (#1931) * Fix error exception * fix --------- Co-authored-by: William Moses --- src/Enzyme.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index aa018ea23b..17a7c6ff5d 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -682,15 +682,19 @@ code, as well as high-order differentiation. if A isa UnionAll rt = Compiler.primal_return_type(rmode, Val(world), FTy, tt) - A2 = A{rt} + rt = Core.Compiler.return_type(f.val, tt) + A2 = A{rt} + if rt == Union{} + throw(ErrorException("Return type inferred to be Union{}. Giving up.")) + end else @assert A isa DataType rt = A + if rt == Union{} + throw(ErrorException("Return type inferred to be Union{}. Giving up.")) + end end - if rt == Union{} - error("Return type inferred to be Union{}. Giving up.") - end ModifiedBetweenT = falses_from_args(Nargs + 1) ModifiedBetween = Val(ModifiedBetweenT) From 1bc2ce18f0999740afc1b8f409ff370bc1b34dc4 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 30 Sep 2024 13:46:14 -0500 Subject: [PATCH 342/495] Update Project.toml (#1932) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b70795fca1..f4a342a6d5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.6" +version = "0.13.7" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e Mon Sep 17 00:00:00 2001 From: Vaibhav Kumar Dixit Date: Wed, 2 Oct 2024 13:40:25 -0400 Subject: [PATCH 343/495] Update DuplicatedNoNeed error message (#1933) * Update DuplicatedNoNeed error message * Update src/Enzyme.jl --- src/Enzyme.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 17a7c6ff5d..2e7789d660 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -600,7 +600,7 @@ f(x) = x*x if A <: DuplicatedNoNeed || A <: BatchDuplicatedNoNeed throw( ErrorException( - "Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)", + "`DuplicatedNoNeed` passed in as return activity for Forward Mode AD is no longer returning or avoiding the primal.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)", ), ) end From 00dd3167069f2f55071f51e023ae3f6f6a09cb92 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 8 Oct 2024 21:38:46 -0500 Subject: [PATCH 344/495] Fix range step (#1945) * Fix range step * fix * cleanup * fix * Update internal_rules.jl --- src/internal_rules.jl | 106 ++++++++++++++++++++++++++++++++++++++++- test/internal_rules.jl | 37 ++++++++++++++ 2 files changed, 141 insertions(+), 2 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 53ca1f9283..6fe70df8cf 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -1121,6 +1121,110 @@ function EnzymeRules.forward( end end +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + func::Const{typeof(Base.range_start_stop_length)}, + RT, + start::Annotation{T}, + stop::Annotation{T}, + len::Annotation{<:Integer}, +) where T <: Base.IEEEFloat + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + return Duplicated( + func.val(start.val, stop.val, len.val), + func.val( + start isa Const ? zero(start.val) : -start.dval, + stop isa Const ? zero(stop.val) : stop.dval, + len.val) + ) + else + return BatchDuplicated( + func.val(start.val, stop.val, len.val), + ntuple( + i -> func.val( + start isa Const ? zero(start.val) : -start.dval[i], + stop isa Const ? zero(stop.val) : stop.dval[i], + len.val, + ), + Val(EnzymeRules.width(config)), + ), + ) + end + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + return func.val( + start isa Const ? zero(start.val) : -start.dval, + stop isa Const ? zero(stop.val) : stop.dval, + len.val) + else + return ntuple( + i -> func.val( + start isa Const ? zero(start.val) : -start.dval[i], + stop isa Const ? zero(stop.val) : stop.dval[i], + len.val, + ), + Val(EnzymeRules.width(config)), + ) + end + elseif EnzymeRules.needs_primal(config) + return func.val(start.val, stop.val, len.val) + else + return nothing + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + func::Const{typeof(Base.range_start_stop_length)}, + ::Type{RT}, + start::Annotation{T}, + stop::Annotation{T}, + len::Annotation{<:Base.Integer}, +) where {RT, T <: Base.IEEEFloat} + if EnzymeRules.needs_primal(config) + primal = func.val(start.val, stop.val, len.val) + else + primal = nothing + end + return EnzymeRules.AugmentedReturn(primal, nothing, nothing) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + func::Const{typeof(Base.range_start_stop_length)}, + dret, + tape, + start::Annotation{T}, + stop::Annotation{T}, + len::Annotation{T3}, +) where {T <: Base.IEEEFloat, T3<:Integer} + dstart = if start isa Const + nothing + elseif EnzymeRules.width(config) == 1 + T(dret.val.ref.hi) - T(dret.val.step.hi) / (len.val - 1) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + T(dret.val[i].ref.hi) - T(dret.val[i].step.hi) / (len.val - 1) + end + end + + dstop = if stop isa Const + nothing + elseif EnzymeRules.width(config) == 1 + T(dret.val.step.hi) / (len.val - 1) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + T(dret.val[i].step.hi) / (len.val - 1) + end + end + + return (dstart, dstop, nothing) +end + + # Ranges # Float64 ranges in Julia use bitwise `&` with higher precision # to correct for numerical error, thus we put rules over the @@ -1196,8 +1300,6 @@ function EnzymeRules.forward( end end - - function EnzymeRules.augmented_primal( config::EnzymeRules.RevConfig, func::Const{Colon}, diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 246929272b..3635ce07e2 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -630,7 +630,44 @@ end @test autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) == ((1.0,),) end + @testset "Ranges" begin + function f1(x) + x = 25.0x + ts = Array(Base.range_start_stop_length(0.0, x, 30)) + return sum(ts) + end + function f2(x) + x = 25.0x + ts = Array(Base.range_start_stop_length(0.0, 0.25, 30)) + return sum(ts) + x + end + function f3(x) + ts = Array(Base.range_start_stop_length(x, 1.25, 30)) + return sum(ts) + end + @test Enzyme.autodiff(Forward, f1, Duplicated(0.1, 1.0)) == (374.99999999999994,) + @test Enzyme.autodiff(Forward, f2, Duplicated(0.1, 1.0)) == (25.0,) + @test Enzyme.autodiff(Forward, f3, Duplicated(0.1, 1.0)) == (15.0,) + + @test Enzyme.autodiff(Forward, f1, BatchDuplicated(0.1, (1.0, 2.0))) == + ((var"1" = 374.99999999999994, var"2" = 749.9999999999999),) + @test Enzyme.autodiff(Forward, f2, BatchDuplicated(0.1, (1.0, 2.0))) == + ((var"1"=25.0, var"2"=50.0),) + @test Enzyme.autodiff(Forward, f3, BatchDuplicated(0.1, (1.0, 2.0))) == + ((var"1"=15.0, var"2"=30.0),) + + @test Enzyme.autodiff(Reverse, f1, Active, Active(0.1)) == ((375.0,),) + @test Enzyme.autodiff(Reverse, f2, Active, Active(0.1)) == ((25.0,),) + @test Enzyme.autodiff(Reverse, f3, Active, Active(0.1)) == ((15.0,),) + + # Batch active rule isnt setup + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f1(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((375.0,750.0)),) + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f2(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((25.0,50.0)),) + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f3(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((15.0,30.0)),) +end + +@testset "Ranges 2" begin function f1(x) x = 25.0x ts = Array(0.0:x:3.0) From 4d4c546dbafb048f7fe73195925bbd81a8d1105a Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 8 Oct 2024 21:39:10 -0500 Subject: [PATCH 345/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f4a342a6d5..ddb8f163f4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.7" +version = "0.13.8" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From ad86689a8a6d7fd895f8427b3a7b977602d7828e Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 10 Oct 2024 12:40:14 -0500 Subject: [PATCH 346/495] Multi arg fwd gradient (#1952) * Multi arg fwd gradient * multi arg deriv * fix * fix * Update Enzyme.jl * cleanup * fix * Update Enzyme.jl * Update Enzyme.jl --- src/Enzyme.jl | 322 +++++++++++++++-------- test/runtests.jl | 469 +--------------------------------- test/sugar.jl | 646 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 869 insertions(+), 568 deletions(-) create mode 100644 test/sugar.jl diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 2e7789d660..598dc872e9 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1794,16 +1794,30 @@ end @inline tupleconcat(x, y) = (x..., y...) @inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...) -function create_shadows(::Nothing, x) - return (onehot(x),) -end - -function create_shadows(::Val{1}, x) - return (onehot(x),) -end - -function create_shadows(::Val{chunk}, x) where {chunk} - return (chunkedonehot(x, Val(chunk)),) +@generated function create_shadows(chunk::ChunkTy, x::X, vargs::Vararg{Any,N}) where {ChunkTy, X, N} + args = Union{Symbol,Expr}[:x] + tys = Type[X] + for i in 1:N + push!(args, :(vargs[$i])) + push!(tys, vargs[i]) + end + + exprs = Union{Symbol,Expr}[] + for (arg, ty) in zip(args, tys) + if ty <: Enzyme.Const + push!(exprs, :(nothing)) + elseif ty <: AbstractFloat + push!(exprs, :(nothing)) + elseif ChunkTy == Nothing || ChunkTy == Val{1} + push!(exprs, :(onehot($arg))) + else + push!(exprs, :(chunkedonehot($arg, chunk))) + end + end + return quote + Base.@_inline_meta + ($(exprs...),) + end end struct TupleArray{T,Shape,Length,N} <: AbstractArray{T,N} @@ -1890,7 +1904,7 @@ gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=Val(1)) (derivs = ([3.0, 2.0],), val = 6.0) ``` -For functions which return an AbstractArray or scalar, this function will return an AbstracttArray +For functions which return an AbstractArray or scalar, this function will return an AbstractArray whose shape is `(size(output)..., size(input)...)`. No guarantees are presently made about the type of the AbstractArray returned by this function (which may or may not be the same as the input AbstractArray if provided). @@ -1905,119 +1919,227 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) # output ([3.0 2.0 0.0; 0.0 1.0 1.0],) ``` + +This function supports multiple arguments and computes the gradient with respect to each + +```jldoctest gradfwd2 +mul(x, y) = x[1]*y[2] + x[2]*y[1] + +gradient(Forward, mul, [2.0, 3.0], [2.7, 3.1]) + +# output + +([3.1, 2.7], [3.0, 2.0]) +``` + +This includes the ability to mark some arguments as `Const` if its derivative is not needed, returning nothing in the corresponding derivative map. + +```jldoctest gradfwd2 +gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1])) + +# output + +([3.1, 2.7], nothing) +``` """ -@inline function gradient( +@generated function gradient( fm::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, - f, - x; + f::F, + x::ty_0, + args::Vararg{Any,N}; chunk::CS = nothing, - shadows = create_shadows(chunk, x), -) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,CS} - if length(shadows[1]) == 0 - return if ReturnPrimal - (; derivs = (x,), val = f(x.val)) + shadows::ST = create_shadows(chunk, x, args...), +) where {F, ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,CS,ST, ty_0, N} + + syms = Union{Symbol,Expr}[:x] + shads = Union{Symbol,Expr}[:(shadows[1])] + tys = Type[ty_0] + for i in 1:N + push!(syms, :(args[$i])) + push!(tys, args[i]) + push!(shads, :(shadows[1+$i])) + end + fval = if F <: Annotation + :(f.val) + else + :f + end + + vals = Union{Symbol,Expr}[] + consts = Union{Symbol,Expr}[] + for (arg, ty) in zip(syms, tys) + if ty <: Const + push!(vals, :($arg.val)) + push!(consts, arg) else - (x,) + push!(vals, arg) + push!(consts, :(Const($arg))) end end - if chunk == Val(0) - throw(ErrorException("Cannot differentiate with a batch size of 0")) + + if CS == Val{0} + return quote + Base.@_inline_meta + throw(ErrorException("Cannot differentiate with a batch size of 0")) + end end - gradtup = if chunk == nothing - resp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1])) + exprs = Union{Symbol,Expr}[] + primal = nothing + derivatives = Union{Symbol,Expr}[] - res = values(resp[1]) - dres = if x isa AbstractFloat - res[1] - else - res + primmode = :(fm) + for (i, (arg, ty)) in enumerate(zip(syms, tys)) + if ty <: Const + push!(derivatives, :(nothing)) + continue end - if ReturnPrimal - ((dres,), resp[2]) - else - (dres,) - end - elseif chunk == Val(1) - if ReturnPrimal - rp = autodiff(fm, f, Duplicated, Duplicated(x, shadows[1][1])) - dres1 = rp[1] - fm2 = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() #=ReturnPrimal=# - res = ntuple(length(shadows[1]) - 1) do i - autodiff(fm2, f, Duplicated, Duplicated(x, shadows[1][i+1]))[1] + argnum = length(ST.parameters[i].parameters) + + argderivative = if ty <: AbstractFloat + dargs = Union{Symbol,Expr}[] + for (j, arg2) in enumerate(syms) + if i == j + push!(dargs, :(Duplicated($arg, one($arg)))) + else + push!(dargs, consts[j]) + end end - gres = if x isa AbstractFloat - dres1[1] - else - (dres1, res...) + + resp = Symbol("resp_$i") + push!(exprs, quote + $resp = autodiff($primmode, f, Duplicated, $(dargs...)) + end) + if ReturnPrimal && primal == nothing + primal = :($resp[2]) + primmode = NoPrimal(fm()) end - ((gres,), rp[2]) - else - res = ntuple(length(shadows[1])) do i - autodiff(fm, f, Duplicated, Duplicated(x, shadows[1][i]))[1] + + :($resp[1]) + elseif argnum == 0 + vals[i] + elseif CS == Nothing + dargs = Union{Symbol,Expr}[] + for (j, arg2) in enumerate(syms) + if i == j + push!(dargs, :(BatchDuplicated($arg, $(shads[i])))) + else + push!(dargs, consts[j]) + end end - (if x isa AbstractFloat - res[1] - else - res - end,) - end - else - if ReturnPrimal - rp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1][1])) - dres1 = values(rp[1]) - gres = if x isa AbstractFloat - dres1[1] - else - fm2 = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() #=ReturnPrimal=# - tmp = ntuple(length(shadows[1]) - 1) do i - values( - autodiff( - fm2, - f, - BatchDuplicated, - BatchDuplicated(x, shadows[1][i+1]), - )[1], - ) + + df = :f + if F <: Enzyme.Duplicated + zeros = Expr[] + for i in 1:argnum + push!(zeros, :(f.dval)) end - tupleconcat(dres1, tmp...) + df = :(BatchDuplicated(f.val, ($(zeros...),) )) + end + + resp = Symbol("resp_$i") + push!(exprs, quote + $resp = autodiff($primmode, $df, BatchDuplicated, $(dargs...)) + end) + if ReturnPrimal && primal == nothing + primal = :($resp[2]) + primmode = NoPrimal(fm()) end - ((gres,), rp[2]) + + :(values($resp[1])) + elseif CS == Val{1} + subderivatives = Union{Symbol,Expr}[] + for an in 1:argnum + dargs = Union{Symbol,Expr}[] + for (j, arg2) in enumerate(syms) + if i == j + push!(dargs, :(Duplicated($arg, $(shads[i])[$an]))) + else + push!(dargs, consts[j]) + end + end + + resp = Symbol("resp_$i"*"_"*string(an)) + push!(exprs, quote + $resp = autodiff($primmode, f, Duplicated, $(dargs...)) + end) + if ReturnPrimal && primal == nothing + primal = :($resp[2]) + primmode = NoPrimal(fm()) + end + + push!(subderivatives, :(values($resp[1]))) + end + :(($(subderivatives...),)) else - tmp = ntuple(length(shadows[1])) do i - values(autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1][i]))[1]) + subderivatives = Union{Symbol,Expr}[] + for an in 1:argnum + dargs = Union{Symbol,Expr}[] + for (j, arg2) in enumerate(syms) + if i == j + push!(dargs, :(BatchDuplicated($arg, $(shads[i])[$an]))) + else + push!(dargs, consts[j]) + end + end + + resp = Symbol("resp_$i"*"_"*string(an)) + push!(exprs, quote + $resp = autodiff($primmode, f, BatchDuplicated, $(dargs...)) + end) + if ReturnPrimal && primal == nothing + primal = :($resp[2]) + primmode = NoPrimal(fm()) + end + + push!(subderivatives, :(values($resp[1]))) end - res = tupleconcat(tmp...) - (if x isa AbstractFloat - res[1] + :(tupleconcat($(subderivatives...))) + end + + deriv = if ty <: AbstractFloat + argderivative + else + tmp = Symbol("tmp_$i") + push!(exprs, :($tmp = $argderivative)) + if ty <: AbstractArray + if argnum > 0 + quote + if $tmp[1] isa AbstractArray + inshape = size($(vals[1])) + outshape = size($tmp[1]) + # st : outshape x total inputs + tupstack($tmp, outshape, inshape) + else + TupleArray($tmp, size($arg)) + end + end + else + :(TupleArray($tmp, size($arg))) + end else - res - end,) + tmp + end end + push!(derivatives, deriv) end - cols = if ReturnPrimal - gradtup[1][1] - else - gradtup[1] - end - res = if x isa AbstractFloat - cols - elseif length(cols) > 0 && cols[1] isa AbstractArray && x isa AbstractArray - inshape = size(x) - outshape = size(cols[1]) - # st : outshape x total inputs - tupstack(cols, outshape, inshape) - elseif x isa AbstractArray - TupleArray(cols, size(x)) - else - cols + # We weirdly asked for no derivatives + if ReturnPrimal && primal == nothing + primal = :($fval($(vals...))) end - if ReturnPrimal - (; derivs = (res,), val = gradtup[2]) + + result = if ReturnPrimal + :((; derivs = ($(derivatives...),), val = $primal)) else - (res,) + :(($(derivatives...),)) + end + + return quote + Base.@_inline_meta + $(exprs...) + $result end end diff --git a/test/runtests.jl b/test/runtests.jl index 902b9e4f65..c3856aabf1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,15 +16,6 @@ using InlineStrings using Enzyme_jll @info "Testing against" Enzyme_jll.libEnzyme -# symbol is \simeq -# this is basically a more flexible version of ≈ -(≃)(a, b) = (≈)(a, b) -(≃)(a::Tuple, b::Tuple) = all(xy -> xy[1] ≃ xy[2], zip(a,b)) -function (≃)(a::AbstractArray{<:Tuple}, b::AbstractArray{<:Tuple}) - size(a) == size(b) || return false - all(xy -> xy[1] ≃ xy[2], zip(a,b)) -end - function isapproxfn(fn, args...; kwargs...) isapprox(args...; kwargs...) end @@ -2938,465 +2929,7 @@ end @test dx ≈ [-1.0, 43.74, 0] end - -# these are used in gradient and jacobian tests -struct InpStruct - i1::Float64 - i2::Float64 - i3::Float64 -end -struct OutStruct - i1::Float64 - i2::Float64 - i3::Float64 -end - -for A ∈ (:InpStruct, :OutStruct) - @eval (≃)(a::$A, b::$A) = (a.i1 ≃ b.i1) && (a.i2 ≃ b.i2) && (a.i3 ≃ b.i3) - @eval function (≃)(a::AbstractArray{<:$A}, b::AbstractArray{<:$A}) - size(a) == size(b) || return false - all(xy -> xy[1] ≃ xy[2], zip(a, b)) - end -end - - -#NOTE: this is needed because of problems with hvcat on 1.10 and something inexplicable on 1.6 -# suffice it to say it's not good that this is required, please remove when possible -mkarray(sz, args...) = reshape(vcat(args...), sz) - -@testset "Gradient and Jacobian Outputs" begin - - scalar = 3.0 - - # ∂ scalar / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, x -> x^2, scalar)[1] ≈ 6.0 - @test Enzyme.gradient(Enzyme.Reverse, x -> x^2, scalar)[1] ≈ 6.0 - @test Enzyme.jacobian(Enzyme.Forward, x -> x^2, scalar)[1] ≈ 6.0 - @test Enzyme.jacobian(Enzyme.Reverse, x -> x^2, scalar)[1] ≈ 6.0 - @test Enzyme.gradient(Enzyme.Forward, x -> 2*x, scalar)[1] ≈ 2.0 - @test Enzyme.gradient(Enzyme.Reverse, x -> 2*x, scalar)[1] ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Forward, x -> 2*x, scalar)[1] ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Reverse, x -> 2*x, scalar)[1] ≈ 2.0 - - # ∂ vector / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] - - @test Enzyme.jacobian(Enzyme.Forward, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] - @test Enzyme.jacobian(Enzyme.Reverse, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] - - - # ∂ tuple / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (2*x, x^2), scalar)[1] ≈ [2.0, 6.0] - - @test Enzyme.jacobian(Enzyme.Forward, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) - - mkarray1 = x -> mkarray((2,2),2*x,sin(x),x^2,exp(x)) - - # ∂ matrix / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] - @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] - - @test Enzyme.jacobian(Enzyme.Forward, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] - @test Enzyme.jacobian(Enzyme.Reverse, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] - - # ∂ struct / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar)[1] == OutStruct(1.0,2*scalar,3*scalar^2) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar)[1] == (OutStruct(1.0,2.0,3.0),) - @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar)[1] == OutStruct(1.0,2*scalar,3*scalar^2) - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar)[1] == (OutStruct(1.0,2.0,3.0),) - - - - vector = [2.7, 3.1] - - # ∂ scalar / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], vector)[1] ≈ [vector[2],vector[1]] - @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] - @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] - @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] - - - # ∂ vector / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ - [vector[2] vector[1]; -sin(vector[1]) 1.0] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ - [vector[2] vector[1]; -sin(vector[1]) 1.0] - @test Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ - [vector[2] vector[1]; -sin(vector[1]) 1.0] - @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ - [vector[2] vector[1]; -sin(vector[1]) 1.0] - - # ∂ tuple / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≃ - [(vector[2], -sin(vector[1])), (vector[1], 1.0)] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ - ([vector[2], -sin(vector[1])], [vector[1], 1.0]) - @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≃ - [(vector[2], -sin(vector[1])), (vector[1], 1.0)] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] - - mkarray2 = x -> mkarray((2,2), x[1]*x[2], exp(x[2]), cos(x[1])+x[2], x[1]) - - # ∂ matrix / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, mkarray2, vector)[1] ≈ - mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) - @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, vector)[1] - @test Enzyme.jacobian(Enzyme.Forward, mkarray2, vector)[1] ≈ - mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) - @test Enzyme.jacobian(Enzyme.Reverse, mkarray2, vector)[1] ≈ - mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) - - # ∂ struct / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector)[1] ≃ - [OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) - - @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector)[1] ≃ - [OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) - - - - tuplev = (2.7, 3.1) - - # ∂ scalar / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) - @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) - @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) - @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) - - # ∂ vector / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≃ - ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≈ - [tuplev[2] tuplev[1]; -sin(tuplev[1]) 1.0] - @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≃ - [(tuplev[2], tuplev[1]), (-sin(tuplev[1]), 1.0)] - - # ∂ tuple / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≃ - ((vector[2], -sin(vector[1])), (vector[1], 1.0)) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≃ - ((tuplev[2], -sin(tuplev[1])), (tuplev[1], 1.0)) - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ - [tuplev[2] tuplev[1]; -sin(tuplev[1]) 1.0] - - # ∂ matrix / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, mkarray2, tuplev)[1] ≃ - ([tuplev[2] -sin(tuplev[1]); 0.0 1.0], [tuplev[1] 1.0; exp(tuplev[2]) 0.0]) - @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, tuplev)[1] - @test_broken Enzyme.jacobian(Enzyme.Forward, mkarray2, tuplev)[1] ≈ - [tuplev[2] -sin(tuplev[1]); 0.0 1.0;;; tuplev[1] 1.0; exp(tuplev[2]) 0.0] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> mkarray2, tuplev)[1] ≈ - [tuplev[2] -sin(tuplev[1]); 0.0 1.0;;; tuplev[1] 1.0; exp(tuplev[2]) 0.0] - - # ∂ struct / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev)[1] ≃ - (OutStruct(tuplev[2], -sin(tuplev[1]), 0.0), OutStruct(tuplev[1], 1.0, exp(tuplev[2]))) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev)[1] ≃ - [OutStruct(tuplev[2], -sin(tuplev[1]), 0.0), OutStruct(tuplev[1], 1.0, exp(tuplev[2]))] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - - - - matrix = [2.7 3.1; 4.7 5.6] - - # ∂ scalar / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] - @test Enzyme.gradient(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] - @test Enzyme.jacobian(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] - @test Enzyme.jacobian(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] - - # ∂ vector / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ - mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) - @test_broken Enzyme.gradient(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] - # again we can't use array construction syntax because of 1.6 - @test Enzyme.jacobian(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ - mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) - @test Enzyme.jacobian(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ - mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) - - # ∂ tuple / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] ≃ - [(matrix[1,2],0.0) (matrix[1,1],0.0); (0.0,matrix[2,2]) (0.0,matrix[2,1])] - @test_broken Enzyme.gradient(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) - @test Enzyme.jacobian(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] ≃ - [(matrix[1,2],0.0) (matrix[1,1],0.0); (0.0,matrix[2,2]) (0.0,matrix[2,1])] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] - - mkarray3 = x -> mkarray((2,2), x[1,1]*x[1,2], exp(x[1,1])+x[2,2], x[2,1]*x[2,2], sin(x[1,2])+x[2,1]) - - # ∂ matrix / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, mkarray3, matrix)[1] ≈ - mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, - matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) - @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray3, matrix)[1] - # array construction syntax broken on 1.6 - @test Enzyme.jacobian(Enzyme.Forward, mkarray3, matrix)[1] ≈ - mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, - matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) - @test Enzyme.jacobian(Enzyme.Reverse, mkarray3, matrix)[1] ≈ - mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, - matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) - - # ∂ tuple / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] ≃ - [OutStruct(matrix[1,2],0.0, exp(matrix[1,1])) OutStruct(matrix[1,1],0.0,0.0); OutStruct(0.0,matrix[2,2],0.0) OutStruct(0.0,matrix[2,1], 1.0)] - @test_broken Enzyme.gradient(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] - @test Enzyme.jacobian(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] ≃ - [OutStruct(matrix[1,2],0.0, exp(matrix[1,1])) OutStruct(matrix[1,1],0.0,0.0); OutStruct(0.0,matrix[2,2],0.0) OutStruct(0.0,matrix[2,1], 1.0)] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] - - - istruct = InpStruct(2.7, 3.1, 4.7) - - # ∂ scalar / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct)[1] - @test Enzyme.gradient(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct)[1] ≃ InpStruct(istruct.i2, istruct.i1, 1.0) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct)[1] - @test Enzyme.jacobian(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct)[1] ≃ InpStruct(istruct.i2, istruct.i1, 1.0) - - # ∂ vector / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] - @test Enzyme.jacobian(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] ≃ [InpStruct(istruct.i2, istruct.i1, 0.0), InpStruct(1.0, 0.0, -sin(istruct.i3))] - - # ∂ tuple / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] - - mkarray4 = x -> mkarray((2,2), x.i1*x.i2, exp(x.i2), cos(x.i3)+x.i1, x.i1) - - # ∂ matrix / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] - @test Enzyme.jacobian(Enzyme.Reverse, mkarray4, istruct)[1] ≃ - [InpStruct(istruct.i2, istruct.i1, 0.0) InpStruct(1.0, 0.0, -sin(istruct.i3)); - InpStruct(0.0, exp(istruct.i2), 0.0) InpStruct(1.0, 0.0, 0.0)] - - # ∂ struct / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] -end - -@testset "Simple Jacobian" begin - @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0)[1] ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0)[1] ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0])[1] ≈ [4.0, 6.0] - - @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, chunk=Val(1))[1] ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, chunk=Val(1))[1] ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], chunk=Val(1))[1] ≈ [4.0, 6.0] - - @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, chunk=Val(2))[1] ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, chunk=Val(2))[1] ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], chunk=Val(2))[1] ≈ [4.0, 6.0] - - @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)))[1] ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)), chunk=Val(1))[1] ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)), chunk=Val(2))[1] ≈ [1.0, 2.0] - - x = float.(reshape(1:6, 2, 3)) - - fillabs2(x) = [sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x), 1000*sum(abs2, x)] - - jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x)[1] - - @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] - @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] - @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] - @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - - jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, chunk=Val(1))[1] - - @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] - @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] - @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] - @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - - jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, chunk=Val(2))[1] - - @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] - @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] - @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] - @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - - - jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, n_outs=Val((4,)), chunk=Val(1))[1] - - @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] - @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] - @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] - @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - - jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, n_outs=Val((4,)), chunk=Val(2))[1] - - @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] - @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] - @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] - @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - - fillinpabs2(x) = [(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 10*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 100*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 1000*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3)] - - x2 = InpStruct(1.0, 2.0, 3.0) - - jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, n_outs=Val((4,)), chunk=Val(1))[1] - - @test jac[1] == InpStruct(2.0, 4.0, 6.0) - @test jac[2] == InpStruct(20.0, 40.0, 60.0) - @test jac[3] == InpStruct(200.0, 400.0, 600.0) - @test jac[4] == InpStruct(2000.0, 4000.0, 6000.0) - - jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, n_outs=Val((4,)), chunk=Val(2))[1] - - @test jac[1] == InpStruct(2.0, 4.0, 6.0) - @test jac[2] == InpStruct(20.0, 40.0, 60.0) - @test jac[3] == InpStruct(200.0, 400.0, 600.0) - @test jac[4] == InpStruct(2000.0, 4000.0, 6000.0) - - filloutabs2(x) = OutStruct(sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x)) - - jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x)[1] - - @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) - @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) - - @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) - @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) - - @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) - @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) - - jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, chunk=Val(1))[1] - - @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) - @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) - - @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) - @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) - - @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) - @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) - - jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, chunk=Val(2))[1] - - @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) - @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) - - @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) - @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) - - @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) - @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) -end - - -@testset "Jacobian" begin - function inout(v) - [v[2], v[1]*v[1], v[1]*v[1]*v[1]] - end - - jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], n_outs=Val((3,)), chunk=Val(1))[1] - @test size(jac) == (3, 2) - @test jac ≈ [ 0.0 1.0; - 4.0 0.0; - 12.0 0.0] - - jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], chunk=Val(1))[1] - @test size(jac) == (3, 2) - @test jac ≈ [ 0.0 1.0; - 4.0 0.0; - 12.0 0.0] - - @test jac == Enzyme.jacobian(Forward, inout, [2.0, 3.0])[1] - - jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], n_outs=Val((3,)), chunk=Val(2))[1] - @test size(jac) == (3, 2) - @test jac ≈ [ 0.0 1.0; - 4.0 0.0; - 12.0 0.0] - - jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], chunk=Val(2))[1] - @test size(jac) == (3, 2) - @test jac ≈ [ 0.0 1.0; - 4.0 0.0; - 12.0 0.0] - - function f_test_1(A, x) - utmp = A*x[2:end] .+ x[1] - return utmp - end - - function f_test_2(A, x) - utmp = Vector{Float64}(undef, length(x)-1) - utmp .= A*x[2:end] .+ x[1] - return utmp - end - - function f_test_3!(u, A, x) - utmp .= A*x[2:end] .+ x[1] - end - - J_r_1(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_1(A, θ), x, n_outs=Val((5,)))[1] - J_r_2(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_2(A, θ), x, n_outs=Val((5,)))[1] - J_r_3(u, A, x) = Enzyme.jacobian(Reverse, θ -> f_test_3!(u, A, θ), x, n_outs=Val((5,)))[1] - - J_f_1(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_1(A, θ)), x)[1] - J_f_2(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_2(A, θ)), x)[1] - J_f_3(u, A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_3!(u, A, θ)), x)[1] - - x = ones(6) - A = Matrix{Float64}(LinearAlgebra.I, 5, 5) - u = Vector{Float64}(undef, 5) - - @test J_r_1(A, x) == [ - 1.0 1.0 0.0 0.0 0.0 0.0; - 1.0 0.0 1.0 0.0 0.0 0.0; - 1.0 0.0 0.0 1.0 0.0 0.0; - 1.0 0.0 0.0 0.0 1.0 0.0; - 1.0 0.0 0.0 0.0 0.0 1.0; - ] - - @test J_r_2(A, x) == [ - 1.0 1.0 0.0 0.0 0.0 0.0; - 1.0 0.0 1.0 0.0 0.0 0.0; - 1.0 0.0 0.0 1.0 0.0 0.0; - 1.0 0.0 0.0 0.0 1.0 0.0; - 1.0 0.0 0.0 0.0 0.0 1.0; - ] - - @test J_f_1(A, x) == [ - 1.0 1.0 0.0 0.0 0.0 0.0; - 1.0 0.0 1.0 0.0 0.0 0.0; - 1.0 0.0 0.0 1.0 0.0 0.0; - 1.0 0.0 0.0 0.0 1.0 0.0; - 1.0 0.0 0.0 0.0 0.0 1.0; - ] - @test J_f_2(A, x) == [ - 1.0 1.0 0.0 0.0 0.0 0.0; - 1.0 0.0 1.0 0.0 0.0 0.0; - 1.0 0.0 0.0 1.0 0.0 0.0; - 1.0 0.0 0.0 0.0 1.0 0.0; - 1.0 0.0 0.0 0.0 0.0 1.0; - ] - - # @show J_r_3(u, A, x) - # @show J_f_3(u, A, x) -end +include("sugar.jl") @testset "Forward on Reverse" begin diff --git a/test/sugar.jl b/test/sugar.jl new file mode 100644 index 0000000000..c558fd813e --- /dev/null +++ b/test/sugar.jl @@ -0,0 +1,646 @@ +using Enzyme, Test + + +mul_scalar(x, y) = x[1]*y[2] + x[2]*y[1] +mul_vector(x, y) = [x[1]*y[2], x[2]*y[1]] + +@testset "Forward Multi-Arg Gradient" begin + res = gradient(Forward, mul_scalar, [2.0, 3.0], [2.7, 3.1]) + @test res[1] ≈ [3.1, 2.7] + @test res[2] ≈ [3.0, 2.0] + + res = gradient(Forward, mul_scalar, [2.0, 3.0], [2.7, 3.1]; chunk=Val(1)) + @test res[1] ≈ [3.1, 2.7] + @test res[2] ≈ [3.0, 2.0] + + res = gradient(Forward, mul_scalar, [2.0, 3.0], [2.7, 3.1]; chunk=Val(2)) + @test res[1] ≈ [3.1, 2.7] + @test res[2] ≈ [3.0, 2.0] + + res = gradient(ForwardWithPrimal, mul_scalar, [2.0, 3.0], [2.7, 3.1]) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1, 2.7] + @test res.derivs[2] ≈ [3.0, 2.0] + + res = gradient(ForwardWithPrimal, mul_scalar, [2.0, 3.0], [2.7, 3.1]; chunk=Val(1)) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1, 2.7] + @test res.derivs[2] ≈ [3.0, 2.0] + + res = gradient(ForwardWithPrimal, mul_scalar, [2.0, 3.0], [2.7, 3.1]; chunk=Val(2)) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1, 2.7] + @test res.derivs[2] ≈ [3.0, 2.0] + + + + res = gradient(Forward, mul_scalar, Const([2.0, 3.0]), [2.7, 3.1]) + @test res[1] == nothing + @test res[2] ≈ [3.0, 2.0] + + res = gradient(Forward, mul_scalar, Const([2.0, 3.0]), [2.7, 3.1]; chunk=Val(1)) + @test res[1] == nothing + @test res[2] ≈ [3.0, 2.0] + + res = gradient(Forward, mul_scalar, Const([2.0, 3.0]), [2.7, 3.1]; chunk=Val(2)) + @test res[1] == nothing + @test res[2] ≈ [3.0, 2.0] + + res = gradient(ForwardWithPrimal, mul_scalar, Const([2.0, 3.0]), [2.7, 3.1]) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] == nothing + @test res.derivs[2] ≈ [3.0, 2.0] + + res = gradient(ForwardWithPrimal, mul_scalar, Const([2.0, 3.0]), [2.7, 3.1]; chunk=Val(1)) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] == nothing + @test res.derivs[2] ≈ [3.0, 2.0] + + res = gradient(ForwardWithPrimal, mul_scalar, Const([2.0, 3.0]), [2.7, 3.1]; chunk=Val(2)) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] == nothing + @test res.derivs[2] ≈ [3.0, 2.0] + + + res = gradient(Forward, mul_scalar, [2.0, 3.0], Const([2.7, 3.1])) + @test res[1] ≈ [3.1, 2.7] + @test res[2] == nothing + + res = gradient(Forward, mul_scalar, [2.0, 3.0], Const([2.7, 3.1]); chunk=Val(1)) + @test res[1] ≈ [3.1, 2.7] + @test res[2] == nothing + + res = gradient(Forward, mul_scalar, [2.0, 3.0], Const([2.7, 3.1]); chunk=Val(2)) + @test res[1] ≈ [3.1, 2.7] + @test res[2] == nothing + + res = gradient(ForwardWithPrimal, mul_scalar, [2.0, 3.0], Const([2.7, 3.1])) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1, 2.7] + @test res.derivs[2] == nothing + + res = gradient(ForwardWithPrimal, mul_scalar, [2.0, 3.0], Const([2.7, 3.1]); chunk=Val(1)) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1, 2.7] + @test res.derivs[2] == nothing + + res = gradient(ForwardWithPrimal, mul_scalar, [2.0, 3.0], Const([2.7, 3.1]); chunk=Val(2)) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1, 2.7] + @test res.derivs[2] == nothing + + + + res = gradient(Forward, mul_vector, [2.0, 3.0], [2.7, 3.1]) + @test res[1] ≈ [3.1 0.0; 0.0 2.7] + @test res[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(Forward, mul_vector, [2.0, 3.0], [2.7, 3.1]; chunk=Val(1)) + @test res[1] ≈ [3.1 0.0; 0.0 2.7] + @test res[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(Forward, mul_vector, [2.0, 3.0], [2.7, 3.1]; chunk=Val(2)) + @test res[1] ≈ [3.1 0.0; 0.0 2.7] + @test res[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(ForwardWithPrimal, mul_vector, [2.0, 3.0], [2.7, 3.1]) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1 0.0; 0.0 2.7] + @test res.derivs[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(ForwardWithPrimal, mul_vector, [2.0, 3.0], [2.7, 3.1]; chunk=Val(1)) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1 0.0; 0.0 2.7] + @test res.derivs[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(ForwardWithPrimal, mul_vector, [2.0, 3.0], [2.7, 3.1]; chunk=Val(2)) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1 0.0; 0.0 2.7] + @test res.derivs[2] ≈ [0.0 2.0; 3.0 0.0] + + + + res = gradient(Forward, mul_vector, Const([2.0, 3.0]), [2.7, 3.1]) + @test res[1] == nothing + @test res[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(Forward, mul_vector, Const([2.0, 3.0]), [2.7, 3.1]; chunk=Val(1)) + @test res[1] == nothing + @test res[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(Forward, mul_vector, Const([2.0, 3.0]), [2.7, 3.1]; chunk=Val(2)) + @test res[1] == nothing + @test res[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(ForwardWithPrimal, mul_vector, Const([2.0, 3.0]), [2.7, 3.1]) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] == nothing + @test res.derivs[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(ForwardWithPrimal, mul_vector, Const([2.0, 3.0]), [2.7, 3.1]; chunk=Val(1)) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] == nothing + @test res.derivs[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(ForwardWithPrimal, mul_vector, Const([2.0, 3.0]), [2.7, 3.1]; chunk=Val(2)) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] == nothing + @test res.derivs[2] ≈ [0.0 2.0; 3.0 0.0] + + + res = gradient(Forward, mul_vector, [2.0, 3.0], Const([2.7, 3.1])) + @test res[1] ≈ [3.1 0.0; 0.0 2.7] + @test res[2] == nothing + + res = gradient(Forward, mul_vector, [2.0, 3.0], Const([2.7, 3.1]); chunk=Val(1)) + @test res[1] ≈ [3.1 0.0; 0.0 2.7] + @test res[2] == nothing + + res = gradient(Forward, mul_vector, [2.0, 3.0], Const([2.7, 3.1]); chunk=Val(2)) + @test res[1] ≈ [3.1 0.0; 0.0 2.7] + @test res[2] == nothing + + res = gradient(ForwardWithPrimal, mul_vector, [2.0, 3.0], Const([2.7, 3.1])) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1 0.0; 0.0 2.7] + @test res.derivs[2] == nothing + + res = gradient(ForwardWithPrimal, mul_vector, [2.0, 3.0], Const([2.7, 3.1]); chunk=Val(1)) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1 0.0; 0.0 2.7] + @test res.derivs[2] == nothing + + res = gradient(ForwardWithPrimal, mul_vector, [2.0, 3.0], Const([2.7, 3.1]); chunk=Val(2)) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1 0.0; 0.0 2.7] + @test res.derivs[2] == nothing + +end + +# these are used in gradient and jacobian tests +struct InpStruct + i1::Float64 + i2::Float64 + i3::Float64 +end +struct OutStruct + i1::Float64 + i2::Float64 + i3::Float64 +end + +# symbol is \simeq +# this is basically a more flexible version of ≈ +(≃)(a, b) = (≈)(a, b) +(≃)(a::Tuple, b::Tuple) = all(xy -> xy[1] ≃ xy[2], zip(a,b)) +function (≃)(a::AbstractArray{<:Tuple}, b::AbstractArray{<:Tuple}) + size(a) == size(b) || return false + all(xy -> xy[1] ≃ xy[2], zip(a,b)) +end + +for A ∈ (:InpStruct, :OutStruct) + @eval (≃)(a::$A, b::$A) = (a.i1 ≃ b.i1) && (a.i2 ≃ b.i2) && (a.i3 ≃ b.i3) + @eval function (≃)(a::AbstractArray{<:$A}, b::AbstractArray{<:$A}) + size(a) == size(b) || return false + all(xy -> xy[1] ≃ xy[2], zip(a, b)) + end +end + + +#NOTE: this is needed because of problems with hvcat on 1.10 and something inexplicable on 1.6 +# suffice it to say it's not good that this is required, please remove when possible +mkarray(sz, args...) = reshape(vcat(args...), sz) + +@testset "Gradient and Jacobian Outputs" begin + + scalar = 3.0 + + # ∂ scalar / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, x -> x^2, scalar)[1] ≈ 6.0 + @test Enzyme.gradient(Enzyme.Reverse, x -> x^2, scalar)[1] ≈ 6.0 + @test Enzyme.jacobian(Enzyme.Forward, x -> x^2, scalar)[1] ≈ 6.0 + @test Enzyme.jacobian(Enzyme.Reverse, x -> x^2, scalar)[1] ≈ 6.0 + @test Enzyme.gradient(Enzyme.Forward, x -> 2*x, scalar)[1] ≈ 2.0 + @test Enzyme.gradient(Enzyme.Reverse, x -> 2*x, scalar)[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x -> 2*x, scalar)[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Reverse, x -> 2*x, scalar)[1] ≈ 2.0 + + # ∂ vector / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Forward, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] + + + # ∂ tuple / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (2*x, x^2), scalar)[1] ≈ [2.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Forward, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) + + mkarray1 = x -> mkarray((2,2),2*x,sin(x),x^2,exp(x)) + + # ∂ matrix / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] + + @test Enzyme.jacobian(Enzyme.Forward, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] + @test Enzyme.jacobian(Enzyme.Reverse, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] + + # ∂ struct / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar)[1] == OutStruct(1.0,2*scalar,3*scalar^2) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar)[1] == (OutStruct(1.0,2.0,3.0),) + @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar)[1] == OutStruct(1.0,2*scalar,3*scalar^2) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar)[1] == (OutStruct(1.0,2.0,3.0),) + + + + vector = [2.7, 3.1] + + # ∂ scalar / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], vector)[1] ≈ [vector[2],vector[1]] + @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] + @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] + @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] + + + # ∂ vector / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ + [vector[2] vector[1]; -sin(vector[1]) 1.0] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ + [vector[2] vector[1]; -sin(vector[1]) 1.0] + @test Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ + [vector[2] vector[1]; -sin(vector[1]) 1.0] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ + [vector[2] vector[1]; -sin(vector[1]) 1.0] + + # ∂ tuple / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≃ + [(vector[2], -sin(vector[1])), (vector[1], 1.0)] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ + ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≃ + [(vector[2], -sin(vector[1])), (vector[1], 1.0)] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] + + mkarray2 = x -> mkarray((2,2), x[1]*x[2], exp(x[2]), cos(x[1])+x[2], x[1]) + + # ∂ matrix / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, mkarray2, vector)[1] ≈ + mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, vector)[1] + @test Enzyme.jacobian(Enzyme.Forward, mkarray2, vector)[1] ≈ + mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) + @test Enzyme.jacobian(Enzyme.Reverse, mkarray2, vector)[1] ≈ + mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) + + # ∂ struct / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector)[1] ≃ + [OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + + @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector)[1] ≃ + [OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + + + + tuplev = (2.7, 3.1) + + # ∂ scalar / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) + @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) + @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) + @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) + + # ∂ vector / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≃ + ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≈ + [tuplev[2] tuplev[1]; -sin(tuplev[1]) 1.0] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≃ + [(tuplev[2], tuplev[1]), (-sin(tuplev[1]), 1.0)] + + # ∂ tuple / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≃ + ((vector[2], -sin(vector[1])), (vector[1], 1.0)) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≃ + ((tuplev[2], -sin(tuplev[1])), (tuplev[1], 1.0)) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ + [tuplev[2] tuplev[1]; -sin(tuplev[1]) 1.0] + + # ∂ matrix / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, mkarray2, tuplev)[1] ≃ + ([tuplev[2] -sin(tuplev[1]); 0.0 1.0], [tuplev[1] 1.0; exp(tuplev[2]) 0.0]) + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, tuplev)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, mkarray2, tuplev)[1] ≈ + [tuplev[2] -sin(tuplev[1]); 0.0 1.0;;; tuplev[1] 1.0; exp(tuplev[2]) 0.0] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> mkarray2, tuplev)[1] ≈ + [tuplev[2] -sin(tuplev[1]); 0.0 1.0;;; tuplev[1] 1.0; exp(tuplev[2]) 0.0] + + # ∂ struct / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev)[1] ≃ + (OutStruct(tuplev[2], -sin(tuplev[1]), 0.0), OutStruct(tuplev[1], 1.0, exp(tuplev[2]))) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev)[1] ≃ + [OutStruct(tuplev[2], -sin(tuplev[1]), 0.0), OutStruct(tuplev[1], 1.0, exp(tuplev[2]))] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + + + + matrix = [2.7 3.1; 4.7 5.6] + + # ∂ scalar / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.gradient(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.jacobian(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.jacobian(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + + # ∂ vector / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ + mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) + @test_broken Enzyme.gradient(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] + # again we can't use array construction syntax because of 1.6 + @test Enzyme.jacobian(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ + mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) + @test Enzyme.jacobian(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ + mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) + + # ∂ tuple / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] ≃ + [(matrix[1,2],0.0) (matrix[1,1],0.0); (0.0,matrix[2,2]) (0.0,matrix[2,1])] + @test_broken Enzyme.gradient(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) + @test Enzyme.jacobian(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] ≃ + [(matrix[1,2],0.0) (matrix[1,1],0.0); (0.0,matrix[2,2]) (0.0,matrix[2,1])] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] + + mkarray3 = x -> mkarray((2,2), x[1,1]*x[1,2], exp(x[1,1])+x[2,2], x[2,1]*x[2,2], sin(x[1,2])+x[2,1]) + + # ∂ matrix / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, mkarray3, matrix)[1] ≈ + mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, + matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray3, matrix)[1] + # array construction syntax broken on 1.6 + @test Enzyme.jacobian(Enzyme.Forward, mkarray3, matrix)[1] ≈ + mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, + matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) + @test Enzyme.jacobian(Enzyme.Reverse, mkarray3, matrix)[1] ≈ + mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, + matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) + + # ∂ tuple / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] ≃ + [OutStruct(matrix[1,2],0.0, exp(matrix[1,1])) OutStruct(matrix[1,1],0.0,0.0); OutStruct(0.0,matrix[2,2],0.0) OutStruct(0.0,matrix[2,1], 1.0)] + @test_broken Enzyme.gradient(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] + @test Enzyme.jacobian(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] ≃ + [OutStruct(matrix[1,2],0.0, exp(matrix[1,1])) OutStruct(matrix[1,1],0.0,0.0); OutStruct(0.0,matrix[2,2],0.0) OutStruct(0.0,matrix[2,1], 1.0)] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] + + + istruct = InpStruct(2.7, 3.1, 4.7) + + # ∂ scalar / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct)[1] + @test Enzyme.gradient(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct)[1] ≃ InpStruct(istruct.i2, istruct.i1, 1.0) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct)[1] + @test Enzyme.jacobian(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct)[1] ≃ InpStruct(istruct.i2, istruct.i1, 1.0) + + # ∂ vector / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] ≃ [InpStruct(istruct.i2, istruct.i1, 0.0), InpStruct(1.0, 0.0, -sin(istruct.i3))] + + # ∂ tuple / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] + + mkarray4 = x -> mkarray((2,2), x.i1*x.i2, exp(x.i2), cos(x.i3)+x.i1, x.i1) + + # ∂ matrix / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] + @test Enzyme.jacobian(Enzyme.Reverse, mkarray4, istruct)[1] ≃ + [InpStruct(istruct.i2, istruct.i1, 0.0) InpStruct(1.0, 0.0, -sin(istruct.i3)); + InpStruct(0.0, exp(istruct.i2), 0.0) InpStruct(1.0, 0.0, 0.0)] + + # ∂ struct / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] +end + +@testset "Simple Jacobian" begin + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0)[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0)[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0])[1] ≈ [4.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, chunk=Val(1))[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, chunk=Val(1))[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], chunk=Val(1))[1] ≈ [4.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, chunk=Val(2))[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, chunk=Val(2))[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], chunk=Val(2))[1] ≈ [4.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)))[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)), chunk=Val(1))[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)), chunk=Val(2))[1] ≈ [1.0, 2.0] + + x = float.(reshape(1:6, 2, 3)) + + fillabs2(x) = [sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x), 1000*sum(abs2, x)] + + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x)[1] + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, chunk=Val(1))[1] + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, chunk=Val(2))[1] + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + + jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, n_outs=Val((4,)), chunk=Val(1))[1] + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, n_outs=Val((4,)), chunk=Val(2))[1] + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + fillinpabs2(x) = [(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 10*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 100*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 1000*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3)] + + x2 = InpStruct(1.0, 2.0, 3.0) + + jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, n_outs=Val((4,)), chunk=Val(1))[1] + + @test jac[1] == InpStruct(2.0, 4.0, 6.0) + @test jac[2] == InpStruct(20.0, 40.0, 60.0) + @test jac[3] == InpStruct(200.0, 400.0, 600.0) + @test jac[4] == InpStruct(2000.0, 4000.0, 6000.0) + + jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, n_outs=Val((4,)), chunk=Val(2))[1] + + @test jac[1] == InpStruct(2.0, 4.0, 6.0) + @test jac[2] == InpStruct(20.0, 40.0, 60.0) + @test jac[3] == InpStruct(200.0, 400.0, 600.0) + @test jac[4] == InpStruct(2000.0, 4000.0, 6000.0) + + filloutabs2(x) = OutStruct(sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x)) + + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x)[1] + + @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) + @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) + + @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) + @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) + + @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) + @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) + + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, chunk=Val(1))[1] + + @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) + @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) + + @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) + @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) + + @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) + @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) + + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, chunk=Val(2))[1] + + @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) + @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) + + @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) + @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) + + @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) + @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) +end + + +@testset "Jacobian" begin + function inout(v) + [v[2], v[1]*v[1], v[1]*v[1]*v[1]] + end + + jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], n_outs=Val((3,)), chunk=Val(1))[1] + @test size(jac) == (3, 2) + @test jac ≈ [ 0.0 1.0; + 4.0 0.0; + 12.0 0.0] + + jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], chunk=Val(1))[1] + @test size(jac) == (3, 2) + @test jac ≈ [ 0.0 1.0; + 4.0 0.0; + 12.0 0.0] + + @test jac == Enzyme.jacobian(Forward, inout, [2.0, 3.0])[1] + + jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], n_outs=Val((3,)), chunk=Val(2))[1] + @test size(jac) == (3, 2) + @test jac ≈ [ 0.0 1.0; + 4.0 0.0; + 12.0 0.0] + + jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], chunk=Val(2))[1] + @test size(jac) == (3, 2) + @test jac ≈ [ 0.0 1.0; + 4.0 0.0; + 12.0 0.0] + + function f_test_1(A, x) + utmp = A*x[2:end] .+ x[1] + return utmp + end + + function f_test_2(A, x) + utmp = Vector{Float64}(undef, length(x)-1) + utmp .= A*x[2:end] .+ x[1] + return utmp + end + + function f_test_3!(u, A, x) + utmp .= A*x[2:end] .+ x[1] + end + + J_r_1(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_1(A, θ), x, n_outs=Val((5,)))[1] + J_r_2(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_2(A, θ), x, n_outs=Val((5,)))[1] + J_r_3(u, A, x) = Enzyme.jacobian(Reverse, θ -> f_test_3!(u, A, θ), x, n_outs=Val((5,)))[1] + + J_f_1(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_1(A, θ)), x)[1] + J_f_2(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_2(A, θ)), x)[1] + J_f_3(u, A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_3!(u, A, θ)), x)[1] + + x = ones(6) + A = Matrix{Float64}(LinearAlgebra.I, 5, 5) + u = Vector{Float64}(undef, 5) + + @test J_r_1(A, x) == [ + 1.0 1.0 0.0 0.0 0.0 0.0; + 1.0 0.0 1.0 0.0 0.0 0.0; + 1.0 0.0 0.0 1.0 0.0 0.0; + 1.0 0.0 0.0 0.0 1.0 0.0; + 1.0 0.0 0.0 0.0 0.0 1.0; + ] + + @test J_r_2(A, x) == [ + 1.0 1.0 0.0 0.0 0.0 0.0; + 1.0 0.0 1.0 0.0 0.0 0.0; + 1.0 0.0 0.0 1.0 0.0 0.0; + 1.0 0.0 0.0 0.0 1.0 0.0; + 1.0 0.0 0.0 0.0 0.0 1.0; + ] + + @test J_f_1(A, x) == [ + 1.0 1.0 0.0 0.0 0.0 0.0; + 1.0 0.0 1.0 0.0 0.0 0.0; + 1.0 0.0 0.0 1.0 0.0 0.0; + 1.0 0.0 0.0 0.0 1.0 0.0; + 1.0 0.0 0.0 0.0 0.0 1.0; + ] + @test J_f_2(A, x) == [ + 1.0 1.0 0.0 0.0 0.0 0.0; + 1.0 0.0 1.0 0.0 0.0 0.0; + 1.0 0.0 0.0 1.0 0.0 0.0; + 1.0 0.0 0.0 0.0 1.0 0.0; + 1.0 0.0 0.0 0.0 0.0 1.0; + ] + + # @show J_r_3(u, A, x) + # @show J_f_3(u, A, x) +end From 1bbcaa9efb43d5e1c59ed85229e3b3dd20ffdbd3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 11 Oct 2024 21:01:39 +0200 Subject: [PATCH 347/495] Add docstring for `guess_activity` (#1955) --- src/compiler.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 16d54481cf..a951b53d5f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -856,6 +856,11 @@ end return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState end +""" + Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) + +Try to guess the most appropriate [`Annotation`](@ref) for arguments of type `T` passed to [`autodiff`](@ref) with a given `mode`. +""" @inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = guess_activity(T, convert(API.CDerivativeMode, mode)) From 2a1f213e5282805a0c46c5d5f84d72cc4d192648 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 14 Oct 2024 19:57:20 -0500 Subject: [PATCH 348/495] WIP: Improve 1.11 support (#1963) * WIP: Improve 1.11 support * fix * fix --- src/compiler.jl | 4 +--- src/compiler/optimize.jl | 6 ++++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index a951b53d5f..b7db211b7a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5384,10 +5384,8 @@ function create_abi_wrapper( end end - cf = nested_codegen!(Mode, mod, add_one_in_place, Tuple{Any}, world) - push!(function_attributes(cf), EnumAttribute("alwaysinline", 0)) for shadowv in shadows - c = call!(builder, LLVM.function_type(cf), cf, [shadowv]) + c = emit_apply_generic!(builder, [unsafe_to_llvm(builder, add_one_in_place), shadowv]) if get_subprogram(llvm_f) !== nothing metadata(c)[LLVM.MD_dbg] = DILocation(0, 0, get_subprogram(llvm_f)) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index cc143ce4f2..eccc5789dd 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -791,8 +791,10 @@ function nodecayed_phis!(mod::LLVM.Module) if addr == 13 && hasload && addrspace(value_type(v)) == 10 return v, offset, hasload end - if addr == 13 && isa(v, LLVM.LoadInst) && !hasload - return getparent(operands(v)[1], offset, true) + if addr == 13 && !hasload + if isa(v, LLVM.LoadInst) + return getparent(operands(v)[1], offset, true) + end end if addr == 13 && isa(v, LLVM.ConstantExpr) From 46057161ae0c2a8824f74d337e34bde86c58f094 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 14 Oct 2024 21:58:57 -0500 Subject: [PATCH 349/495] Speed up onehot of arrays (#1953) * Speed up onehot of arrays * faster tupstack * fix * fix * fix assert error * Better and GC safe version * gc push * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix --- ext/EnzymeStaticArraysExt.jl | 6 +- src/Enzyme.jl | 31 +++++++-- src/compiler.jl | 119 +++++++++++++++++++++++++++++++++++ test/sugar.jl | 2 +- 4 files changed, 150 insertions(+), 8 deletions(-) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index af31d405d7..eaae75cccb 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -8,8 +8,8 @@ using Enzyme end @inline Base.convert(::Type{StaticArray}, tpa::Enzyme.TupleArray) = convert(SArray, tpa) -@inline function Enzyme.tupstack(rows::(NTuple{N, <:StaticArrays.SArray} where N), inshape, outshape) - reshape(reduce(hcat, map(vec, rows)), Size(inshape..., outshape...)) +@inline function Enzyme.tupstack(rows::Tuple{Vararg{T}}, outshape::Tuple{Vararg{Int}}, inshape::Tuple{Vararg{Int}}) where {T<:StaticArrays.SArray} + reshape(reduce(hcat, map(vec, rows)), Size(outshape..., inshape...)) end @inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}) where {S, T, N, L} @@ -19,7 +19,7 @@ end end end -@inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}, start, endl) where {S, T, N, L} +@inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}, start::Int, endl::Int) where {S, T, N, L} ntuple(Val(endl-start+1)) do i Base.@_inline_meta StaticArrays.SArray{S, T, N, L}( diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 598dc872e9..51b40cee37 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1518,7 +1518,21 @@ end nothing end -@inline function onehot(x) +function zerosetfn(x, i::Int) + res = zero(x) + @inbounds res[i] = 1 + return res +end + +@inline function onehot(x::Array) + Compiler.onehot_internal(zerosetfn, x, 0, length(x)) +end + +@inline function onehot(x::Array, start::Int, endl::Int) + Compiler.onehot_internal(zerosetfn, x, start-1, endl-start+1) +end + +@inline function onehot(x::AbstractArray) N = length(x) ntuple(Val(N)) do i Base.@_inline_meta @@ -1529,7 +1543,7 @@ end return res end end -@inline function onehot(x, start, endl) +@inline function onehot(x::AbstractArray, start::Int, endl::Int) ntuple(Val(endl - start + 1)) do i Base.@_inline_meta res = similar(x) @@ -1852,12 +1866,21 @@ function Base.getindex(a::TupleArray, args::Vararg{Int,N}) where {N} return a.data[start] end -@inline function tupstack(x, inshape, outshape) +@inline function tupstack(data::Tuple{Vararg{<:Array{T}}}, outshape::Tuple{Vararg{Int}}, inshape::Tuple{Vararg{Int}}) where {T} + num = prod(outshape) + res = Array{T}(undef, outshape..., inshape...) + for (i, val) in enumerate(data) + Base.unsafe_copyto!(res, num*(i-1)+1, val, 1, Base.reinterpret(UInt, num)) + end + res +end + +@inline function tupstack(x, outshape::Tuple{Vararg{Int}}, inshape::Tuple{Vararg{Int}}) st = Base.stack(x) if length(outshape) == 1 st else - reshape(st, (inshape..., outshape...)) + reshape(st, (outshape..., inshape...)) end end diff --git a/src/compiler.jl b/src/compiler.jl index b7db211b7a..e58357d8a0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -9362,4 +9362,123 @@ end include("compiler/reflection.jl") +@generated function onehot_internal(fn::F, x::T, startv::Int, lengthv::Int) where {F, T<:Array} + ir = JuliaContext() do ctx + Base.@_inline_meta + + target = Compiler.DefaultCompilerTarget() + params = Compiler.PrimalCompilerParams(API.DEM_ForwardMode) + mi = GPUCompiler.methodinstance(fn, Tuple{T, Int}) + job = CompilerJob(mi, CompilerConfig(target, params; kernel = false)) + mod, meta = GPUCompiler.codegen( + :llvm, + job; + optimize = false, + cleanup = false, + validate = false, + ) + copysetfn = meta.entry + blk = first(blocks(copysetfn)) + for inst in collect(instructions(blk)) + if isa(inst, LLVM.FenceInst) + eraseInst(blk, inst) + end + if isa(inst, LLVM.CallInst) + fn = LLVM.called_operand(inst) + if isa(fn, LLVM.Function) + if LLVM.name(fn) == "julia.safepoint" + eraseInst(blk, inst) + end + end + end + end + hasNoRet = any( + map( + k -> kind(k) == kind(EnumAttribute("noreturn")), + collect(function_attributes(copysetfn)), + ), + ) + @assert !hasNoRet + if !hasNoRet + push!(function_attributes(copysetfn), EnumAttribute("alwaysinline", 0)) + end + ity = convert(LLVMType, Int) + jlvaluet = convert(LLVMType, T; allow_boxed=true) + + FT = LLVM.FunctionType(jlvaluet, LLVMType[jlvaluet, ity, ity]) + llvm_f = LLVM.Function(mod, "f", FT) + push!(function_attributes(llvm_f), EnumAttribute("alwaysinline", 0)) + + # Check if Julia version has https://github.com/JuliaLang/julia/pull/46914 + # and also https://github.com/JuliaLang/julia/pull/47076 + # and also https://github.com/JuliaLang/julia/pull/48620 + needs_dynamic_size_workaround = !(VERSION >= v"1.10.5") + + builder = LLVM.IRBuilder() + entry = BasicBlock(llvm_f, "entry") + position!(builder, entry) + inp, lstart, len = collect(LLVM.Value, parameters(llvm_f)) + + boxed_count = if sizeof(Int) == sizeof(Int64) + emit_box_int64!(builder, len) + else + emit_box_int32!(builder, len) + end + + tag = emit_apply_type!(builder, NTuple, (boxed_count, unsafe_to_llvm(builder, T))) + + fullsize = nuwmul!(builder, len, LLVM.ConstantInt(sizeof(Int))) + obj = emit_allocobj!(builder, tag, fullsize, needs_dynamic_size_workaround) + + T_int8 = LLVM.Int8Type() + LLVM.memset!(builder, obj, LLVM.ConstantInt(T_int8, 0), fullsize, 0) + + alloc = pointercast!(builder, obj, LLVM.PointerType(jlvaluet, Tracked)) + alloc = pointercast!(builder, alloc, LLVM.PointerType(jlvaluet, 11)) + + loop = BasicBlock(llvm_f, "loop") + exit = BasicBlock(llvm_f, "exit") + + br!(builder, icmp!(builder, LLVM.API.LLVMIntEQ, LLVM.ConstantInt(0), len), exit, loop) + + position!(builder, loop) + idx = phi!(builder, ity) + + push!(LLVM.incoming(idx), (LLVM.ConstantInt(0), entry)) + inc = add!(builder, idx, LLVM.ConstantInt(1)) + push!(LLVM.incoming(idx), (inc, loop)) + rval = add!(builder, inc, lstart) + res = call!(builder, LLVM.function_type(copysetfn), copysetfn, [inp, rval]) + if !hasNoRet + gidx = gep!(builder, jlvaluet, alloc, [idx]) + store!(builder, res, gidx) + emit_writebarrier!(builder, get_julia_inner_types(builder, obj, res)) + end + + br!(builder, icmp!(builder, LLVM.API.LLVMIntEQ, inc, len), exit, loop) + + + T_int32 = LLVM.Int32Type() + + reinsert_gcmarker!(llvm_f) + + position!(builder, exit) + ret!(builder, obj) + + string(mod) + end + return quote + Base.@_inline_meta + Base.llvmcall( + ($ir, "f"), + Tuple{Vararg{T}}, + Tuple{T, Int, Int}, + x, + startv, + lengthv + ) + end +end + + end diff --git a/test/sugar.jl b/test/sugar.jl index c558fd813e..097472ab22 100644 --- a/test/sugar.jl +++ b/test/sugar.jl @@ -1,5 +1,5 @@ using Enzyme, Test - +using LinearAlgebra mul_scalar(x, y) = x[1]*y[2] + x[2]*y[1] mul_vector(x, y) = [x[1]*y[2], x[2]*y[1]] From a6de5c4ef1fc67618fd68cc0ea4c975a2cae3a3a Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 14 Oct 2024 21:59:31 -0500 Subject: [PATCH 350/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ddb8f163f4..982a7c338f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.8" +version = "0.13.9" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From c0c5e5169987612641383de68b67d6dadb334f1a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 16 Oct 2024 00:02:50 +0200 Subject: [PATCH 351/495] Fix code coverage & update action versions (#1954) --- .github/workflows/CI.yml | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 60d713c529..e5c92993e1 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -54,13 +54,13 @@ jobs: version: '1.11' assertions: true steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 if: ${{ ! matrix.assertions }} with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 if: ${{ matrix.assertions }} with: repository: 'JuliaLang/julia' @@ -72,7 +72,7 @@ jobs: sed -i.bak 's/exit 2/exit 0/g' julia/deps/tools/jlchecksum make -C julia -j $(nproc) FORCE_ASSERTIONS=1 LLVM_ASSERTIONS=1 JULIA_PRECOMPILE=0 echo $PWD/julia/usr/bin >> $GITHUB_PATH - - uses: actions/cache@v1 + - uses: actions/cache@v4 env: cache-name: cache-artifacts with: @@ -120,10 +120,12 @@ jobs: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager - uses: julia-actions/julia-processcoverage@v1 if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' - - uses: codecov/codecov-action@v1 + - uses: codecov/codecov-action@v4 if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' with: file: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false # or true if you want CI to fail when Codecov fails enzymetestutils: name: EnzymeTestUtils - Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ matrix.libEnzyme }} libEnzyme - ${{ github.event_name }} runs-on: ${{ matrix.os }} @@ -143,12 +145,12 @@ jobs: - x64 libEnzyme: [packaged] steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v1 + - uses: actions/cache@v4 env: cache-name: cache-artifacts with: @@ -180,10 +182,12 @@ jobs: if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' with: directories: lib/EnzymeTestUtils/src - - uses: codecov/codecov-action@v2 + - uses: codecov/codecov-action@v4 if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' with: files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false # or true if you want CI to fail when Codecov fails integration: name: Integration Tests - ${{ matrix.test }} runs-on: ${{ matrix.os }} @@ -200,10 +204,10 @@ jobs: - DynamicExpressions steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v4 - uses: julia-actions/julia-buildpkg@v1 - name: "Run tests" run: | @@ -214,11 +218,11 @@ jobs: name: Documentation runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 with: version: '1' - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v4 - run: | julia --project=docs -e ' using Pkg From ec68ae6636ea6f8cefd8cf5055dfe8b15eba23cb Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 15 Oct 2024 19:42:41 -0500 Subject: [PATCH 352/495] Static array return for forward gradient (#1943) * Static array return for forward gradient * cleanup * fix --- ext/EnzymeStaticArraysExt.jl | 2 + src/Enzyme.jl | 6 +- test/ext/sparsearrays.jl | 32 ++++++++++ test/ext/staticarrays.jl | 92 +++++++++++++++++++++++++++ test/runtests.jl | 118 +---------------------------------- 5 files changed, 133 insertions(+), 117 deletions(-) create mode 100644 test/ext/sparsearrays.jl create mode 100644 test/ext/staticarrays.jl diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index eaae75cccb..c2639a4c99 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -12,6 +12,8 @@ end reshape(reduce(hcat, map(vec, rows)), Size(outshape..., inshape...)) end +@inline Enzyme.specialize_output(output, input::StaticArray) = convert(SArray, output) + @inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}) where {S, T, N, L} ntuple(Val(L)) do i Base.@_inline_meta diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 51b40cee37..7caa06c281 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1884,6 +1884,8 @@ end end end +@inline specialize_output(output, input) = output + """ gradient(::ForwardMode, f, x; shadows=onehot(x), chunk=nothing) @@ -2135,11 +2137,11 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1])) # st : outshape x total inputs tupstack($tmp, outshape, inshape) else - TupleArray($tmp, size($arg)) + specialize_output(TupleArray($tmp, size($arg)), $(vals[1])) end end else - :(TupleArray($tmp, size($arg))) + :(specialize_output(TupleArray($tmp, size($arg)), $(vals[1]))) end else tmp diff --git a/test/ext/sparsearrays.jl b/test/ext/sparsearrays.jl new file mode 100644 index 0000000000..28b90fb6ee --- /dev/null +++ b/test/ext/sparsearrays.jl @@ -0,0 +1,32 @@ +using Enzyme, Test + +using SparseArrays + +@testset "Gradient & SparseArrays" begin + x = sparse([5.0, 0.0, 6.0]) + dx = Enzyme.gradient(Reverse, sum, x)[1] + @test dx isa SparseVector + @test dx ≈ [1, 0, 1] + + x = sparse([5.0 0.0 6.0]) + dx = Enzyme.gradient(Reverse, sum, x)[1] + @test dx isa SparseMatrixCSC + @test dx ≈ [1 0 1] +end + +function sparse_eval(x::Vector{Float64}) + A = sparsevec([1, 1, 2, 3], [2.0*x[2]^3.0, 1.0-x[1], 2.0+x[3], -1.0]) + B = sparsevec([1, 1, 2, 3], [2.0*x[2], 1.0-x[1], 2.0+x[3], -1.0]) + C = A + B + return A[1] +end + +@testset "Type Unstable SparseArrays" begin + x = [3.1, 2.7, 8.2] + dx = [0.0, 0.0, 0.0] + + autodiff(Reverse, sparse_eval, Duplicated(x, dx)) + + @test x ≈ [3.1, 2.7, 8.2] + @test dx ≈ [-1.0, 43.74, 0] +end \ No newline at end of file diff --git a/test/ext/staticarrays.jl b/test/ext/staticarrays.jl new file mode 100644 index 0000000000..c2c55a9aa4 --- /dev/null +++ b/test/ext/staticarrays.jl @@ -0,0 +1,92 @@ +using Enzyme, Test + +using StaticArrays + +@testset "Gradient & StaticArrays" begin + + x = @SArray [5.0 0.0 6.0] + dx = Enzyme.gradient(Reverse, prod, x)[1] + @test dx isa SArray + @test dx ≈ [0 30 0] + + x = @SVector [1.0, 2.0, 3.0] + y = onehot(x) + # this should be a very specific type of SArray, but there + # is a bizarre issue with older julia versions where it can be MArray + @test eltype(y) <: StaticVector + @test length(y) == 3 + @test y[1] == [1.0, 0.0, 0.0] + @test y[2] == [0.0, 1.0, 0.0] + @test y[3] == [0.0, 0.0, 1.0] + + y = onehot(x, 2, 3) + @test eltype(y) <: StaticVector + @test length(y) == 2 + @test y[1] == [0.0, 1.0, 0.0] + @test y[2] == [0.0, 0.0, 1.0] + + x = @SArray [5.0 0.0 6.0] + dx = Enzyme.gradient(Forward, prod, x)[1] + @test dx isa SArray + @test dx ≈ [0 30 0] + + f0 = x -> sum(2*x) + f1 = x -> @SVector Float64[x[2], 2*x[2]] + f2 = x -> @SMatrix Float64[x[2] x[1]; 2*x[2] 2*x[1]] + + x = @SVector Float64[1, 2] + + @inferred gradient(Forward, f0, x) + dx = gradient(Forward, f0, x)[1] + @test dx isa SVector + @test dx == [2.0, 2.0] # test to make sure conversion works + @test gradient(Forward, f1, x)[1] isa SMatrix + @test gradient(Forward, f1, x)[1] == [0 1.0; 0 2.0] + @test Enzyme.jacobian(Forward, f2, x)[1] isa SArray + @test Enzyme.jacobian(Forward, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2)) + + x = @SMatrix Float64[1 2; 3 4] + + @inferred gradient(Forward, f0, x) + dx = gradient(Forward, f0, x)[1] + @test dx isa SMatrix + @test dx == fill(2.0, (2,2)) + @test gradient(Forward, f1, x)[1] isa SArray + @test gradient(Forward, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2)) + @test Enzyme.jacobian(Forward, f2, x)[1] isa SArray + @test Enzyme.jacobian(Forward, f2, x)[1] == reshape( + Float64[0,0,1,2,1,2,0,0,0,0,0,0,0,0,0,0], (2,2,2,2), + ) + + x = @SVector Float64[1, 2] + + @inferred gradient(Reverse, f0, x) + dx = gradient(Reverse, f0, x)[1] + @test dx isa SVector + @test dx == [2.0, 2.0] # test to make sure conversion works + @test_broken gradient(Reverse, f1, x)[1] isa SMatrix + @test_broken gradient(Reverse, f1, x)[1] == [0 1.0; 0 2.0] + @test_broken Enzyme.jacobian(Reverse, f2, x)[1] isa SArray + @test_broken Enzyme.jacobian(Reverse, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2)) + + x = @SMatrix Float64[1 2; 3 4] + + @test_broken gradient(Reverse, f1, x)[1] isa SArray + @test_broken gradient(Reverse, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2)) + @test_broken Enzyme.jacobian(Reverse, f2, x)[1] isa SArray + @test_broken Enzyme.jacobian(Reverse, f2, x)[1] == reshape( + Float64[0,0,1,2,1,2,0,0,0,0,0,0,0,0,0,0], (2,2,2,2), + ) +end + +function unstable_fun(A0) + A = 'N' in ('H', 'h', 'S', 's') ? wrap(A0) : A0 + (@inbounds A[1])::eltype(A0) +end +@testset "Type unstable static array index" begin + inp = ones(SVector{2, Float64}) + res = Enzyme.gradient(Enzyme.Reverse, unstable_fun, inp)[1] + @test res ≈ [1.0, 0.0] + res = Enzyme.gradient(Enzyme.Forward, unstable_fun, inp)[1] + @test res ≈ [1.0, 0.0] +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index c3856aabf1..f421f35625 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,8 +7,6 @@ using Enzyme using Test using FiniteDifferences using Aqua -using SparseArrays -using StaticArrays using Statistics using LinearAlgebra using InlineStrings @@ -2816,119 +2814,6 @@ end @test grad == 14.0 end -@testset "Gradient & SparseArrays / StaticArrays" begin - x = sparse([5.0, 0.0, 6.0]) - dx = Enzyme.gradient(Reverse, sum, x)[1] - @test dx isa SparseVector - @test dx ≈ [1, 0, 1] - - x = sparse([5.0 0.0 6.0]) - dx = Enzyme.gradient(Reverse, sum, x)[1] - @test dx isa SparseMatrixCSC - @test dx ≈ [1 0 1] - - x = @SArray [5.0 0.0 6.0] - dx = Enzyme.gradient(Reverse, prod, x)[1] - @test dx isa SArray - @test dx ≈ [0 30 0] - - x = @SVector [1.0, 2.0, 3.0] - y = onehot(x) - # this should be a very specific type of SArray, but there - # is a bizarre issue with older julia versions where it can be MArray - @test eltype(y) <: StaticVector - @test length(y) == 3 - @test y[1] == [1.0, 0.0, 0.0] - @test y[2] == [0.0, 1.0, 0.0] - @test y[3] == [0.0, 0.0, 1.0] - - y = onehot(x, 2, 3) - @test eltype(y) <: StaticVector - @test length(y) == 2 - @test y[1] == [0.0, 1.0, 0.0] - @test y[2] == [0.0, 0.0, 1.0] - - x = @SArray [5.0 0.0 6.0] - dx = Enzyme.gradient(Forward, prod, x)[1] - @test dx[1] ≈ 0 - @test dx[2] ≈ 30 - @test dx[3] ≈ 0 - - f0 = x -> sum(2*x) - f1 = x -> @SVector Float64[x[2], 2*x[2]] - f2 = x -> @SMatrix Float64[x[2] x[1]; 2*x[2] 2*x[1]] - - x = @SVector Float64[1, 2] - - dx = gradient(Forward, f0, x)[1] - @test dx isa Enzyme.TupleArray - @test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works - @test gradient(Forward, f1, x)[1] isa SMatrix - @test gradient(Forward, f1, x)[1] == [0 1.0; 0 2.0] - @test Enzyme.jacobian(Forward, f2, x)[1] isa SArray - @test Enzyme.jacobian(Forward, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2)) - - x = @SMatrix Float64[1 2; 3 4] - - dx = gradient(Forward, f0, x)[1] - @test dx isa Enzyme.TupleArray - @test convert(SArray, dx) == fill(2.0, (2,2)) - @test gradient(Forward, f1, x)[1] isa SArray - @test gradient(Forward, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2)) - @test Enzyme.jacobian(Forward, f2, x)[1] isa SArray - @test Enzyme.jacobian(Forward, f2, x)[1] == reshape( - Float64[0,0,1,2,1,2,0,0,0,0,0,0,0,0,0,0], (2,2,2,2), - ) - - x = @SVector Float64[1, 2] - - dx = gradient(Reverse, f0, x)[1] - @test dx isa SVector - @test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works - @test_broken gradient(Reverse, f1, x)[1] isa SMatrix - @test_broken gradient(Reverse, f1, x)[1] == [0 1.0; 0 2.0] - @test_broken Enzyme.jacobian(Reverse, f2, x)[1] isa SArray - @test_broken Enzyme.jacobian(Reverse, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2)) - - x = @SMatrix Float64[1 2; 3 4] - - @test_broken gradient(Reverse, f1, x)[1] isa SArray - @test_broken gradient(Reverse, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2)) - @test_broken Enzyme.jacobian(Reverse, f2, x)[1] isa SArray - @test_broken Enzyme.jacobian(Reverse, f2, x)[1] == reshape( - Float64[0,0,1,2,1,2,0,0,0,0,0,0,0,0,0,0], (2,2,2,2), - ) -end - -function unstable_fun(A0) - A = 'N' in ('H', 'h', 'S', 's') ? wrap(A0) : A0 - (@inbounds A[1])::eltype(A0) -end -@testset "Type unstable static array index" begin - inp = ones(SVector{2, Float64}) - res = Enzyme.gradient(Enzyme.Reverse, unstable_fun, inp)[1] - @test res ≈ [1.0, 0.0] - res = Enzyme.gradient(Enzyme.Forward, unstable_fun, inp)[1] - @test res ≈ [1.0, 0.0] -end - -function sparse_eval(x::Vector{Float64}) - A = sparsevec([1, 1, 2, 3], [2.0*x[2]^3.0, 1.0-x[1], 2.0+x[3], -1.0]) - B = sparsevec([1, 1, 2, 3], [2.0*x[2], 1.0-x[1], 2.0+x[3], -1.0]) - C = A + B - return A[1] -end - -@testset "Type Unstable SparseArrays" begin - x = [3.1, 2.7, 8.2] - dx = [0.0, 0.0, 0.0] - - autodiff(Reverse, sparse_eval, Duplicated(x, dx)) - - @test x ≈ [3.1, 2.7, 8.2] - @test dx ≈ [-1.0, 43.74, 0] -end - include("sugar.jl") @testset "Forward on Reverse" begin @@ -3745,3 +3630,6 @@ include("ext/logexpfunctions.jl") @testset "BFloat16s ext" begin include("ext/bfloat16s.jl") end + +include("ext/sparsearrays.jl") +include("ext/staticarrays.jl") From f9f8820223267f146751da39076acec654785dcd Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 15 Oct 2024 21:59:24 -0500 Subject: [PATCH 353/495] Newstructt (#1965) * fix * fix --- src/compiler.jl | 277 +-------------------------- src/jlrt.jl | 333 +++++++++++++++++++++++++++++++++ src/rules/typeunstablerules.jl | 156 ++++++++++++++- test/runtests.jl | 1 + test/typeunstable.jl | 104 ++++++++++ 5 files changed, 587 insertions(+), 284 deletions(-) create mode 100644 src/jlrt.jl create mode 100644 test/typeunstable.jl diff --git a/src/compiler.jl b/src/compiler.jl index e58357d8a0..049d96285d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -926,282 +926,7 @@ end using .JIT - -declare_allocobj!(mod) = - get_function!(mod, "julia.gc_alloc_obj") do - T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - T_ppjlvalue = LLVM.PointerType(LLVM.PointerType(T_jlvalue)) - T_size_t = convert(LLVM.LLVMType, Int) - - - LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue]) - end -function emit_allocobj!( - B, - tag::LLVM.Value, - Size::LLVM.Value, - needs_workaround::Bool, - name::String = "", -) - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - - T_jlvalue = LLVM.StructType(LLVMType[]) - T_pjlvalue = LLVM.PointerType(T_jlvalue) - T_ppjlvalue = LLVM.PointerType(T_pjlvalue) - - T_int8 = LLVM.Int8Type() - T_pint8 = LLVM.PointerType(T_int8) - - pgcstack = reinsert_gcmarker!(fn, B) - ct = inbounds_gep!( - B, - T_pjlvalue, - bitcast!(B, pgcstack, T_ppjlvalue), - [LLVM.ConstantInt(current_task_offset())], - ) - ptls_field = inbounds_gep!(B, T_pjlvalue, ct, [LLVM.ConstantInt(current_ptls_offset())]) - T_ppint8 = LLVM.PointerType(T_pint8) - ptls = load!(B, T_pint8, bitcast!(B, ptls_field, T_ppint8)) - - if needs_workaround - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - T_size_t = convert(LLVM.LLVMType, Int) - # This doesn't allow for optimizations - alty = LLVM.FunctionType(T_prjlvalue, [T_pint8, T_size_t, T_prjlvalue]) - alloc_obj, _ = get_function!(mod, "jl_gc_alloc_typed", alty) - if value_type(Size) != T_size_t # Fix Int32/Int64 issues on 32bit systems - Size = trunc!(B, Size, T_size_t) - end - return call!(B, alty, alloc_obj, [ptls, Size, tag]) - end - - - alloc_obj, alty = declare_allocobj!(mod) - - return call!(B, alty, alloc_obj, [ct, Size, tag], name) -end -function emit_allocobj!(B, T::DataType, name::String = "") - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - - # Obtain tag - tag = unsafe_to_llvm(B, T) - - T_size_t = convert(LLVM.LLVMType, UInt) - Size = LLVM.ConstantInt(T_size_t, sizeof(T)) - emit_allocobj!(B, tag, Size, false, name) #=needs_workaround=# -end -declare_pointerfromobjref!(mod) = - get_function!(mod, "julia.pointer_from_objref") do - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Derived) - T_pjlvalue = LLVM.PointerType(T_jlvalue) - LLVM.FunctionType(T_pjlvalue, [T_prjlvalue]) - end -function emit_pointerfromobjref!(B, T) - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - func, fty = declare_pointerfromobjref!(mod) - return call!(B, fty, func, [T]) -end - -declare_writebarrier!(mod) = - get_function!(mod, "julia.write_barrier") do - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - LLVM.FunctionType(LLVM.VoidType(), [T_prjlvalue]; vararg = true) - end -declare_apply_generic!(mod) = - get_function!(mod, "ijl_apply_generic") do - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - LLVM.FunctionType( - T_prjlvalue, - [T_prjlvalue, LLVM.PointerType(T_prjlvalue), LLVM.Int32Type()], - ) - end -declare_juliacall!(mod) = - get_function!(mod, "julia.call") do - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg = true) - end - -function emit_jl!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]) - fn, _ = get_function!(mod, "jl_", FT) - call!(B, FT, fn, [val]) -end - -function emit_getfield!(B::LLVM.IRBuilder, val::LLVM.Value, fld::LLVM.Value)::LLVM.Value - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - T_pprjlvalue = LLVM.PointerType(T_prjlvalue) - T_int32 = LLVM.Int32Type() - - gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32]) - inv, _ = get_function!(mod, "jl_f_getfield", gen_FT) - - args = [val, fld] - - julia_call, FT = get_function!( - mod, - "julia.call", - LLVM.FunctionType( - T_prjlvalue, - [LLVM.PointerType(gen_FT), T_prjlvalue]; - vararg = true, - ), - ) - res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) - return res -end - - -function emit_nthfield!(B::LLVM.IRBuilder, val::LLVM.Value, fld::LLVM.Value)::LLVM.Value - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - T_size_t = convert(LLVM.LLVMType, Int) - - gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_size_t]) - inv, _ = get_function!(mod, "jl_get_nth_field_checked", gen_FT) - - args = [val, fld] - call!(B, gen_FT, inv, args) -end - -function emit_jl_throw!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - T_void = LLVM.VoidType() - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, 12) - FT = LLVM.FunctionType(T_void, [T_prjlvalue]) - fn, _ = get_function!(mod, "jl_throw", FT) - call!(B, FT, fn, [val]) -end - -function emit_box_int32!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - T_int32 = LLVM.Int32Type() - - FT = LLVM.FunctionType(T_prjlvalue, [T_int32]) - box_int32, _ = get_function!(mod, "ijl_box_int32", FT) - call!(B, FT, box_int32, [val]) -end - -function emit_box_int64!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - T_int64 = LLVM.Int64Type() - - FT = LLVM.FunctionType(T_prjlvalue, [T_int64]) - box_int64, _ = get_function!(mod, "ijl_box_int64", FT) - call!(B, FT, box_int64, [val]) -end - -function emit_apply_generic!(B::LLVM.IRBuilder, args)::LLVM.Value - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - T_pprjlvalue = LLVM.PointerType(T_prjlvalue) - T_int32 = LLVM.Int32Type() - - gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32]) - inv, _ = get_function!(mod, "ijl_apply_generic", gen_FT) - - # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!( - mod, - "julia.call", - LLVM.FunctionType( - T_prjlvalue, - [LLVM.PointerType(gen_FT), T_prjlvalue]; - vararg = true, - ), - ) - res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) - return res -end - -function emit_invoke!(B::LLVM.IRBuilder, args)::LLVM.Value - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - T_pprjlvalue = LLVM.PointerType(T_prjlvalue) - T_int32 = LLVM.Int32Type() - - # {} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)* @ijl_invoke - gen_FT = - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32, T_prjlvalue]) - inv = get_function!(mod, "ijl_invoke", gen_FT) - - # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!( - mod, - "julia.call2", - LLVM.FunctionType( - T_prjlvalue, - [LLVM.PointerType(generic_FT), T_prjlvalue]; - vararg = true, - ), - ) - res = call!(B, FT, julia_call, [inv, args...]) - return res -end - -function emit_svec!(B, args)::LLVM.Value - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - - fn, fty = get_function!(mod, "jl_svec") - sz = convert(LLVMType, Csize_t) - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - LLVM.FunctionType(T_prjlvalue, [sz]; vararg = true) - - sz = convert(LLVMType, Csize_t) - call!(B, fty, fn, [LLVM.ConstantInt(sz, length(args)), args...]) -end +include("jlrt.jl") AnyArray(Length::Int) = NamedTuple{ntuple(i -> Symbol(i), Val(Length)),NTuple{Length,Any}} diff --git a/src/jlrt.jl b/src/jlrt.jl new file mode 100644 index 0000000000..70b6fb2ad3 --- /dev/null +++ b/src/jlrt.jl @@ -0,0 +1,333 @@ +# For julia runtime function emission + +declare_allocobj!(mod::LLVM.Module) = + get_function!(mod, "julia.gc_alloc_obj") do + T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_ppjlvalue = LLVM.PointerType(LLVM.PointerType(T_jlvalue)) + T_size_t = convert(LLVM.LLVMType, Int) + + + LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue]) + end +function emit_allocobj!( + B::LLVM.IRBuilder, + @nospecialize(tag::LLVM.Value), + @nospecialize(Size::LLVM.Value), + needs_workaround::Bool, + name::String = "", +) + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_pjlvalue = LLVM.PointerType(T_jlvalue) + T_ppjlvalue = LLVM.PointerType(T_pjlvalue) + + T_int8 = LLVM.Int8Type() + T_pint8 = LLVM.PointerType(T_int8) + + pgcstack = reinsert_gcmarker!(fn, B) + ct = inbounds_gep!( + B, + T_pjlvalue, + bitcast!(B, pgcstack, T_ppjlvalue), + [LLVM.ConstantInt(current_task_offset())], + ) + ptls_field = inbounds_gep!(B, T_pjlvalue, ct, [LLVM.ConstantInt(current_ptls_offset())]) + T_ppint8 = LLVM.PointerType(T_pint8) + ptls = load!(B, T_pint8, bitcast!(B, ptls_field, T_ppint8)) + + if needs_workaround + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_size_t = convert(LLVM.LLVMType, Int) + # This doesn't allow for optimizations + alty = LLVM.FunctionType(T_prjlvalue, [T_pint8, T_size_t, T_prjlvalue]) + alloc_obj, _ = get_function!(mod, "jl_gc_alloc_typed", alty) + if value_type(Size) != T_size_t # Fix Int32/Int64 issues on 32bit systems + Size = trunc!(B, Size, T_size_t) + end + return call!(B, alty, alloc_obj, [ptls, Size, tag]) + end + + + alloc_obj, alty = declare_allocobj!(mod) + + return call!(B, alty, alloc_obj, [ct, Size, tag], name) +end +function emit_allocobj!(B::LLVM.IRBuilder, @nospecialize(T::DataType), name::String = "") + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + + # Obtain tag + tag = unsafe_to_llvm(B, T) + + T_size_t = convert(LLVM.LLVMType, UInt) + Size = LLVM.ConstantInt(T_size_t, sizeof(T)) + emit_allocobj!(B, tag, Size, false, name) #=needs_workaround=# +end + +declare_pointerfromobjref!(mod::LLVM.Module) = + get_function!(mod, "julia.pointer_from_objref") do + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Derived) + T_pjlvalue = LLVM.PointerType(T_jlvalue) + LLVM.FunctionType(T_pjlvalue, [T_prjlvalue]) + end + +function emit_pointerfromobjref!(B::LLVM.IRBuilder, @nospecialize(T::LLVM.Value)) + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + func, fty = declare_pointerfromobjref!(mod) + return call!(B, fty, func, [T]) +end + +declare_writebarrier!(mod) = + get_function!(mod, "julia.write_barrier") do + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + LLVM.FunctionType(LLVM.VoidType(), [T_prjlvalue]; vararg = true) + end +declare_apply_generic!(mod::LLVM.Module) = + get_function!(mod, "ijl_apply_generic") do + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + LLVM.FunctionType( + T_prjlvalue, + [T_prjlvalue, LLVM.PointerType(T_prjlvalue), LLVM.Int32Type()], + ) + end +declare_juliacall!(mod::LLVM.Module) = + get_function!(mod, "julia.call") do + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg = true) + end + +function emit_jl!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]) + fn, _ = get_function!(mod, "jl_", FT) + call!(B, FT, fn, [val]) +end + +function emit_getfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(fld::LLVM.Value))::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_pprjlvalue = LLVM.PointerType(T_prjlvalue) + T_int32 = LLVM.Int32Type() + + gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32]) + inv, _ = get_function!(mod, "jl_f_getfield", gen_FT) + + args = [val, fld] + + julia_call, FT = get_function!( + mod, + "julia.call", + LLVM.FunctionType( + T_prjlvalue, + [LLVM.PointerType(gen_FT), T_prjlvalue]; + vararg = true, + ), + ) + res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) + return res +end + + +function emit_nthfield!(B::LLVM.IRBuilder, val::LLVM.Value, @nospecialize(fld::LLVM.Value))::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_size_t = convert(LLVM.LLVMType, Int) + + gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_size_t]) + inv, _ = get_function!(mod, "jl_get_nth_field_checked", gen_FT) + + args = [val, fld] + call!(B, gen_FT, inv, args) +end + +function emit_nthfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), fld::Integer)::LLVM.Value + emit_nthfield!(B, val, LLVM.ConstantInt(Int(fld))) +end + +function emit_jl_throw!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + T_void = LLVM.VoidType() + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, 12) + FT = LLVM.FunctionType(T_void, [T_prjlvalue]) + fn, _ = get_function!(mod, "jl_throw", FT) + call!(B, FT, fn, [val]) +end + +function emit_box_int32!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_int32 = LLVM.Int32Type() + + FT = LLVM.FunctionType(T_prjlvalue, [T_int32]) + box_int32, _ = get_function!(mod, "ijl_box_int32", FT) + call!(B, FT, box_int32, [val]) +end + +function emit_box_int64!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_int64 = LLVM.Int64Type() + + FT = LLVM.FunctionType(T_prjlvalue, [T_int64]) + box_int64, _ = get_function!(mod, "ijl_box_int64", FT) + call!(B, FT, box_int64, [val]) +end + +function emit_apply_generic!(B::LLVM.IRBuilder, @nospecialize(args))::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_pprjlvalue = LLVM.PointerType(T_prjlvalue) + T_int32 = LLVM.Int32Type() + + gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32]) + inv, _ = get_function!(mod, "ijl_apply_generic", gen_FT) + + # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) + julia_call, FT = get_function!( + mod, + "julia.call", + LLVM.FunctionType( + T_prjlvalue, + [LLVM.PointerType(gen_FT), T_prjlvalue]; + vararg = true, + ), + ) + res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) + return res +end + +function emit_invoke!(B::LLVM.IRBuilder, @nospecialize(args))::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_pprjlvalue = LLVM.PointerType(T_prjlvalue) + T_int32 = LLVM.Int32Type() + + # {} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)* @ijl_invoke + gen_FT = + LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32, T_prjlvalue]) + inv = get_function!(mod, "ijl_invoke", gen_FT) + + # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) + julia_call, FT = get_function!( + mod, + "julia.call2", + LLVM.FunctionType( + T_prjlvalue, + [LLVM.PointerType(generic_FT), T_prjlvalue]; + vararg = true, + ), + ) + res = call!(B, FT, julia_call, [inv, args...]) + return res +end + +function emit_svec!(B::LLVM.IRBuilder, @nospecialize(args))::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + fn, fty = get_function!(mod, "jl_svec") + sz = convert(LLVMType, Csize_t) + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + LLVM.FunctionType(T_prjlvalue, [sz]; vararg = true) + + sz = convert(LLVMType, Csize_t) + call!(B, fty, fn, [LLVM.ConstantInt(sz, length(args)), args...]) +end + + +function val_from_byref_if_mixed(B::LLVM.IRBuilder, @nospecialize(oval::LLVM.Value), @nospecialize(val::LLVM.Value)) + legal, TT, _ = abs_typeof(oval) + @assert legal + world = enzyme_extract_world(LLVM.parent(position(B))) + act = active_reg_inner(TT, (), world) + if act == ActiveState || act == MixedState + legal2, TT2, _ = abs_typeof(val) + if legal2 + @assert TT2 <: Base.RefValue + else + shadowpointer = false + if isa(val, LLVM.PHIInst) + if size(incoming(val))[1] == 0 + shadowpointer = true + end + elseif isa(val, LLVM.ExtractValueInst) + m = operands(val)[1] + if isa(m, LLVM.PHIInst) + if size(incoming(m))[1] == 0 + shadowpointer = true + end + end + end + @assert shadowpointer + end + return emit_nthfield!(B, val, 0) + else + return val + end +end + +function byref_from_val_if_mixed(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value)) + legal, TT, _ = abs_typeof(val) + @assert legal + world = enzyme_extract_world(LLVM.parent(position(B))) + act = active_reg_inner(TT, (), world) + + if act == ActiveState || act == MixedState + obj = emit_allocobj!(B, Base.RefValue{TT}) + lty = convert(LLVMType, TT) + ld = load!(B, lty, bitcast!(B, val, LLVM.PointerType(lty, addrspace(value_type(val))))) + store!(B, ld, bitcast!(B, obj, LLVM.PointerType(lty, addrspace(value_type(val))))) + emit_writebarrier!(B, get_julia_inner_types(B, obj, ld)) + return obj + else + return val + end +end \ No newline at end of file diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 500ff53d9a..6371e409b1 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -775,9 +775,21 @@ end end @register_fwd function new_structt_fwd(B, orig, gutils, normalR, shadowR) - if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end + origops = collect(operands(orig)) width = get_width(gutils) @@ -804,7 +816,7 @@ end new_from_original(gutils, origops[1]), extract_value!(B, shadowsin, idx - 1), ] - tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args) + tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), vals) callconv!(tmp, callconv(orig)) shadowres = insert_value!(B, shadowres, tmp, idx - 1) end @@ -814,7 +826,113 @@ end end @register_aug function new_structt_augfwd(B, orig, gutils, normalR, shadowR, tapeR)::Bool - new_structt_fwd(B, orig, gutils, normalR, shadowR) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) + return true + end + + origops = collect(operands(orig)) + width = get_width(gutils) + + @assert is_constant_value(gutils, origops[1]) + if is_constant_value(gutils, origops[2]) + emit_error( + B, + orig, + "Enzyme: Not yet implemented, mixed activity for jl_new_struct_t" * + string(orig), + ) + end + + shadowsin = invert_pointer(gutils, origops[2], B) + if width == 1 + vals = [new_from_original(gutils, origops[1]), val_from_byref_if_mixed(B, origops[2], shadowsin)] + shadowres = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), vals) + callconv!(shadowres, callconv(orig)) + shadowres = byref_from_val_if_mixed(B, shadowres) + else + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + vals = [ + new_from_original(gutils, origops[1]), + val_from_byref_if_mixed(B, origops[2], extract_value!(B, shadowsin, idx - 1)), + ] + tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), vals) + callconv!(tmp, callconv(orig)) + tmp = byref_from_val_if_mixed(B, tmp) + shadowres = insert_value!(B, shadowres, tmp, idx - 1) + end + end + unsafe_store!(shadowR, shadowres.ref) + + legal, TT, _ = abs_typeof(orig) + @assert legal + world = enzyme_extract_world(LLVM.parent(position(B))) + act = active_reg_inner(TT, (), world) + if act == ActiveState || act == MixedState + unsafe_store!(tapeR, shadowres.ref) + end + + return false +end + +@generated function recursive_tuple(::Val{num}, lhs, rhs) where num + exprs = Expr[] + for i in 1:num + push!(exprs, quote + recursive_add(lhs[$i], getfield(rhs, $i), identity, guaranteed_nonactive) + end) + end + return quote + Base.@_inline_meta + ($(exprs...),) + end +end + +@generated function runtime_newstructt_rev(::Val{Width}, revres0::RR0, revarg0::RA0, args::Vararg{Any, N}) where {Width, RR0, RA0, N} + exprs = Expr[] + for i in 1:Width + dres = if i == 1 + :revres0 + else + :(args[$(2*(i-2)+1)]) + end + darg = if i == 1 + :revarg0 + else + :(args[$(2*(i-2)+1+1)]) + end + push!(exprs, quote + @assert $dres isa Base.RefValue + if $darg isa Base.RefValue + tmparg = $darg[] + tmpres = $dres[] + $darg[] = recursive_tuple(Val(length(tmparg)), tmparg, tmpres) + else + error( + "Enzyme Mutability Error: Cannot accumulate in place to immutable value " * + string($darg), + ) + end + end) + end + expr = quote + Base.@_inline_meta + $(exprs...) + return nothing + end + return expr end @register_rev function new_structt_rev(B, orig, gutils, tape) @@ -836,11 +954,33 @@ end if !needsShadow return end - emit_error( - B, - orig, - "Enzyme: Not yet implemented reverse for jl_new_structt " * string(orig), - ) + + origops = collect(operands(orig)) + width = get_width(gutils) + + legal, TT, _ = abs_typeof(orig) + @assert legal + world = enzyme_extract_world(LLVM.parent(position(B))) + act = active_reg_inner(TT, (), world) + if act == ActiveState || act == MixedState + vals = LLVM.Value[ + unsafe_to_llvm(B, runtime_newstructt_rev), + unsafe_to_llvm(B, Val(Int(width))), + ] + + shadowsin = lookup_value(gutils, invert_pointer(gutils, origops[2], B), B) + if width == 1 + push!(vals, tape) + push!(vals, shadowsin) + else + for i in 1:width + push!(vals, extract_value!(B, tape, i - 1)) + push!(vals, extract_value!(B, shadowsin, i - 1)) + end + end + emit_apply_generic!(B, vals) + end + return nothing end diff --git a/test/runtests.jl b/test/runtests.jl index f421f35625..8c7ca39abc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -80,6 +80,7 @@ include("kwrules.jl") include("kwrrules.jl") include("internal_rules.jl") include("ruleinvalidation.jl") +include("typeunstable.jl") @static if !Sys.iswindows() include("blas.jl") diff --git a/test/typeunstable.jl b/test/typeunstable.jl new file mode 100644 index 0000000000..b3600413a1 --- /dev/null +++ b/test/typeunstable.jl @@ -0,0 +1,104 @@ +using Enzyme, Test + +@eval construct_splatnew(T, fields) = $(Expr(:splatnew, :T, :fields)) + +struct ActivePair + x::Float32 + y::Float64 +end + +function toactivepair(x, y) + tup = Base.inferencebarrier((x, y)) + pair = construct_splatnew(ActivePair, tup) + (pair.x * pair.y)::Float64 +end + +struct VectorPair + x::Vector{Float32} + y::Vector{Float64} +end + + +function tovectorpair(x, y) + tup = Base.inferencebarrier((x, y)) + pair = construct_splatnew(VectorPair, tup) + (pair.x[1] * pair.y[1])::Float64 +end + + +function toactivepair!(res, x, y) + tup = Base.inferencebarrier((x[1], y[1])) + pair = construct_splatnew(ActivePair, tup) + res[] = (pair.x * pair.y)::Float64 + nothing +end + +function tovectorpair!(res, x, y) + tup = Base.inferencebarrier((x, y)) + pair = construct_splatnew(VectorPair, tup) + res[] = (pair.x[1] * pair.y[1])::Float64 + nothing +end + +@testset "Reverse Unstable newstructt" begin + res = Enzyme.autodiff(Reverse, toactivepair, Active(2.7f0), Active(3.1)) + @test res[1][1] ≈ 3.1f0 + @test res[1][2] ≈ 2.700000047683716 + + x = Float32[2.7f0] + dx = Float32[0.0f0] + y = Float64[3.1] + dy = Float64[0.0] + + Enzyme.autodiff(Reverse, tovectorpair, Duplicated(x, dx), Duplicated(y, dy)) + @test dx[1] ≈ 3.1f0 + @test dy[1] ≈ 2.700000047683716 + + x = Float32[2.7f0] + dx = Float32[0.0f0] + dx2 = Float32[0.0f0] + y = Float64[3.1] + dy = Float64[0.0] + dy2 = Float64[0.0] + + res = Ref(0.0) + dres = Ref(1.0) + dres2 = Ref(3.0) + + Enzyme.autodiff(Reverse, toactivepair!, BatchDuplicated(res, (dres, dres2)), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) + + @test dx[1] ≈ 3.1f0 + @test dy[1] ≈ 2.700000047683716 + + @test dx2[1] ≈ 3.1f0 * 3 + @test dy2[1] ≈ 2.700000047683716 * 3 + + + x = Float32[2.7f0] + dx = Float32[0.0f0] + dx2 = Float32[0.0f0] + y = Float64[3.1] + dy = Float64[0.0] + dy2 = Float64[0.0] + + res = Ref(0.0) + dres = Ref(1.0) + dres2 = Ref(3.0) + + Enzyme.autodiff(Reverse, tovectorpair!, BatchDuplicated(res, (dres, dres2)), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) + + @test dx[1] ≈ 3.1f0 + @test dy[1] ≈ 2.700000047683716 + + @test dx2[1] ≈ 3.1f0 * 3 + @test dy2[1] ≈ 2.700000047683716 * 3 + +end + +@testset "Forward Unstable newstructt" begin + res = Enzyme.autodiff(Forward, toactivepair, Duplicated(2.7f0, 2.0f0), Duplicated(3.1, 3.0)) + @test res[1] ≈ 2.7f0 * 3.0 + 2.0f0 * 3.1 + res = Enzyme.autodiff(Forward, toactivepair, BatchDuplicated(2.7f0, (2.0f0, 5.0f0)), BatchDuplicated(3.1, (3.0, 7.0))) + @test res[1][1] ≈ 2.7f0 * 3.0 + 2.0f0 * 3.1 + @test res[1][2] ≈ 2.7f0 * 7.0 + 5.0f0 * 3.1 +end \ No newline at end of file From 42ecd12cf5076f8d3db1694e014f69bc0b99173f Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 15 Oct 2024 22:40:42 -0500 Subject: [PATCH 354/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 982a7c338f..acbcd8a079 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.9" +version = "0.13.10" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 4b8c3528cf423931a56b3f15cb382a5d662914b3 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 16 Oct 2024 01:35:16 -0500 Subject: [PATCH 355/495] CompatHelper: bump compat for GPUCompiler to 1, (keep existing compat) (#1960) Co-authored-by: CompatHelper Julia --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index acbcd8a079..ad6ee87529 100644 --- a/Project.toml +++ b/Project.toml @@ -37,7 +37,7 @@ CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.4" Enzyme_jll = "0.0.153" -GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" +GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" ObjectFile = "0.4" From 5172d77c4c7526b9a0d8ba8f901d30871983e137 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 16 Oct 2024 08:41:26 +0200 Subject: [PATCH 356/495] Fix Codecov badge (#1967) * Fix Codecov badge * Replace master with main --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 67287c181a..d0466974a3 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,10 @@ [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://enzyme.mit.edu/julia/stable) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://enzyme.mit.edu/julia/dev) -[![Build Status](https://github.com/wsmoses/Enzyme.jl/workflows/CI/badge.svg)](https://github.com/wsmoses/Enzyme.jl/actions) -[![Coverage](https://codecov.io/gh/wsmoses/Enzyme.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/wsmoses/Enzyme.jl) +[![Build Status](https://github.com/EnzymeAD/Enzyme.jl/workflows/CI/badge.svg)](https://github.com/EnzymeAD/Enzyme.jl/actions) +[![Coverage](https://codecov.io/gh/EnzymeAD/Enzyme.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/EnzymeAD/Enzyme.jl) -This is a package containing the Julia bindings for [Enzyme](https://github.com/wsmoses/enzyme). This is very much a work in progress and bug reports/discussion is greatly appreciated! +This is a package containing the Julia bindings for [Enzyme](https://github.com/EnzymeAD/enzyme). This is very much a work in progress and bug reports/discussion is greatly appreciated! Enzyme is a plugin that performs automatic differentiation (AD) of statically analyzable LLVM. It is highly-efficient and its ability perform AD on optimized code allows Enzyme to meet or exceed the performance of state-of-the-art AD tools. From b80735e9437180b1c5952582fcca06acb62a15d4 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 16 Oct 2024 01:42:20 -0500 Subject: [PATCH 357/495] Julia 1.11: the adventure continues (#1966) * Julia 1.11: the adventure continues * more fixups of gc * fixup * fixup * mem * fix * evoice * fix * improve del * fix * assertion * fix * around * fix * fix * fix --- src/compiler.jl | 9 ++- src/compiler/optimize.jl | 54 +++++++++++-- src/compiler/utils.jl | 5 ++ src/compiler/validation.jl | 150 ++++++++++++++++++++++++++++++++----- 4 files changed, 192 insertions(+), 26 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 049d96285d..710725dcca 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2965,8 +2965,13 @@ function zero_allocation( push!(function_attributes(wrapper_f), StringAttribute("enzyme_no_escaping_allocation")) push!(function_attributes(wrapper_f), EnumAttribute("alwaysinline", 0)) push!(function_attributes(wrapper_f), EnumAttribute("nofree", 0)) - push!(function_attributes(wrapper_f), EnumAttribute("argmemonly", 0)) - push!(function_attributes(wrapper_f), EnumAttribute("writeonly", 0)) + + if LLVM.version().major <= 15 + push!(function_attributes(wrapper_f), EnumAttribute("argmemonly", 0)) + push!(function_attributes(wrapper_f), EnumAttribute("writeonly", 0)) + else + push!(function_attributes(wrapper_f), EnumAttribute("memory", WriteOnlyArgMemEffects.data)) + end push!(function_attributes(wrapper_f), EnumAttribute("willreturn", 0)) if LLVM.version().major >= 12 push!(function_attributes(wrapper_f), EnumAttribute("mustprogress", 0)) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index eccc5789dd..7214aa540f 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -793,7 +793,20 @@ function nodecayed_phis!(mod::LLVM.Module) end if addr == 13 && !hasload if isa(v, LLVM.LoadInst) - return getparent(operands(v)[1], offset, true) + v2, o2, hl2 = getparent(operands(v)[1], LLVM.ConstantInt(offty, 0), true) + @assert o2 == LLVM.ConstantInt(offty, 0) + return v2, offset, true + end + if isa(v, LLVM.CallInst) + cf = LLVM.called_operand(v) + if isa(cf, LLVM.Function) && LLVM.name(cf) == "julia.gc_loaded" + ld = operands(v)[2] + if isa(ld, LLVM.LoadInst) + v2, o2, hl2 = getparent(operands(ld)[1], LLVM.ConstantInt(offty, 0), true) + @assert o2 == LLVM.ConstantInt(offty, sizeof(Int)) + return v2, offset, true + end + end end end @@ -894,7 +907,7 @@ function nodecayed_phis!(mod::LLVM.Module) return v2, offset, skipload end - if isa(v, LLVM.GetElementPtrInst) && !hasload + if isa(v, LLVM.GetElementPtrInst) v2, offset, skipload = getparent(operands(v)[1], offset, hasload) offset = nuwadd!( @@ -1035,9 +1048,40 @@ function nodecayed_phis!(mod::LLVM.Module) position!(nb, nonphi) if addr == 13 - nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10)) - nphi = addrspacecast!(nb, nphi, LLVM.PointerType(ty, 11)) - nphi = load!(nb, ty, nphi) + @static if VERSION < v"1.11-" + nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10)) + nphi = addrspacecast!(nb, nphi, LLVM.PointerType(ty, 11)) + nphi = load!(nb, ty, nphi) + else + base_obj = nphi + + # %value_phi11 = phi {} addrspace(10)* [ %55, %L78 ], [ %54, %L76 ] + + # %.phi.trans.insert77 = bitcast {} addrspace(10)* %value_phi11 to { i64, {} addrspace(10)** } addrspace(10)* + # %.phi.trans.insert78 = addrspacecast { i64, {} addrspace(10)** } addrspace(10)* %.phi.trans.insert77 to { i64, {} addrspace(10)** } addrspace(11)* + # %.phi.trans.insert79 = getelementptr inbounds { i64, {} addrspace(10)** }, { i64, {} addrspace(10)** } addrspace(11)* %.phi.trans.insert78, i64 0, i32 1 + # %.pre80 = load {} addrspace(10)**, {} addrspace(10)** addrspace(11)* %.phi.trans.insert79, align 8, !dbg !532, !tbaa !19, !alias.scope !26, !noalias !29 + + # %154 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %value_phi11, {} addrspace(10)** %.pre80), !dbg !532 + + jlt = LLVM.PointerType(LLVM.StructType(LLVM.LLVMType[]), 10) + pjlt = LLVM.PointerType(jlt) + gent = LLVM.StructType([convert(LLVMType, Int), pjlt]) + pgent = LLVM.PointerType(LLVM.StructType([convert(LLVMType, Int), pjlt]), 10) + + nphi = bitcast!(nb, nphi, pgent) + nphi = addrspacecast!(nb, nphi, LLVM.PointerType(gent, 11)) + nphi = inbounds_gep!(nb, gent, nphi, [LLVM.ConstantInt(Int64(0)), LLVM.ConstantInt(Int32(1))]) + nphi = load!(nb, pjlt, nphi) + + GTy = LLVM.FunctionType(LLVM.PointerType(jlt, 13), LLVM.LLVMType[jlt, pjlt]) + gcloaded, _ = get_function!( + mod, + "julia.gc_loaded", + GTy + ) + nphi = call!(nb, GTy, gcloaded, LLVM.Value[base_obj, nphi]) + end else nphi = addrspacecast!(nb, nphi, ty) end diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index 8a801067eb..5539c5ed06 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -35,6 +35,11 @@ const ReadOnlyArgMemEffects = MemoryEffect( (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other)), ) +const WriteOnlyArgMemEffects = MemoryEffect( + (MRI_Mod << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), +) const NoEffects = MemoryEffect( (MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index b672c50f57..839aa120d7 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -373,38 +373,150 @@ function check_ir!(job, errors, mod::LLVM.Module) eraseInst(mod, f) end rewrite_ccalls!(mod) + + del = LLVM.Function[] for f in collect(functions(mod)) - check_ir!(job, errors, imported, f) + if in(f, del) + continue + end + check_ir!(job, errors, imported, f, del) + end + for d in del + LLVM.API.LLVMDeleteFunction(d) end + + del = LLVM.Function[] for f in collect(functions(mod)) - check_ir!(job, errors, imported, f) + if in(f, del) + continue + end + check_ir!(job, errors, imported, f, del) + end + for d in del + LLVM.API.LLVMDeleteFunction(d) end return errors end -function check_ir!(job, errors, imported, f::LLVM.Function) + +function unwrap_ptr_casts(val::LLVM.Value) + while true + is_simple_cast = false + is_simple_cast |= isa(val, LLVM.BitCastInst) + is_simple_cast |= isa(val, LLVM.AddrSpaceCastInst) || isa(val, LLVM.PtrToIntInst) + is_simple_cast |= isa(val, LLVM.ConstantExpr) && opcode(val) == LLVM.API.LLVMAddrSpaceCast + is_simple_cast |= isa(val, LLVM.ConstantExpr) && opcode(val) == LLVM.API.LLVMIntToPtr + is_simple_cast |= isa(val, LLVM.ConstantExpr) && opcode(val) == LLVM.API.LLVMBitCast + + if !is_simple_cast + return val + else + val = operands(val)[1] + end + end +end + +function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns) calls = [] isInline = API.EnzymeGetCLBool(cglobal((:EnzymeInline, API.libEnzyme))) != 0 - for bb in blocks(f), inst in instructions(bb) + mod = LLVM.parent(f) + for bb in blocks(f), inst in collect(instructions(bb)) if isa(inst, LLVM.CallInst) push!(calls, inst) # remove illegal invariant.load and jtbaa_const invariants - elseif isInline && isa(inst, LLVM.LoadInst) - md = metadata(inst) - if haskey(md, LLVM.MD_tbaa) - modified = LLVM.Metadata( - ccall( - (:EnzymeMakeNonConstTBAA, API.libEnzyme), - LLVM.API.LLVMMetadataRef, - (LLVM.API.LLVMMetadataRef,), - md[LLVM.MD_tbaa], - ), - ) - setindex!(md, modified, LLVM.MD_tbaa) - end - if haskey(md, LLVM.MD_invariant_load) - delete!(md, LLVM.MD_invariant_load) + elseif isa(inst, LLVM.LoadInst) + + fn_got = unwrap_ptr_casts(operands(inst)[1]) + fname = String(name(fn_got)) + match_ = match(r"^jlplt_(.*)_\d+_got$", fname) + + if match_ !== nothing + fname = match_[1] + FT = nothing + todo = LLVM.Instruction[inst] + while length(todo) != 0 + v = pop!(todo) + for u in LLVM.uses(v) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + FT = called_type(u) + break + end + if isa(u, LLVM.BitCastInst) + push!(todo, u) + continue + end + end + if FT !== nothing + break + end + end + @assert FT !== nothing + newf, _ = get_function!(mod, String(fname), FT) + + initfn = unwrap_ptr_casts(LLVM.initializer(fn_got)) + loadfn = first(instructions(first(blocks(initfn))))::LLVM.LoadInst + opv = operands(loadfn)[1]::LLVM.GlobalVariable + + if startswith(fname, "jl_") || startswith(fname, "ijl_") + else + @assert "unsupported jl got" + msg = sprint() do io::IO + println( + io, + "Enzyme internal error unsupported got", + ) + println(io, "inst=", inst) + println(io, "fname=", fname) + println(io, "FT=", FT) + println(io, "fn_got=", fn_got) + println(io, "init=", string(initfn)) + println(io, "opv=", string(opv)) + end + throw(AssertionError(msg)) + end + + if value_type(newf) != value_type(inst) + newf = const_pointercast(newf, value_type(inst)) + end + replace_uses!(inst, newf) + LLVM.API.LLVMInstructionEraseFromParent(inst) + + baduse = false + for u in LLVM.uses(fn_got) + u = LLVM.user(u) + if isa(u, LLVM.StoreInst) + continue + end + baduse = true + end + + if !baduse + push!(deletedfns, initfn) + LLVM.initializer!(fn_got, LLVM.null(value_type(LLVM.initializer(fn_got)))) + replace_uses!(opv, LLVM.null(value_type(opv))) + LLVM.API.LLVMDeleteGlobal(opv) + replace_uses!(fn_got, LLVM.null(value_type(fn_got))) + LLVM.API.LLVMDeleteGlobal(fn_got) + end + + elseif isInline + md = metadata(inst) + if haskey(md, LLVM.MD_tbaa) + modified = LLVM.Metadata( + ccall( + (:EnzymeMakeNonConstTBAA, API.libEnzyme), + LLVM.API.LLVMMetadataRef, + (LLVM.API.LLVMMetadataRef,), + md[LLVM.MD_tbaa], + ), + ) + setindex!(md, modified, LLVM.MD_tbaa) + end + if haskey(md, LLVM.MD_invariant_load) + delete!(md, LLVM.MD_invariant_load) + end end end end From 68d9d31482aafa033d16ab743882dd786c9b62b7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 16 Oct 2024 18:38:20 -0500 Subject: [PATCH 358/495] Improve shadow error message (#1971) * Improve shadow error message * Update compiler.jl --- src/compiler.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 710725dcca..f7811ba947 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2798,7 +2798,9 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie fn = LLVM.parent(LLVM.parent(V)) world = enzyme_extract_world(fn) has, Ty, byref = abs_typeof(V) - @assert has + if !has + throw(AssertionError("Allocation could not have its type statically determined $(string(V))")) + end rt = active_reg_inner(Ty, (), world) if rt == ActiveState || rt == MixedState B = LLVM.IRBuilder() From 93d690a30c9fd18a71ef90aeee4e676f2ed7270e Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 16 Oct 2024 18:39:51 -0500 Subject: [PATCH 359/495] Improve newstructt err (#1973) --- src/jlrt.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/jlrt.jl b/src/jlrt.jl index 70b6fb2ad3..8338986e19 100644 --- a/src/jlrt.jl +++ b/src/jlrt.jl @@ -285,7 +285,9 @@ end function val_from_byref_if_mixed(B::LLVM.IRBuilder, @nospecialize(oval::LLVM.Value), @nospecialize(val::LLVM.Value)) legal, TT, _ = abs_typeof(oval) - @assert legal + if !legal + throw(AssertionError("Could not determine type of value within jl_newstructt arg: $(string(oval))")) + end world = enzyme_extract_world(LLVM.parent(position(B))) act = active_reg_inner(TT, (), world) if act == ActiveState || act == MixedState @@ -330,4 +332,4 @@ function byref_from_val_if_mixed(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Valu else return val end -end \ No newline at end of file +end From e2b0e41ea770a7a0a7e9a8566b975e48d340f45e Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 16 Oct 2024 18:44:25 -0500 Subject: [PATCH 360/495] Revert "Fix code coverage & update action versions (#1954)" (#1974) This reverts commit c0c5e5169987612641383de68b67d6dadb334f1a. --- .github/workflows/CI.yml | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index e5c92993e1..60d713c529 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -54,13 +54,13 @@ jobs: version: '1.11' assertions: true steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@v1 if: ${{ ! matrix.assertions }} with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/checkout@v4 + - uses: actions/checkout@v3 if: ${{ matrix.assertions }} with: repository: 'JuliaLang/julia' @@ -72,7 +72,7 @@ jobs: sed -i.bak 's/exit 2/exit 0/g' julia/deps/tools/jlchecksum make -C julia -j $(nproc) FORCE_ASSERTIONS=1 LLVM_ASSERTIONS=1 JULIA_PRECOMPILE=0 echo $PWD/julia/usr/bin >> $GITHUB_PATH - - uses: actions/cache@v4 + - uses: actions/cache@v1 env: cache-name: cache-artifacts with: @@ -120,12 +120,10 @@ jobs: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager - uses: julia-actions/julia-processcoverage@v1 if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' - - uses: codecov/codecov-action@v4 + - uses: codecov/codecov-action@v1 if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' with: file: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: false # or true if you want CI to fail when Codecov fails enzymetestutils: name: EnzymeTestUtils - Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ matrix.libEnzyme }} libEnzyme - ${{ github.event_name }} runs-on: ${{ matrix.os }} @@ -145,12 +143,12 @@ jobs: - x64 libEnzyme: [packaged] steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v4 + - uses: actions/cache@v1 env: cache-name: cache-artifacts with: @@ -182,12 +180,10 @@ jobs: if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' with: directories: lib/EnzymeTestUtils/src - - uses: codecov/codecov-action@v4 + - uses: codecov/codecov-action@v2 if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' with: files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: false # or true if you want CI to fail when Codecov fails integration: name: Integration Tests - ${{ matrix.test }} runs-on: ${{ matrix.os }} @@ -204,10 +200,10 @@ jobs: - DynamicExpressions steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 + - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} - - uses: julia-actions/cache@v4 + - uses: julia-actions/cache@v1 - uses: julia-actions/julia-buildpkg@v1 - name: "Run tests" run: | @@ -218,11 +214,11 @@ jobs: name: Documentation runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@v1 with: version: '1' - - uses: julia-actions/cache@v4 + - uses: julia-actions/cache@v1 - run: | julia --project=docs -e ' using Pkg From 1e4dbec7e550263bd3d2ebe5ee9d9145299d4319 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Wed, 16 Oct 2024 21:39:17 -0400 Subject: [PATCH 361/495] fix `Vararg` (#1969) `Vararg{<:T}` is deprecated so this was causing LinearSolve.jl test failures. --- src/Enzyme.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 7caa06c281..0a3c030f11 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1866,7 +1866,7 @@ function Base.getindex(a::TupleArray, args::Vararg{Int,N}) where {N} return a.data[start] end -@inline function tupstack(data::Tuple{Vararg{<:Array{T}}}, outshape::Tuple{Vararg{Int}}, inshape::Tuple{Vararg{Int}}) where {T} +@inline function tupstack(data::Tuple{Vararg{Array{T}}}, outshape::Tuple{Vararg{Int}}, inshape::Tuple{Vararg{Int}}) where {T} num = prod(outshape) res = Array{T}(undef, outshape..., inshape...) for (i, val) in enumerate(data) From 72ae040d8dc83b791ca396ac684a154353d6f5f6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 16 Oct 2024 21:30:09 -0500 Subject: [PATCH 362/495] Fix const return for fwd (#1975) * Fix const return for fwd * Improve error message * cleanup --- src/compiler.jl | 14 ++++++++------ src/utils.jl | 4 +++- test/abi.jl | 9 +++++++++ 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index f7811ba947..d006cc9b55 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5155,11 +5155,13 @@ function create_abi_wrapper( count_Sret = 0 count_llvm_Sret = 0 if !isghostty(actualRetType) - if returnPrimal - count_llvm_Sret += 1 - end - if !(rettype <: Const) - count_llvm_Sret += 1 + if !Core.Compiler.isconstType(actualRetType) + if returnPrimal + count_llvm_Sret += 1 + end + if !(rettype <: Const) + count_llvm_Sret += 1 + end end end if !isghostty(literal_rt) @@ -5174,7 +5176,7 @@ function create_abi_wrapper( eval = fixup_abi( returnNum + 1, if count_llvm_Sret == 0 - makeInstanceOf(builder, sret_types[returnNum+1]) + makeInstanceOf(builder, actualRetType) elseif count_llvm_Sret == 1 val else diff --git a/src/utils.jl b/src/utils.jl index d042859b89..cad25cf277 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -114,7 +114,9 @@ end export unsafe_to_llvm, unsafe_nothing_to_llvm function makeInstanceOf(B::LLVM.IRBuilder, @nospecialize(T)) - @assert Core.Compiler.isconstType(T) + if !Core.Compiler.isconstType(T) + throw(AssertionError("Tried to make instance of non constant type $T")) + end @assert T <: Type return unsafe_to_llvm(B, T.parameters[1]) end diff --git a/test/abi.jl b/test/abi.jl index 7a7917553f..acc8f26090 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -1,6 +1,15 @@ using Enzyme using Test +retty() = Float64 + +@testset "Const Return" begin + res = Enzyme.autodiff(ForwardWithPrimal, retty, Const) + @test res === NamedTuple{(Symbol("1"),), Tuple{Type{Float64}}}((Float64,)) + res = Enzyme.autodiff(Forward, retty, Const) + @test res === () +end + @testset "ABI & Calling convention" begin f(x) = x From bc64880d135576c3cb34b82c218bf7a288478bdf Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 16 Oct 2024 22:51:44 -0500 Subject: [PATCH 363/495] Initial memory handling (#1972) * Initial memory handling * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix --- src/absint.jl | 15 +- src/compiler.jl | 526 +++---------------------------- src/jlrt.jl | 664 ++++++++++++++++++++++++++++++++++++++++ src/rules/allocrules.jl | 81 ++--- src/rules/llvmrules.jl | 345 +++++++++++++++------ src/rules/typerules.jl | 40 +++ 6 files changed, 1041 insertions(+), 630 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 77ce2b6a7e..5519b0b862 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -292,7 +292,9 @@ function abs_typeof( nm == "jl_alloc_array_3d" || nm == "ijl_alloc_array_3d" || nm == "jl_new_array" || - nm == "ijl_new_array" + nm == "ijl_new_array" || + nm == "jl_alloc_genericmemory" || + nm == "ijl_alloc_genericmemory" vals = absint(operands(arg)[1], partial) return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing) end @@ -365,6 +367,17 @@ function abs_typeof( end return (legal, RT, nothing) end + @static if VERSION < v"1.11-" + else + if nm == "jl_genericmemory_copy_slice" || nm == "ijl_genericmemory_copy_slice" + legal, RT, _ = abs_typeof(operands(arg)[1], partial) + if legal + @assert RT <: Memory + return (legal, RT, GPUCompiler.MUT_REF) + end + return (legal, RT, nothing) + end + end _, RT = enzyme_custom_extract_mi(arg, false) if RT !== nothing diff --git a/src/compiler.jl b/src/compiler.jl index d006cc9b55..8f8a01f713 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -197,6 +197,8 @@ const nofreefns = Set{String}(( "jl_array_ptr_copy", "ijl_array_copy", "jl_array_copy", + "ijl_genericmemory_copy_slice", + "jl_genericmemory_copy_slice", "ijl_get_nth_field_checked", "ijl_get_nth_field_checked", "jl_array_del_end", @@ -1096,367 +1098,6 @@ const JuliaGlobalNameMap = Dict{String,Any}( include("absint.jl") -function emit_apply_type!(B::LLVM.IRBuilder, Ty, args)::LLVM.Value - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - - legal = true - found = [] - for arg in args - slegal, foundv = absint(arg) - if slegal - push!(found, foundv) - else - legal = false - break - end - end - - if legal - return unsafe_to_llvm(B, Ty{found...}) - end - - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - T_pprjlvalue = LLVM.PointerType(T_prjlvalue) - T_int32 = LLVM.Int32Type() - - generic_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32]) - f_apply_type, _ = get_function!(mod, "jl_f_apply_type", generic_FT) - Ty = unsafe_to_llvm(B, Ty) - - # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!( - mod, - "julia.call", - LLVM.FunctionType( - T_prjlvalue, - [LLVM.PointerType(generic_FT), T_prjlvalue]; - vararg = true, - ), - ) - tag = call!( - B, - FT, - julia_call, - LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), Ty, args...], - ) - return tag -end - -function emit_tuple!(B, args)::LLVM.Value - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - - legal = true - found = [] - for arg in args - slegal, foundv = absint(arg) - if slegal - push!(found, foundv) - else - legal = false - break - end - end - - if legal - return unsafe_to_llvm(B, (found...,)) - end - - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - T_pprjlvalue = LLVM.PointerType(T_prjlvalue) - T_int32 = LLVM.Int32Type() - - generic_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32]) - f_apply_type, _ = get_function!(mod, "jl_f_tuple", generic_FT) - - # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!( - mod, - "julia.call", - LLVM.FunctionType( - T_prjlvalue, - [LLVM.PointerType(generic_FT), T_prjlvalue]; - vararg = true, - ), - ) - tag = call!( - B, - FT, - julia_call, - LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), args...], - ) - return tag -end - -function emit_jltypeof!(B::LLVM.IRBuilder, arg::LLVM.Value)::LLVM.Value - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - - legal, val, byref = abs_typeof(arg) - if legal - return unsafe_to_llvm(B, val) - end - - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg = true) - fn, _ = get_function!(mod, "jl_typeof", FT) - call!(B, FT, fn, [arg]) -end - -function emit_methodinstance!(B::LLVM.IRBuilder, func, args)::LLVM.Value - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - - world = enzyme_extract_world(fn) - - sizeT = convert(LLVMType, Csize_t) - psizeT = LLVM.PointerType(sizeT) - - primalvaltys = LLVM.Value[unsafe_to_llvm(B, Core.Typeof(func))] - for a in args - push!(primalvaltys, emit_jltypeof!(B, a)) - end - - meth = only(methods(func)) - tag = emit_apply_type!(B, Tuple, primalvaltys) - - # TT = meth.sig - # while TT isa UnionAll - # TT = TT.body - # end - # parms = TT.parameters - # - # tosv = primalvaltys - # if length(parms) > 0 && typeof(parms[end]) == Core.TypeofVararg - # tosv = LLVM.Value[tosv[1:length(parms)-1]..., emit_apply_type!(B, Tuple, tosv[length(parms):end])] - # end - # sv = emit_svec!(B, tosv[2:end]) - # - - meth = unsafe_to_llvm(B, meth) - - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - worlds, FT = get_function!( - mod, - "jl_gf_invoke_lookup_worlds", - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue, sizeT, psizeT, psizeT]), - ) - EB = LLVM.IRBuilder() - position!(EB, first(LLVM.instructions(LLVM.entry(fn)))) - minworld = alloca!(EB, sizeT) - maxworld = alloca!(EB, sizeT) - store!(B, LLVM.ConstantInt(sizeT, 0), minworld) - store!(B, LLVM.ConstantInt(sizeT, -1), maxworld) - methodmatch = call!( - B, - FT, - worlds, - LLVM.Value[ - tag, - unsafe_to_llvm(B, nothing), - LLVM.ConstantInt(sizeT, world), - minworld, - maxworld, - ], - ) - # emit_jl!(B, methodmatch) - # emit_jl!(B, emit_jltypeof!(B, methodmatch)) - offset = 1 - AT = LLVM.ArrayType(T_prjlvalue, offset + 1) - methodmatch = addrspacecast!(B, methodmatch, LLVM.PointerType(T_jlvalue, Derived)) - methodmatch = bitcast!(B, methodmatch, LLVM.PointerType(AT, Derived)) - gep = LLVM.inbounds_gep!( - B, - AT, - methodmatch, - LLVM.Value[LLVM.ConstantInt(0), LLVM.ConstantInt(offset)], - ) - sv = LLVM.load!(B, T_prjlvalue, gep) - - fn, FT = get_function!( - mod, - "jl_specializations_get_linfo", - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue, T_prjlvalue]), - ) - - mi = call!(B, FT, fn, [meth, tag, sv]) - - return mi -end - -function emit_writebarrier!(B, T) - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - func, FT = declare_writebarrier!(mod) - return call!(B, FT, func, T) -end - - -function get_array_struct() - @static if VERSION < v"1.11-" - # JL_EXTENSION typedef struct { - # JL_DATA_TYPE - # void *data; - # #ifdef STORE_ARRAY_LEN (just true new newer versions) - # size_t length; - # #endif - # jl_array_flags_t flags; - # uint16_t elsize; // element size including alignment (dim 1 memory stride) - # uint32_t offset; // for 1-d only. does not need to get big. - # size_t nrows; - # union { - # // 1d - # size_t maxsize; - # // Nd - # size_t ncols; - # }; - # // other dim sizes go here for ndims > 2 - # - # // followed by alignment padding and inline data, or owner pointer - # } jl_array_t; - - i8 = LLVM.IntType(8) - ptrty = LLVM.PointerType(i8, 13) - sizeT = LLVM.IntType(8 * sizeof(Csize_t)) - arrayFlags = LLVM.IntType(16) - elsz = LLVM.IntType(16) - off = LLVM.IntType(32) - nrows = LLVM.IntType(8 * sizeof(Csize_t)) - - return LLVM.StructType([ptrty, sizeT, arrayFlags, elsz, off, nrows]; packed = true) - else - # JL_EXTENSION typedef struct { - # JL_DATA_TYPE - # size_t length; - # void *ptr; - # // followed by padding and inline data, or owner pointer - # #ifdef _P64 - # // union { - # // jl_value_t *owner; - # // T inl[]; - # // }; - # #else - # // - # // jl_value_t *owner; - # // size_t padding[1]; - # // T inl[]; - # #endif - # } jl_genericmemory_t; - # - # JL_EXTENSION typedef struct { - # JL_DATA_TYPE - # void *ptr_or_offset; - # jl_genericmemory_t *mem; - # } jl_genericmemoryref_t; - # - # JL_EXTENSION typedef struct { - # JL_DATA_TYPE - # jl_genericmemoryref_t ref; - # size_t dimsize[]; // length for 1-D, otherwise length is mem->length - # } jl_array_t; - i8 = LLVM.IntType(8) - ptrty = LLVM.PointerType(i8, 10) - sizeT = LLVM.IntType(8 * sizeof(Csize_t)) - return LLVM.StructType([ptrty, sizeT]; packed = true) - end -end - -function get_array_data(B, array) - i8 = LLVM.IntType(8) - ptrty = LLVM.PointerType(i8, 13) - array = LLVM.pointercast!( - B, - array, - LLVM.PointerType(ptrty, LLVM.addrspace(LLVM.value_type(array))), - ) - return LLVM.load!(B, ptrty, array) -end - -function get_array_elsz(B, array) - ST = get_array_struct() - elsz = LLVM.IntType(16) - array = LLVM.pointercast!( - B, - array, - LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array))), - ) - v = inbounds_gep!( - B, - ST, - array, - LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(3))], - ) - return LLVM.load!(B, elsz, v) -end - -function get_array_len(B, array) - if isa(array, LLVM.CallInst) - fn = LLVM.called_operand(array) - nm = "" - if isa(fn, LLVM.Function) - nm = LLVM.name(fn) - end - - for (fname, num) in ( - ("jl_alloc_array_1d", 1), - ("ijl_alloc_array_1d", 1), - ("jl_alloc_array_2d", 2), - ("jl_alloc_array_2d", 2), - ("jl_alloc_array_2d", 3), - ("jl_alloc_array_2d", 3), - ) - if nm == fname - res = operands(array)[2] - for i = 2:num - res = mul!(B, res, operands(array)[1+i]) - end - return res - end - end - end - ST = get_array_struct() - array = LLVM.pointercast!( - B, - array, - LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array))), - ) - v = inbounds_gep!( - B, - ST, - array, - LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(1))], - ) - sizeT = LLVM.IntType(8 * sizeof(Csize_t)) - return LLVM.load!(B, sizeT, v) -end - -function get_array_nrows(B, array) - ST = get_array_struct() - array = LLVM.pointercast!( - B, - array, - LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array))), - ) - v = inbounds_gep!( - B, - ST, - array, - LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(5))], - ) - nrows = LLVM.IntType(8 * sizeof(Csize_t)) - return LLVM.load!(B, nrows, v) -end - # Force sret struct Return2 ret1::Any @@ -1491,129 +1132,8 @@ struct Tape{TapeTy,ShadowTy,ResT} shadow_return::ShadowTy end -function emit_gc_preserve_begin(B::LLVM.IRBuilder, args = LLVM.Value[]) - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - func, FT = get_function!( - mod, - "llvm.julia.gc_preserve_begin", - LLVM.FunctionType(LLVM.TokenType(), vararg = true), - ) - - token = call!(B, FT, func, args) - return token -end - -function emit_gc_preserve_end(B::LLVM.IRBuilder, token) - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - - func, FT = get_function!( - mod, - "llvm.julia.gc_preserve_end", - LLVM.FunctionType(LLVM.VoidType(), [LLVM.TokenType()]), - ) - - call!(B, FT, func, [token]) - return -end - -function allocate_sret!(B::LLVM.IRBuilder, N) - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - al = LLVM.alloca!(B, LLVM.ArrayType(T_prjlvalue, N)) - return al -end - -function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, N) - B = LLVM.IRBuilder() - position!(B, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) - allocate_sret!(B, N) -end - include("make_zero.jl") -function emit_error(B::LLVM.IRBuilder, orig, string, errty = EnzymeRuntimeException) - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - - if !isa(string, LLVM.Value) - string = globalstring_ptr!(B, string, "enz_exception") - end - - ct = if occursin("ptx", LLVM.triple(mod)) || occursin("amdgcn", LLVM.triple(mod)) - - vt = LLVM.VoidType() - ptr = convert(LLVMType, Ptr{Cvoid}) - - exc, _ = - get_function!(mod, "gpu_report_exception", LLVM.FunctionType(vt, [ptr])) - - string = ptrtoint!(B, string, ptr) - - call!(B, LLVM.function_type(exc), exc, [string]) - - framefn, ft = get_function!( - mod, - "gpu_report_exception_frame", - LLVM.FunctionType(vt, [LLVM.Int32Type(), ptr, ptr, LLVM.Int32Type()]), - ) - - if orig !== nothing - bt = GPUCompiler.backtrace(orig) - for (i, frame) in enumerate(bt) - idx = ConstantInt(parameters(ft)[1], i) - func = globalstring_ptr!(B, String(frame.func), "di_func") - func = ptrtoint!(B, func, ptr) - file = globalstring_ptr!(B, String(frame.file), "di_file") - file = ptrtoint!(B, file, ptr) - line = ConstantInt(parameters(ft)[4], frame.line) - call!(B, ft, framefn, [idx, func, file, line]) - end - end - - sigfn, sigft = get_function!( - mod, - "gpu_signal_exception", - LLVM.FunctionType(vt, LLVM.LLVMType[]), - ) - call!(B, sigft, sigfn) - trap_ft = LLVM.FunctionType(LLVM.VoidType()) - trap = if haskey(functions(mod), "llvm.trap") - functions(mod)["llvm.trap"] - else - LLVM.Function(mod, "llvm.trap", trap_ft) - end - call!(B, trap_ft, trap) - else - err = emit_allocobj!(B, errty) - err2 = bitcast!(B, err, LLVM.PointerType(LLVM.PointerType(LLVM.Int8Type()), 10)) - store!(B, string, err2) - emit_jl_throw!( - B, - addrspacecast!(B, err, LLVM.PointerType(LLVM.StructType(LLVMType[]), 12)), - ) - end - - # 2. Call error function and insert unreachable - LLVM.API.LLVMAddCallSiteAttribute( - ct, - reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), - EnumAttribute("noreturn"), - ) - if EnzymeMutabilityException != errty - LLVM.API.LLVMAddCallSiteAttribute( - ct, - reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), - StringAttribute("enzyme_error"), - ) - end - return ct -end - function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, f, tt, world) funcspec = GPUCompiler.methodinstance(typeof(f), tt, world) nested_codegen!(mode, mod, funcspec, world) @@ -3817,6 +3337,8 @@ function annotate!(mod, mode) "ijl_box_float64", "ijl_box_int32", "ijl_box_int64", + "jl_alloc_genericmemory", + "ijl_alloc_genericmemory", "jl_alloc_array_1d", "jl_alloc_array_2d", "jl_alloc_array_3d", @@ -3825,6 +3347,10 @@ function annotate!(mod, mode) "ijl_alloc_array_3d", "jl_array_copy", "ijl_array_copy", + "jl_genericmemory_copy_slice", + "ijl_genericmemory_copy_slice", + "jl_alloc_genericmemory", + "ijl_alloc_genericmemory", "jl_idtable_rehash", "ijl_idtable_rehash", "jl_f_tuple", @@ -3854,6 +3380,8 @@ function annotate!(mod, mode) boxfn in ( "jl_array_copy", "ijl_array_copy", + "jl_genericmemory_copy_slice", + "ijl_genericmemory_copy_slice", "jl_idtable_rehash", "ijl_idtable_rehash", ) @@ -3876,6 +3404,8 @@ function annotate!(mod, mode) boxfn in ( "jl_array_copy", "ijl_array_copy", + "jl_genericmemory_copy_slice", + "ijl_genericmemory_copy_slice", "jl_idtable_rehash", "ijl_idtable_rehash", ) @@ -3916,6 +3446,8 @@ function annotate!(mod, mode) boxfn in ( "jl_array_copy", "ijl_array_copy", + "jl_genericmemory_copy_slice", + "ijl_genericmemory_copy_slice", "jl_idtable_rehash", "ijl_idtable_rehash", ) @@ -4260,6 +3792,30 @@ function enzyme!( LLVM.API.LLVMValueRef, ) ), + "jl_genericmemory_copy_slice" => @cfunction( + inoutcopyslice_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), + "ijl_genericmemory_copy_slice" => @cfunction( + inoutcopyslice_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), "julia.pointer_from_objref" => @cfunction( inout_rule, UInt8, @@ -7479,7 +7035,11 @@ function GPUCompiler.codegen( nm == "ijl_alloc_array_2d" || nm == "jl_alloc_array_2d" || nm == "ijl_alloc_array_3d" || - nm == "jl_alloc_array_3d" + nm == "jl_alloc_array_3d" || + nm == "ijl_new_array" || + nm == "jl_new_array" || + nm == "jl_alloc_genericmemory" || + nm == "ijl_alloc_genericmemory" continue end if is_readonly(called) diff --git a/src/jlrt.jl b/src/jlrt.jl index 8338986e19..4f2ca71801 100644 --- a/src/jlrt.jl +++ b/src/jlrt.jl @@ -333,3 +333,667 @@ function byref_from_val_if_mixed(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Valu return val end end + +function emit_apply_type!(B::LLVM.IRBuilder, Ty, args)::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + legal = true + found = [] + for arg in args + slegal, foundv = absint(arg) + if slegal + push!(found, foundv) + else + legal = false + break + end + end + + if legal + return unsafe_to_llvm(B, Ty{found...}) + end + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_pprjlvalue = LLVM.PointerType(T_prjlvalue) + T_int32 = LLVM.Int32Type() + + generic_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32]) + f_apply_type, _ = get_function!(mod, "jl_f_apply_type", generic_FT) + Ty = unsafe_to_llvm(B, Ty) + + # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) + julia_call, FT = get_function!( + mod, + "julia.call", + LLVM.FunctionType( + T_prjlvalue, + [LLVM.PointerType(generic_FT), T_prjlvalue]; + vararg = true, + ), + ) + tag = call!( + B, + FT, + julia_call, + LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), Ty, args...], + ) + return tag +end + +function emit_tuple!(B, args)::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + legal = true + found = [] + for arg in args + slegal, foundv = absint(arg) + if slegal + push!(found, foundv) + else + legal = false + break + end + end + + if legal + return unsafe_to_llvm(B, (found...,)) + end + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_pprjlvalue = LLVM.PointerType(T_prjlvalue) + T_int32 = LLVM.Int32Type() + + generic_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32]) + f_apply_type, _ = get_function!(mod, "jl_f_tuple", generic_FT) + + # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) + julia_call, FT = get_function!( + mod, + "julia.call", + LLVM.FunctionType( + T_prjlvalue, + [LLVM.PointerType(generic_FT), T_prjlvalue]; + vararg = true, + ), + ) + tag = call!( + B, + FT, + julia_call, + LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), args...], + ) + return tag +end + +function emit_jltypeof!(B::LLVM.IRBuilder, arg::LLVM.Value)::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + legal, val, byref = abs_typeof(arg) + if legal + return unsafe_to_llvm(B, val) + end + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg = true) + fn, _ = get_function!(mod, "jl_typeof", FT) + call!(B, FT, fn, [arg]) +end + +function emit_methodinstance!(B::LLVM.IRBuilder, func, args)::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + world = enzyme_extract_world(fn) + + sizeT = convert(LLVMType, Csize_t) + psizeT = LLVM.PointerType(sizeT) + + primalvaltys = LLVM.Value[unsafe_to_llvm(B, Core.Typeof(func))] + for a in args + push!(primalvaltys, emit_jltypeof!(B, a)) + end + + meth = only(methods(func)) + tag = emit_apply_type!(B, Tuple, primalvaltys) + + # TT = meth.sig + # while TT isa UnionAll + # TT = TT.body + # end + # parms = TT.parameters + # + # tosv = primalvaltys + # if length(parms) > 0 && typeof(parms[end]) == Core.TypeofVararg + # tosv = LLVM.Value[tosv[1:length(parms)-1]..., emit_apply_type!(B, Tuple, tosv[length(parms):end])] + # end + # sv = emit_svec!(B, tosv[2:end]) + # + + meth = unsafe_to_llvm(B, meth) + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + worlds, FT = get_function!( + mod, + "jl_gf_invoke_lookup_worlds", + LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue, sizeT, psizeT, psizeT]), + ) + EB = LLVM.IRBuilder() + position!(EB, first(LLVM.instructions(LLVM.entry(fn)))) + minworld = alloca!(EB, sizeT) + maxworld = alloca!(EB, sizeT) + store!(B, LLVM.ConstantInt(sizeT, 0), minworld) + store!(B, LLVM.ConstantInt(sizeT, -1), maxworld) + methodmatch = call!( + B, + FT, + worlds, + LLVM.Value[ + tag, + unsafe_to_llvm(B, nothing), + LLVM.ConstantInt(sizeT, world), + minworld, + maxworld, + ], + ) + # emit_jl!(B, methodmatch) + # emit_jl!(B, emit_jltypeof!(B, methodmatch)) + offset = 1 + AT = LLVM.ArrayType(T_prjlvalue, offset + 1) + methodmatch = addrspacecast!(B, methodmatch, LLVM.PointerType(T_jlvalue, Derived)) + methodmatch = bitcast!(B, methodmatch, LLVM.PointerType(AT, Derived)) + gep = LLVM.inbounds_gep!( + B, + AT, + methodmatch, + LLVM.Value[LLVM.ConstantInt(0), LLVM.ConstantInt(offset)], + ) + sv = LLVM.load!(B, T_prjlvalue, gep) + + fn, FT = get_function!( + mod, + "jl_specializations_get_linfo", + LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue, T_prjlvalue]), + ) + + mi = call!(B, FT, fn, [meth, tag, sv]) + + return mi +end + +function emit_writebarrier!(B, T) + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + func, FT = declare_writebarrier!(mod) + return call!(B, FT, func, T) +end + + +function get_array_struct() + @static if VERSION < v"1.11-" + # JL_EXTENSION typedef struct { + # JL_DATA_TYPE + # void *data; + # #ifdef STORE_ARRAY_LEN (just true new newer versions) + # size_t length; + # #endif + # jl_array_flags_t flags; + # uint16_t elsize; // element size including alignment (dim 1 memory stride) + # uint32_t offset; // for 1-d only. does not need to get big. + # size_t nrows; + # union { + # // 1d + # size_t maxsize; + # // Nd + # size_t ncols; + # }; + # // other dim sizes go here for ndims > 2 + # + # // followed by alignment padding and inline data, or owner pointer + # } jl_array_t; + + i8 = LLVM.IntType(8) + ptrty = LLVM.PointerType(i8, 13) + sizeT = LLVM.IntType(8 * sizeof(Csize_t)) + arrayFlags = LLVM.IntType(16) + elsz = LLVM.IntType(16) + off = LLVM.IntType(32) + nrows = LLVM.IntType(8 * sizeof(Csize_t)) + + return LLVM.StructType([ptrty, sizeT, arrayFlags, elsz, off, nrows]; packed = true) + else + # JL_EXTENSION typedef struct { + # JL_DATA_TYPE + # size_t length; + # void *ptr; + # // followed by padding and inline data, or owner pointer + # #ifdef _P64 + # // union { + # // jl_value_t *owner; + # // T inl[]; + # // }; + # #else + # // + # // jl_value_t *owner; + # // size_t padding[1]; + # // T inl[]; + # #endif + # } jl_genericmemory_t; + # + # JL_EXTENSION typedef struct { + # JL_DATA_TYPE + # void *ptr_or_offset; + # jl_genericmemory_t *mem; + # } jl_genericmemoryref_t; + # + # JL_EXTENSION typedef struct { + # JL_DATA_TYPE + # jl_genericmemoryref_t ref; + # size_t dimsize[]; // length for 1-D, otherwise length is mem->length + # } jl_array_t; + i8 = LLVM.IntType(8) + ptrty = LLVM.PointerType(i8, 10) + sizeT = LLVM.IntType(8 * sizeof(Csize_t)) + return LLVM.StructType([ptrty, sizeT]; packed = true) + end +end + +function get_memory_struct() + # JL_EXTENSION typedef struct { + # JL_DATA_TYPE + # size_t length; + # void *ptr; + # // followed by padding and inline data, or owner pointer + # #ifdef _P64 + # // union { + # // jl_value_t *owner; + # // T inl[]; + # // }; + # #else + # // + # // jl_value_t *owner; + # // size_t padding[1]; + # // T inl[]; + # #endif + # } jl_genericmemory_t; + + i8 = LLVM.IntType(8) + ptrty = LLVM.PointerType(i8) + sizeT = LLVM.IntType(8 * sizeof(Csize_t)) + + return LLVM.StructType([sizeT, ptrty]; packed = true) +end + +function get_memory_data(B, array) + mty = get_memory_struct() + array = LLVM.pointercast!( + B, + array, + LLVM.PointerType(mty, LLVM.addrspace(LLVM.value_type(array))), + ) + v = inbounds_gep!( + B, + mty, + array, + LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(1))], + ) + i8 = LLVM.IntType(8) + ptrty = LLVM.PointerType(i8) + return LLVM.load!(B, ptrty, v) +end + +function get_layout_struct() + # typedef struct { + # uint32_t size; + # uint32_t nfields; + # uint32_t npointers; // number of pointers embedded inside + # int32_t first_ptr; // index of the first pointer (or -1) + # uint16_t alignment; // strictest alignment over all fields + # struct { // combine these fields into a struct so that we can take addressof them + # uint16_t haspadding : 1; // has internal undefined bytes + # uint16_t fielddesc_type : 2; // 0 -> 8, 1 -> 16, 2 -> 32, 3 -> foreign type + # // metadata bit only for GenericMemory eltype layout + # uint16_t arrayelem_isboxed : 1; + # uint16_t arrayelem_isunion : 1; + # // If set, this type's egality can be determined entirely by comparing + # // the non-padding bits of this datatype. + # uint16_t isbitsegal : 1; + # uint16_t padding : 10; + # } flags; + # // union { + # // jl_fielddesc8_t field8[nfields]; + # // jl_fielddesc16_t field16[nfields]; + # // jl_fielddesc32_t field32[nfields]; + # // }; + # // union { // offsets relative to data start in words + # // uint8_t ptr8[npointers]; + # // uint16_t ptr16[npointers]; + # // uint32_t ptr32[npointers]; + # // }; + # } jl_datatype_layout_t; + i32 = LLVM.IntType(32) + i16 = LLVM.IntType(16) + return LLVM.StructType([i32, i32, i32, i32, i16, i16]; packed = true) +end + +function get_datatype_struct() + # typedef struct _jl_datatype_t { + # JL_DATA_TYPE + # jl_typename_t *name; + # struct _jl_datatype_t *super; + # jl_svec_t *parameters; + # jl_svec_t *types; + # jl_value_t *instance; // for singletons + # const jl_datatype_layout_t *layout; + # // memoized properties (set on construction) + # uint32_t hash; + # uint16_t hasfreetypevars:1; // majority part of isconcrete computation + # uint16_t isconcretetype:1; // whether this type can have instances + # uint16_t isdispatchtuple:1; // aka isleaftupletype + # uint16_t isbitstype:1; // relevant query for C-api and type-parameters + # uint16_t zeroinit:1; // if one or more fields requires zero-initialization + # uint16_t has_concrete_subtype:1; // If clear, no value will have this datatype + # uint16_t maybe_subtype_of_cache:1; // Computational bit for has_concrete_supertype. See description in jltypes.c. + # uint16_t isprimitivetype:1; // whether this is declared with 'primitive type' keyword (sized, no fields, and immutable) + # uint16_t ismutationfree:1; // whether any mutable memory is reachable through this type (in the type or via fields) + # uint16_t isidentityfree:1; // whether this type or any object reachable through its fields has non-content-based identity + # uint16_t smalltag:6; // whether this type has a small-tag optimization + # } jl_datatype_t; + jlvaluet = LLVM.PointerType(LLVM.StructType(LLVMType[]), 10) + i32 = LLVM.IntType(32) + i16 = LLVM.IntType(16) + return LLVM.StructType([jlvaluet, jlvaluet, jlvaluet, jlvaluet, jlvaluet, jlvaluet, i32, i16]; packed = true) +end + +function get_array_data(B, array) + i8 = LLVM.IntType(8) + ptrty = LLVM.PointerType(i8, 13) + array = LLVM.pointercast!( + B, + array, + LLVM.PointerType(ptrty, LLVM.addrspace(LLVM.value_type(array))), + ) + return LLVM.load!(B, ptrty, array) +end + +function get_array_elsz(B, array) + ST = get_array_struct() + elsz = LLVM.IntType(16) + array = LLVM.pointercast!( + B, + array, + LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array))), + ) + v = inbounds_gep!( + B, + ST, + array, + LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(3))], + ) + return LLVM.load!(B, elsz, v) +end + +function emit_layout_of_type!(B, ty) + legal, JTy = absint(ty) + ls = get_layout_struct() + lptr = LLVM.PointerType(ls, 10) + if legal + return LLVM.const_inttoptr(LLVM.ConstantInt(Base.reinterpret(UInt, JTy.layout)), lptr) + end + @assert !isa(ty, LLVM.ConstantExpr) + @assert !isa(ty, LLVM.Constant) + dt = get_datatype_struct() + lty = bitcast!(B, ty, LLVM.PointerType(dt, addrspace(value_type(ty)))) + layoutp = inbounds_gep!(B, dt, ty, + LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(5))], + ) + layout = load!(B, lptr, layoutp) + return layout +end + +function emit_memorytype_elsz!(B, ty) + legal, JTy = absint(ty) + if legal + res = unsafe_load(reinterpret(Ptr{UInt32}, JTy.layout)) + return LLVM.ConstantInt(res) + end + ty = emit_layout_of_type!(B, ty) + @assert !isa(ty, LLVM.ConstantExpr) + @assert !isa(ty, LLVM.Constant) + i32 = LLVM.IntType(32) + lty = bitcast!(B, ty, LLVM.PointerType(i32, addrspace(value_type(ty)))) + return load!(B, i32, lty) +end + +function get_memory_elsz(B, array) + ty = emit_jltypeof!(B, array) + return emit_memorytype_elsz!(B, ty) +end + +function get_array_len(B, array) + if isa(array, LLVM.CallInst) + fn = LLVM.called_operand(array) + nm = "" + if isa(fn, LLVM.Function) + nm = LLVM.name(fn) + end + + for (fname, num) in ( + ("jl_alloc_array_1d", 1), + ("ijl_alloc_array_1d", 1), + ("jl_alloc_array_2d", 2), + ("jl_alloc_array_2d", 2), + ("jl_alloc_array_2d", 3), + ("jl_alloc_array_2d", 3), + ) + if nm == fname + res = operands(array)[2] + for i = 2:num + res = mul!(B, res, operands(array)[1+i]) + end + return res + end + end + end + ST = get_array_struct() + array = LLVM.pointercast!( + B, + array, + LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array))), + ) + v = inbounds_gep!( + B, + ST, + array, + LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(1))], + ) + sizeT = LLVM.IntType(8 * sizeof(Csize_t)) + return LLVM.load!(B, sizeT, v) +end + +function get_memory_len(B, array) + if isa(array, LLVM.CallInst) + fn = LLVM.called_operand(array) + nm = "" + if isa(fn, LLVM.Function) + nm = LLVM.name(fn) + end + + for (fname, num) in ( + ("jl_alloc_genericmemory", 1), + ("ijl_alloc_genericmemory", 1), + ) + if nm == fname + res = operands(array)[2] + for i = 2:num + res = mul!(B, res, operands(array)[1+i]) + end + return res + end + end + end + ST = get_memory_struct() + array = LLVM.pointercast!( + B, + array, + LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array))), + ) + v = inbounds_gep!( + B, + ST, + array, + LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(0))], + ) + sizeT = LLVM.IntType(8 * sizeof(Csize_t)) + return LLVM.load!(B, sizeT, v) +end + +function get_array_nrows(B, array) + ST = get_array_struct() + array = LLVM.pointercast!( + B, + array, + LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array))), + ) + v = inbounds_gep!( + B, + ST, + array, + LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(5))], + ) + nrows = LLVM.IntType(8 * sizeof(Csize_t)) + return LLVM.load!(B, nrows, v) +end + +function emit_gc_preserve_begin(B::LLVM.IRBuilder, args = LLVM.Value[]) + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + func, FT = get_function!( + mod, + "llvm.julia.gc_preserve_begin", + LLVM.FunctionType(LLVM.TokenType(), vararg = true), + ) + + token = call!(B, FT, func, args) + return token +end + +function emit_gc_preserve_end(B::LLVM.IRBuilder, token) + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + func, FT = get_function!( + mod, + "llvm.julia.gc_preserve_end", + LLVM.FunctionType(LLVM.VoidType(), [LLVM.TokenType()]), + ) + + call!(B, FT, func, [token]) + return +end + +function allocate_sret!(B::LLVM.IRBuilder, N) + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + al = LLVM.alloca!(B, LLVM.ArrayType(T_prjlvalue, N)) + return al +end + +function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, N) + B = LLVM.IRBuilder() + position!(B, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) + allocate_sret!(B, N) +end + +function emit_error(B::LLVM.IRBuilder, orig, string, errty = EnzymeRuntimeException) + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + + if !isa(string, LLVM.Value) + string = globalstring_ptr!(B, string, "enz_exception") + end + + ct = if occursin("ptx", LLVM.triple(mod)) || occursin("amdgcn", LLVM.triple(mod)) + + vt = LLVM.VoidType() + ptr = convert(LLVMType, Ptr{Cvoid}) + + exc, _ = + get_function!(mod, "gpu_report_exception", LLVM.FunctionType(vt, [ptr])) + + string = ptrtoint!(B, string, ptr) + + call!(B, LLVM.function_type(exc), exc, [string]) + + framefn, ft = get_function!( + mod, + "gpu_report_exception_frame", + LLVM.FunctionType(vt, [LLVM.Int32Type(), ptr, ptr, LLVM.Int32Type()]), + ) + + if orig !== nothing + bt = GPUCompiler.backtrace(orig) + for (i, frame) in enumerate(bt) + idx = ConstantInt(parameters(ft)[1], i) + func = globalstring_ptr!(B, String(frame.func), "di_func") + func = ptrtoint!(B, func, ptr) + file = globalstring_ptr!(B, String(frame.file), "di_file") + file = ptrtoint!(B, file, ptr) + line = ConstantInt(parameters(ft)[4], frame.line) + call!(B, ft, framefn, [idx, func, file, line]) + end + end + + sigfn, sigft = get_function!( + mod, + "gpu_signal_exception", + LLVM.FunctionType(vt, LLVM.LLVMType[]), + ) + call!(B, sigft, sigfn) + trap_ft = LLVM.FunctionType(LLVM.VoidType()) + trap = if haskey(functions(mod), "llvm.trap") + functions(mod)["llvm.trap"] + else + LLVM.Function(mod, "llvm.trap", trap_ft) + end + call!(B, trap_ft, trap) + else + err = emit_allocobj!(B, errty) + err2 = bitcast!(B, err, LLVM.PointerType(LLVM.PointerType(LLVM.Int8Type()), 10)) + store!(B, string, err2) + emit_jl_throw!( + B, + addrspacecast!(B, err, LLVM.PointerType(LLVM.StructType(LLVMType[]), 12)), + ) + end + + # 2. Call error function and insert unreachable + LLVM.API.LLVMAddCallSiteAttribute( + ct, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + EnumAttribute("noreturn"), + ) + if EnzymeMutabilityException != errty + LLVM.API.LLVMAddCallSiteAttribute( + ct, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + StringAttribute("enzyme_error"), + ) + end + return ct +end + diff --git a/src/rules/allocrules.jl b/src/rules/allocrules.jl index 1c1447dd65..83a9a22cd4 100644 --- a/src/rules/allocrules.jl +++ b/src/rules/allocrules.jl @@ -27,7 +27,11 @@ function array_shadow_handler( typ = eltype(typ) b = LLVM.IRBuilder(B) - orig = LLVM.Value(OrigCI) + orig = LLVM.Value(OrigCI)::LLVM.CallInst + + nm = LLVM.name(LLVM.called_operand(orig)::LLVM.Function) + + memory = nm == "jl_alloc_genericmemory" || nm == "ijl_alloc_genericmemory" vals = LLVM.Value[] valTys = API.CValueType[] @@ -38,7 +42,11 @@ function array_shadow_handler( anti = call_samefunc_with_inverted_bundles!(b, gutils, orig, vals, valTys, false) #=lookup=# - prod = get_array_len(b, anti) + prod = if memory + get_memory_len(b, anti) + else + get_array_len(b, anti) + end isunboxed, elsz, al = Base.uniontype_layout(typ) @@ -66,7 +74,11 @@ function array_shadow_handler( end i8 = LLVM.IntType(8) - toset = get_array_data(b, anti) + toset = if memory + get_memory_data(b, anti) + else + get_array_data(b, anti) + end mcall = LLVM.memset!(b, toset, LLVM.ConstantInt(i8, 0, false), tot, al) @@ -90,64 +102,13 @@ end @inline function register_alloc_rules() register_alloc_handler!( - ("jl_alloc_array_1d", "ijl_alloc_array_1d"), - @cfunction( - array_shadow_handler, - LLVM.API.LLVMValueRef, - ( - LLVM.API.LLVMBuilderRef, - LLVM.API.LLVMValueRef, - Csize_t, - Ptr{LLVM.API.LLVMValueRef}, - API.EnzymeGradientUtilsRef, - ) + ( + "jl_alloc_array_1d", "ijl_alloc_array_1d", + "jl_alloc_array_2d", "ijl_alloc_array_2d", + "jl_alloc_array_3d", "ijl_alloc_array_3d", + "jl_new_array", "ijl_new_array", + "jl_alloc_genericmemory", "ijl_alloc_genericmemory", ), - @cfunction( - null_free_handler, - LLVM.API.LLVMValueRef, - (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef) - ) - ) - register_alloc_handler!( - ("jl_alloc_array_2d", "ijl_alloc_array_2d"), - @cfunction( - array_shadow_handler, - LLVM.API.LLVMValueRef, - ( - LLVM.API.LLVMBuilderRef, - LLVM.API.LLVMValueRef, - Csize_t, - Ptr{LLVM.API.LLVMValueRef}, - API.EnzymeGradientUtilsRef, - ) - ), - @cfunction( - null_free_handler, - LLVM.API.LLVMValueRef, - (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef) - ) - ) - register_alloc_handler!( - ("jl_alloc_array_3d", "ijl_alloc_array_3d"), - @cfunction( - array_shadow_handler, - LLVM.API.LLVMValueRef, - ( - LLVM.API.LLVMBuilderRef, - LLVM.API.LLVMValueRef, - Csize_t, - Ptr{LLVM.API.LLVMValueRef}, - API.EnzymeGradientUtilsRef, - ) - ), - @cfunction( - null_free_handler, - LLVM.API.LLVMValueRef, - (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef) - ) - ) - register_alloc_handler!( - ("jl_new_array", "ijl_new_array"), @cfunction( array_shadow_handler, LLVM.API.LLVMValueRef, diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 3c4b95d8ee..899c9f6d43 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -549,7 +549,10 @@ end return false end -function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) +# Optionally takes a length if requested +# If this is a memory, pass memoryptr= +function arraycopy_common(fwd, B, orig, shadowsrc, gutils, shadowdst; len=nothing, memoryptr=nothing) + memory = memoryptr != nothing needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) activep = API.EnzymeGradientUtilsGetReturnDiffeType( @@ -569,22 +572,16 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) shadowdst = invert_pointer(gutils, orig, B) end - # size_t len = jl_array_len(ary); - # size_t elsz = ary->elsize; - # memcpy(new_ary->data, ary->data, len * elsz); - # JL_EXTENSION typedef struct { - # JL_DATA_TYPE - # void *data; - # #ifdef STORE_ARRAY_LEN - # size_t length; - # #endif - # jl_array_flags_t flags; - # uint16_t elsize; // element size including alignment (dim 1 memory stride) - tt = TypeTree(API.EnzymeGradientUtilsAllocAndGetTypeTree(gutils, orig)) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) dl = string(LLVM.datalayout(mod)) - API.EnzymeTypeTreeLookupEq(tt, 1, dl) + # memory stores the data pointer after a length + if memory + API.EnzymeTypeTreeLookupEq(tt, 2*sizeof(Int), dl) + API.EnzymeTypeTreeShiftIndiciesEq(tt, dl, sizeof(Int), sizeof(Int), 0) + else + API.EnzymeTypeTreeLookupEq(tt, sizeof(Int), dl) + end data0!(tt) ct = API.EnzymeTypeTreeInner0(tt) @@ -596,7 +593,7 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) emit_error( B, orig, - "Enzyme: Unknown concrete type in arraycopy_common. tt: " * string(tt), + "Enzyme: Unknown concrete type in arraycopy_common. tt: " * string(tt)* " " * string(orig) * " " * string(abs_typeof(orig)), ) return nothing end @@ -605,14 +602,7 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) ctx = LLVM.context(orig) secretty = API.EnzymeConcreteTypeIsFloat(ct) - off = sizeof(Cstring) - if true # STORE_ARRAY_LEN - off += sizeof(Csize_t) - end - #jl_array_flags_t - off += 2 - - actualOp = new_from_original(gutils, origArg) + actualOp = new_from_original(gutils, shadowsrc) if fwd B0 = B elseif typeof(actualOp) <: LLVM.Argument @@ -631,15 +621,36 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) while isa(nextInst, LLVM.PHIInst) nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(nextInst)) end + if len != nothing + nextInst = new_from_original(gutils, orig) + end position!(B0, nextInst) end - elSize = get_array_elsz(B0, actualOp) + elSize = if memory + get_memory_elsz(B0, actualOp) + else + get_array_elsz(B0, actualOp) + end + elSize = LLVM.zext!(B0, elSize, LLVM.IntType(8 * sizeof(Csize_t))) - len = get_array_len(B0, actualOp) + if len == nothing + if memory + len = get_memory_len(B0, actualOp) + else + len = get_array_len(B0, actualOp) + end + elseif !fwd + # len = lookup_value(gutils, len, B) + end + + if memory + length = LLVM.mul!(B0, len, elSize) + else + length = LLVM.mul!(B0, len, elSize) + end - length = LLVM.mul!(B0, len, elSize) isVolatile = LLVM.ConstantInt(LLVM.IntType(1), 0) # forward pass copy already done by underlying call @@ -649,10 +660,26 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) if !fwd shadowdst = lookup_value(gutils, shadowdst, B) end - shadowsrc = invert_pointer(gutils, origArg, B) - if !fwd - shadowsrc = lookup_value(gutils, shadowsrc, B) - end + + + lookup_src = true + + if memory + if fwd + shadowsrc = memoryptr + lookup_src = false + else + shadowsrc = invert_pointer(gutils, shadowsrc, B) + if !fwd + shadowsrc = lookup_value(gutils, shadowsrc, B) + end + end + else + shadowsrc = invert_pointer(gutils, shadowsrc, B) + if !fwd + shadowsrc = lookup_value(gutils, shadowsrc, B) + end + end width = get_width(gutils) @@ -669,93 +696,233 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) algn = 0 i8 = LLVM.IntType(8) - if width == 1 + for i = 1:width - shadowsrc = get_array_data(B, shadowsrc) - shadowdst = get_array_data(B, shadowdst) + evsrc = if width == 1 + shadowsrc + else + extract_value!(B, shadowsrc, i - 1) + end + evdst = if width == 1 + shadowdst + else + extract_value!(B, shadowdst, i - 1) + end - if fwd && secretty != nothing - LLVM.memset!(B, shadowdst, LLVM.ConstantInt(i8, 0, false), length, algn) - end + # src already has done the lookup from the argument + shadowsrc0 = if lookup_src + if memory + get_memory_data(B, evsrc) + else + get_array_data(B, evsrc) + end + else + evsrc + end - API.sub_transfer( - gutils, - fwd ? API.DEM_ReverseModePrimal : API.DEM_ReverseModeGradient, - secretty, - intrinsic, - 1, - 1, - 0, - false, - shadowdst, - false, - shadowsrc, - length, - isVolatile, - orig, - allowForward, - !fwd, - ) #=shadowsLookedUp=# + shadowdst0 = if memory + get_memory_data(B, evdst) + else + get_array_data(B, evdst) + end - else - for i = 1:width + if fwd && secretty != nothing + LLVM.memset!(B, shadowdst0, LLVM.ConstantInt(i8, 0, false), length, algn) + end - evsrc = extract_value!(B, shadowsrc, i - 1) - evdst = extract_value!(B, shadowdst, i - 1) + API.sub_transfer( + gutils, + fwd ? API.DEM_ReverseModePrimal : API.DEM_ReverseModeGradient, + secretty, + intrinsic, + 1, + 1, + 0, + false, + shadowdst0, + false, + shadowsrc0, + length, + isVolatile, + orig, + allowForward, + !fwd, + ) #=shadowsLookedUp=# + end - shadowsrc0 = get_array_data(B, evsrc) - shadowdst0 = get_array_data(B, evdst) + return nothing +end - if fwd && secretty != nothing - LLVM.memset!(B, shadowdst0, LLVM.ConstantInt(i8, 0, false), length, algn) - end +@register_aug function arraycopy_augfwd(B, orig, gutils, normalR, shadowR, tapeR) + if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL + return true + end + arraycopy_fwd(B, orig, gutils, normalR, shadowR) - API.sub_transfer( - gutils, - fwd ? API.DEM_ReverseModePrimal : API.DEM_ReverseModeGradient, - secretty, - intrinsic, - 1, - 1, - 0, - false, - shadowdst0, - false, - shadowsrc0, + origops = LLVM.operands(orig) + + if !is_constant_value(gutils, origops[1]) && !is_constant_value(gutils, orig) + shadowres = LLVM.Value(unsafe_load(shadowR)) + + arraycopy_common(true, B, orig, origops[1], gutils, shadowres) + end + + return false +end + +@register_rev function arraycopy_rev(B, orig, gutils, tape) + origops = LLVM.operands(orig) + if !is_constant_value(gutils, origops[1]) && !is_constant_value(gutils, orig) + arraycopy_common(false, B, orig, origops[1], gutils, nothing) + end + + return nothing +end + +@register_fwd function genericmemory_copy_slice_fwd(B, orig, gutils, normalR, shadowR) + ctx = LLVM.context(orig) + + if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL + return true + end + + origops = LLVM.operands(orig) + + width = get_width(gutils) + + shadowin = invert_pointer(gutils, origops[1], B) + shadowdata = invert_pointer(gutils, origops[2], B) + len = new_from_original(gutils, origops[3]) + + i8 = LLVM.IntType(8) + algn = 0 + + if width == 1 + shadowres = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + [shadowin, shadowdata, len], + [API.VT_Shadow, API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# + + # TODO zero based off runtime types, rather than presume floatlike? + if is_constant_value(gutils, origops[1]) + elSize = get_memory_elsz(B, shadowin) + elSize = LLVM.zext!(B, elSize, LLVM.IntType(8 * sizeof(Csize_t))) + length = LLVM.mul!(B, len, elSize) + bt = GPUCompiler.backtrace(orig) + btstr = sprint() do io + print(io, "\nCaused by:") + Base.show_backtrace(io, bt) + end + GPUCompiler.@safe_warn "TODO forward zero-set of memorycopy used memset rather than runtime type $btstr" + LLVM.memset!( + B, + shadowdata, + LLVM.ConstantInt(i8, 0, false), length, - isVolatile, + algn, + ) + end + if get_runtime_activity(gutils) + prev = new_from_original(gutils, orig) + shadowres = LLVM.select!( + B, + LLVM.icmp!( + B, + LLVM.API.LLVMIntNE, + shadowin, + new_from_original(gutils, origops[1]), + ), + shadowres, + prev, + ) + API.moveBefore(prev, shadowres, B) + end + else + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + ev = extract_value!(B, shadowin, idx - 1) + ev2 = extract_value!(B, shadowdata, idx - 1) + callv = call_samefunc_with_inverted_bundles!( + B, + gutils, orig, - allowForward, - !fwd, - ) #=shadowsLookedUp=# + [ev, ev2, len], + [API.VT_Shadow, API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# + if is_constant_value(gutils, origops[1]) + elSize = get_array_elsz(B, ev) + elSize = LLVM.zext!(B, elSize, LLVM.IntType(8 * sizeof(Csize_t))) + length = LLVM.mul!(B, len, elSize) + bt = GPUCompiler.backtrace(orig) + btstr = sprint() do io + print(io, "\nCaused by:") + Base.show_backtrace(io, bt) + end + GPUCompiler.@safe_warn "TODO forward zero-set of memorycopy used memset rather than runtime type $btstr" + LLVM.memset!( + B, + ev2, + LLVM.ConstantInt(i8, 0, false), + length, + algn, + ) + end + if get_runtime_activity(gutils) + prev = new_from_original(gutils, orig) + callv = LLVM.select!( + B, + LLVM.icmp!( + B, + LLVM.API.LLVMIntNE, + ev, + new_from_original(gutils, origops[1]), + ), + callv, + prev, + ) + if idx == 1 + API.moveBefore(prev, callv, B) + end + end + shadowres = insert_value!(B, shadowres, callv, idx - 1) end - end - return nothing + unsafe_store!(shadowR, shadowres.ref) + return false end -@register_aug function arraycopy_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function genericmemory_copy_slice_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL return true end - arraycopy_fwd(B, orig, gutils, normalR, shadowR) + genericmemory_copy_slice_fwd(B, orig, gutils, normalR, shadowR) origops = LLVM.operands(orig) if !is_constant_value(gutils, origops[1]) && !is_constant_value(gutils, orig) shadowres = LLVM.Value(unsafe_load(shadowR)) - arraycopy_common(true, B, orig, origops[1], gutils, shadowres) #=fwd=# + len = new_from_original(gutils, origops[3]) + memoryptr = new_from_original(gutils, origops[2]) + arraycopy_common(true, B, orig, origops[1], gutils, shadowres; len, memoryptr) end return false end -@register_rev function arraycopy_rev(B, orig, gutils, tape) +@register_rev function genericmemory_copy_slice_rev(B, orig, gutils, tape) origops = LLVM.operands(orig) if !is_constant_value(gutils, origops[1]) && !is_constant_value(gutils, orig) - arraycopy_common(false, B, orig, origops[1], gutils, nothing) #=fwd=# + len = new_from_original(gutils, origops[3]) + memoryptr = new_from_original(gutils, origops[2]) + arraycopy_common(false, B, orig, origops[1], gutils, nothing; len, memoryptr) end return nothing @@ -2010,6 +2177,12 @@ end @revfunc(arraycopy_rev), @fwdfunc(arraycopy_fwd), ) + register_handler!( + ("jl_genericmemory_copy_slice", "ijl_genericmemory_copy_slice"), + @augfunc(genericmemory_copy_slice_augfwd), + @revfunc(genericmemory_copy_slice_rev), + @fwdfunc(genericmemory_copy_slice_fwd), + ) register_handler!( ("jl_reshape_array", "ijl_reshape_array"), @augfunc(arrayreshape_augfwd), diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index 2a2d1032c1..2cba33c14e 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -52,3 +52,43 @@ function inout_rule( end return UInt8(false) end + +function inoutcopyslice_rule( + direction::Cint, + ret::API.CTypeTreeRef, + args::Ptr{API.CTypeTreeRef}, + known_values::Ptr{API.IntList}, + numArgs::Csize_t, + val::LLVM.API.LLVMValueRef, +)::UInt8 + if numArgs != 1 + return UInt8(false) + end + inst = LLVM.Instruction(val) + + legal, typ = abs_typeof(inst) + + if legal + if (direction & API.DOWN) != 0 + ctx = LLVM.context(inst) + dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) + if GPUCompiler.deserves_retbox(typ) + typ = Ptr{typ} + end + rest = typetree(typ, ctx, dl) + changed, legal = API.EnzymeCheckedMergeTypeTree(ret, rest) + @assert legal + end + return UInt8(false) + end + + if (direction & API.UP) != 0 + changed, legal = API.EnzymeCheckedMergeTypeTree(unsafe_load(args), ret) + @assert legal + end + if (direction & API.DOWN) != 0 + changed, legal = API.EnzymeCheckedMergeTypeTree(ret, unsafe_load(args)) + @assert legal + end + return UInt8(false) +end From 660dc228b9238bb36e00cb5dbca0ea1d70933029 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 17 Oct 2024 00:35:25 -0500 Subject: [PATCH 364/495] Partial newstructt info (#1976) * Partial newstructt info * More absint * fixup load if mixed * fixup * fix * cleanup * cleanup --- src/absint.jl | 29 ++++++++++++ src/jlrt.jl | 58 ++++++++++++++++++++++-- src/rules/typeunstablerules.jl | 83 ++++++++++++++++++++++++---------- 3 files changed, 141 insertions(+), 29 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 5519b0b862..4ed9d6c91b 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -348,6 +348,35 @@ function abs_typeof( return (true, res, GPUCompiler.BITS_REF) end end + + if nm == "jl_f__apply_iterate" || nm == "ijl_f__apply_iterate" + index += 1 + found = [] + unionalls = [] + legal, iterfn = absint(operands(arg)[index]) + index += 1 + if legal && iterfn == Base.iterate + legal0, combfn = absint(operands(arg)[index]) + index += 1 + if legal0 && combfn == Core.apply_type && partial + return (true, Type, GPUCompiler.BITS_REF) + end + resvals = [] + while index != length(operands(arg)) + legal, pval, _ = abs_typeof(operands(arg)[index], partial) + if !legal + break + end + push!(resvals, pval) + index+=1 + end + if legal0 && legal && combfn == Base.tuple && partial && length(resvals) == 1 + if resvals[1] <: Vector + return (true, Tuple{Vararg{eltype(resvals[1])}}, GPUCompiler.BITS_REF) + end + end + end + end end if nm == "julia.call" diff --git a/src/jlrt.jl b/src/jlrt.jl index 4f2ca71801..47c5ce3d65 100644 --- a/src/jlrt.jl +++ b/src/jlrt.jl @@ -122,6 +122,22 @@ function emit_jl!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value call!(B, FT, fn, [val]) end +function emit_jl_isa!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(ty::LLVM.Value))::LLVM.Value + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + ity = LLVM.IntType(8*sizeof(Int)) + FT = LLVM.FunctionType(ity, [T_prjlvalue, T_prjlvalue]) + fn, _ = get_function!(mod, "jl_isa", FT) + call!(B, FT, fn, [val, val]) +end + +function emit_jl_isa!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(ty::Type))::LLVM.Value + emit_jl_isa!(B, val, unsafe_to_llvm(B, ty)) +end + function emit_getfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(fld::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) @@ -283,12 +299,27 @@ function emit_svec!(B::LLVM.IRBuilder, @nospecialize(args))::LLVM.Value end -function val_from_byref_if_mixed(B::LLVM.IRBuilder, @nospecialize(oval::LLVM.Value), @nospecialize(val::LLVM.Value)) +function load_if_mixed(oval::OT, val::VT) where {OT, VT} + if !(oval isa Base.RefValue) && (val isa Base.RefValue) + return val[] + else + return val + end +end + +function val_from_byref_if_mixed(B::LLVM.IRBuilder, gutils, @nospecialize(oval::LLVM.Value), @nospecialize(val::LLVM.Value)) + world = enzyme_extract_world(LLVM.parent(position(B))) legal, TT, _ = abs_typeof(oval) if !legal - throw(AssertionError("Could not determine type of value within jl_newstructt arg: $(string(oval))")) + legal, TT, _ = abs_typeof(oval, true) + if legal + act = active_reg_inner(TT, (), world) + if act == AnyState + return val + end + end + return emit_apply_generic!(B, [unsafe_to_llvm(B, load_if_mixed), new_from_original(gutils, oval), val]) end - world = enzyme_extract_world(LLVM.parent(position(B))) act = active_reg_inner(TT, (), world) if act == ActiveState || act == MixedState legal2, TT2, _ = abs_typeof(val) @@ -316,10 +347,27 @@ function val_from_byref_if_mixed(B::LLVM.IRBuilder, @nospecialize(oval::LLVM.Val end end +function ref_if_mixed(val::VT) where {VT} + if active_reg_inner(Core.Typeof(val), (), nothing, Val(true)) == ActiveState + return Ref(val) + else + return val + end +end + function byref_from_val_if_mixed(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value)) - legal, TT, _ = abs_typeof(val) - @assert legal world = enzyme_extract_world(LLVM.parent(position(B))) + legal, TT, _ = abs_typeof(val) + if !legal + legal, TT, _ = abs_typeof(val, true) + act = active_reg_inner(TT, (), world) + if act == AnyState + return val + end + if !legal + return emit_apply_generic!(B, [unsafe_to_llvm(B, ref_if_mixed), val]) + end + end act = active_reg_inner(TT, (), world) if act == ActiveState || act == MixedState diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 6371e409b1..723ed23a31 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -856,7 +856,7 @@ end shadowsin = invert_pointer(gutils, origops[2], B) if width == 1 - vals = [new_from_original(gutils, origops[1]), val_from_byref_if_mixed(B, origops[2], shadowsin)] + vals = [new_from_original(gutils, origops[1]), val_from_byref_if_mixed(B, gutils, origops[2], shadowsin)] shadowres = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), vals) callconv!(shadowres, callconv(orig)) shadowres = byref_from_val_if_mixed(B, shadowres) @@ -866,7 +866,7 @@ end for idx = 1:width vals = [ new_from_original(gutils, origops[1]), - val_from_byref_if_mixed(B, origops[2], extract_value!(B, shadowsin, idx - 1)), + val_from_byref_if_mixed(B, gutils, origops[2], extract_value!(B, shadowsin, idx - 1)), ] tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), vals) callconv!(tmp, callconv(orig)) @@ -877,11 +877,15 @@ end unsafe_store!(shadowR, shadowres.ref) legal, TT, _ = abs_typeof(orig) - @assert legal - world = enzyme_extract_world(LLVM.parent(position(B))) - act = active_reg_inner(TT, (), world) - if act == ActiveState || act == MixedState + if !legal unsafe_store!(tapeR, shadowres.ref) + else + @assert legal + world = enzyme_extract_world(LLVM.parent(position(B))) + act = active_reg_inner(TT, (), world) + if act == ActiveState || act == MixedState + unsafe_store!(tapeR, shadowres.ref) + end end return false @@ -900,7 +904,7 @@ end end end -@generated function runtime_newstructt_rev(::Val{Width}, revres0::RR0, revarg0::RA0, args::Vararg{Any, N}) where {Width, RR0, RA0, N} +@generated function runtime_newstructt_rev(::Val{Width}, origres::Type{OR}, revres0::RR0, revarg0::RA0, args::Vararg{Any, N}) where {Width, OR, RR0, RA0, N} exprs = Expr[] for i in 1:Width dres = if i == 1 @@ -913,19 +917,37 @@ end else :(args[$(2*(i-2)+1+1)]) end - push!(exprs, quote - @assert $dres isa Base.RefValue - if $darg isa Base.RefValue - tmparg = $darg[] - tmpres = $dres[] - $darg[] = recursive_tuple(Val(length(tmparg)), tmparg, tmpres) - else - error( - "Enzyme Mutability Error: Cannot accumulate in place to immutable value " * - string($darg), - ) - end - end) + if OR == Nothing + push!(exprs, quote + @assert $dres isa Base.RefValue + if $darg isa Base.RefValue + tmparg = $darg[] + tmpres = $dres[] + $darg[] = recursive_tuple(Val(length(tmparg)), tmparg, tmpres) + else + error( + "Enzyme Mutability Error: Cannot accumulate in place to immutable value " * + string($darg), + ) + end + end) + elseif OR <: Base.RefValue + else + push!(exprs, quote + if $dres isa Base.RefValue + if $darg isa Base.RefValue + tmparg = $darg[] + tmpres = $dres[] + $darg[] = recursive_tuple(Val(length(tmparg)), tmparg, tmpres) + else + error( + "Enzyme Mutability Error: Cannot accumulate in place to immutable value " * + string($darg), + ) + end + end + end) + end end expr = quote Base.@_inline_meta @@ -959,15 +981,28 @@ end width = get_width(gutils) legal, TT, _ = abs_typeof(orig) - @assert legal - world = enzyme_extract_world(LLVM.parent(position(B))) - act = active_reg_inner(TT, (), world) - if act == ActiveState || act == MixedState + torun = false + if legal + @assert legal + world = enzyme_extract_world(LLVM.parent(position(B))) + act = active_reg_inner(TT, (), world) + torun = act == ActiveState || act == MixedState + else + torun = true + end + + if torun vals = LLVM.Value[ unsafe_to_llvm(B, runtime_newstructt_rev), unsafe_to_llvm(B, Val(Int(width))), ] + if legal + push!(vals, unsafe_to_llvm(B, Nothing)) + else + push!(vals, lookup_value(gutils, new_from_original(gutils, origops[1]), B)) + end + shadowsin = lookup_value(gutils, invert_pointer(gutils, origops[2], B), B) if width == 1 push!(vals, tape) From d862d54e328cd5c9763fb9b60a2204332fa401d0 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 17 Oct 2024 00:36:50 -0500 Subject: [PATCH 365/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ad6ee87529..d46d9dd9df 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.10" +version = "0.13.11" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 1032b7121ead26cd9e7cc7b7c0548cb1a4696ce6 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 17 Oct 2024 07:56:23 +0200 Subject: [PATCH 366/495] Revert "Revert "Fix code coverage & update action versions"" (#1977) * Revert "Revert "Fix code coverage & update action versions (#1954)" (#1974)" This reverts commit e2b0e41ea770a7a0a7e9a8566b975e48d340f45e. * Cache action versions * More cache v2 --- .github/workflows/CI.yml | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 60d713c529..c541e97a19 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -54,13 +54,13 @@ jobs: version: '1.11' assertions: true steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 if: ${{ ! matrix.assertions }} with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 if: ${{ matrix.assertions }} with: repository: 'JuliaLang/julia' @@ -72,7 +72,7 @@ jobs: sed -i.bak 's/exit 2/exit 0/g' julia/deps/tools/jlchecksum make -C julia -j $(nproc) FORCE_ASSERTIONS=1 LLVM_ASSERTIONS=1 JULIA_PRECOMPILE=0 echo $PWD/julia/usr/bin >> $GITHUB_PATH - - uses: actions/cache@v1 + - uses: actions/cache@v2 env: cache-name: cache-artifacts with: @@ -120,10 +120,12 @@ jobs: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager - uses: julia-actions/julia-processcoverage@v1 if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' - - uses: codecov/codecov-action@v1 + - uses: codecov/codecov-action@v4 if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' with: file: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false # or true if you want CI to fail when Codecov fails enzymetestutils: name: EnzymeTestUtils - Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ matrix.libEnzyme }} libEnzyme - ${{ github.event_name }} runs-on: ${{ matrix.os }} @@ -143,12 +145,12 @@ jobs: - x64 libEnzyme: [packaged] steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v1 + - uses: actions/cache@v2 env: cache-name: cache-artifacts with: @@ -180,10 +182,12 @@ jobs: if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' with: directories: lib/EnzymeTestUtils/src - - uses: codecov/codecov-action@v2 + - uses: codecov/codecov-action@v4 if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' with: files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false # or true if you want CI to fail when Codecov fails integration: name: Integration Tests - ${{ matrix.test }} runs-on: ${{ matrix.os }} @@ -200,10 +204,10 @@ jobs: - DynamicExpressions steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - name: "Run tests" run: | @@ -214,11 +218,11 @@ jobs: name: Documentation runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 with: version: '1' - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - run: | julia --project=docs -e ' using Pkg From e50e8ad4663eca6cf11f6b4fbfdb28faf58803df Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 17 Oct 2024 00:57:19 -0500 Subject: [PATCH 367/495] Update Project.toml (#1978) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d46d9dd9df..e7b7bb9ae0 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.4" -Enzyme_jll = "0.0.153" +Enzyme_jll = "0.0.154" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From 6987986820a651b4e1977cc73113ce532deac7ed Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 17 Oct 2024 15:30:39 -0500 Subject: [PATCH 368/495] More Julia 1.11 (#1981) * More Julia 1.11 * nmi * nmi --- src/compiler.jl | 14 +++++++------- src/jlrt.jl | 6 ++++-- src/rules/customrules.jl | 8 ++++---- src/utils.jl | 28 ++++++++++++++++++++++++++++ 4 files changed, 43 insertions(+), 13 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 8f8a01f713..d5c87fcaad 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -680,7 +680,7 @@ end inactivety = if typeof(world) === Nothing EnzymeCore.EnzymeRules.inactive_type(T) else - inmi = GPUCompiler.methodinstance( + inmi = my_methodinstance( typeof(EnzymeCore.EnzymeRules.inactive_type), Tuple{Type{T}}, world, @@ -1135,7 +1135,7 @@ end include("make_zero.jl") function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, f, tt, world) - funcspec = GPUCompiler.methodinstance(typeof(f), tt, world) + funcspec = my_methodinstance(typeof(f), tt, world) nested_codegen!(mode, mod, funcspec, world) end @@ -2984,9 +2984,9 @@ Create the methodinstance pair, and lookup the primal return type. primal_tt = Tuple{map(eltype, _tt)...} primal = if world isa Nothing - GPUCompiler.methodinstance(F, primal_tt) + my_methodinstance(F, primal_tt) else - GPUCompiler.methodinstance(F, primal_tt, world) + my_methodinstance(F, primal_tt, world) end return primal @@ -4544,7 +4544,7 @@ function create_abi_wrapper( push!(realparms, val) elseif T <: BatchDuplicatedFunc Func = get_func(T) - funcspec = GPUCompiler.methodinstance(Func, Tuple{}, world) + funcspec = my_methodinstance(Func, Tuple{}, world) llvmf = nested_codegen!(Mode, mod, funcspec, world) push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) Func_RT = Core.Compiler.typeinf_ext_toplevel(interp, funcspec).rettype @@ -7373,7 +7373,7 @@ function GPUCompiler.codegen( ((LLVM.DoubleType(), Float64, ""), (LLVM.FloatType(), Float32, "f")) fname = String(name) * pf if haskey(functions(mod), fname) - funcspec = GPUCompiler.methodinstance(fnty, Tuple{JT}, world) + funcspec = my_methodinstance(fnty, Tuple{JT}, world) llvmf = nested_codegen!(mode, mod, funcspec, world) push!(function_attributes(llvmf), StringAttribute("implements", fname)) end @@ -8662,7 +8662,7 @@ include("compiler/reflection.jl") target = Compiler.DefaultCompilerTarget() params = Compiler.PrimalCompilerParams(API.DEM_ForwardMode) - mi = GPUCompiler.methodinstance(fn, Tuple{T, Int}) + mi = my_methodinstance(fn, Tuple{T, Int}) job = CompilerJob(mi, CompilerConfig(target, params; kernel = false)) mod, meta = GPUCompiler.codegen( :llvm, diff --git a/src/jlrt.jl b/src/jlrt.jl index 47c5ce3d65..ca2491cece 100644 --- a/src/jlrt.jl +++ b/src/jlrt.jl @@ -803,10 +803,12 @@ function emit_layout_of_type!(B, ty) @assert !isa(ty, LLVM.Constant) dt = get_datatype_struct() lty = bitcast!(B, ty, LLVM.PointerType(dt, addrspace(value_type(ty)))) - layoutp = inbounds_gep!(B, dt, ty, + layoutp = inbounds_gep!(B, dt, lty, LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(5))], ) - layout = load!(B, lptr, layoutp) + jlvaluet = LLVM.PointerType(LLVM.StructType(LLVMType[]), 10) + layout = load!(B, jlvaluet, layoutp) + layout = bitcast!(B, layout, lptr) return layout end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index cb6c60d98d..fbd646866b 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -816,11 +816,11 @@ end augprimal_TT = Tuple{augprimal_tt...} kwfunc = Core.kwfunc(EnzymeRules.augmented_primal) try - ami = GPUCompiler.methodinstance(Core.Typeof(kwfunc), augprimal_TT, world) + ami = my_methodinstance(Core.Typeof(kwfunc), augprimal_TT, world) @safe_debug "Applying custom augmented_primal rule (kwcall)" TT = augprimal_TT catch e augprimal_TT = Tuple{typeof(world),typeof(kwfunc),augprimal_TT.parameters...} - ami = GPUCompiler.methodinstance( + ami = my_methodinstance( typeof(custom_rule_method_error), augprimal_TT, world, @@ -836,7 +836,7 @@ end augprimal_TT = Tuple{augprimal_tt...} try - ami = GPUCompiler.methodinstance( + ami = my_methodinstance( Core.Typeof(EnzymeRules.augmented_primal), augprimal_TT, world, @@ -848,7 +848,7 @@ end typeof(EnzymeRules.augmented_primal), augprimal_TT.parameters..., } - ami = GPUCompiler.methodinstance( + ami = my_methodinstance( typeof(custom_rule_method_error), augprimal_TT, world, diff --git a/src/utils.jl b/src/utils.jl index cad25cf277..167e589713 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -249,3 +249,31 @@ end end export codegen_world_age + + +if VERSION >= v"1.11.0-DEV.1552" + +# XXX: version of Base.method_instance that uses a function type +@inline function my_methodinstance(@nospecialize(ft::Type), @nospecialize(tt::Type), + world::Integer=tls_world_age()) + sig = GPUCompiler.signature_type_by_tt(ft, tt) + # @assert Base.isdispatchtuple(sig) # JuliaLang/julia#52233 + + mi = ccall(:jl_method_lookup_by_tt, Any, + (Any, Csize_t, Any), + sig, world, #=method_table=# nothing) + mi === nothing && throw(MethodError(ft, tt, world)) + mi = mi::MethodInstance + + # `jl_method_lookup_by_tt` and `jl_method_lookup` can return a unspecialized mi + if !Base.isdispatchtuple(mi.specTypes) + mi = Core.Compiler.specialize_method(mi.def, sig, mi.sparam_vals)::MethodInstance + end + + return mi +end +else + import GPUCompiler: methodinstance as my_methodinstance +end + +export my_methodinstance From da53c038cdf9a2e6dd4536dfd20847cdaeb73fb3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 17 Oct 2024 20:11:12 -0700 Subject: [PATCH 369/495] Continuing 1.11 stuff (#1984) * Continuing 1.11 stuff * cleanup * fix * fix * fix * fixup * fixup * bypass for now * more info and utter confusion * more stringent assertions * correct checks * s * better prints * clean --- src/compiler.jl | 2 +- src/compiler/interpreter.jl | 38 ++++++++ src/compiler/optimize.jl | 76 ++++++++++++++- src/compiler/validation.jl | 178 +++++++++++++++++++++++++++++++++--- src/jlrt.jl | 34 +++---- 5 files changed, 298 insertions(+), 30 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index d5c87fcaad..c74dcfc912 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2319,7 +2319,7 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie world = enzyme_extract_world(fn) has, Ty, byref = abs_typeof(V) if !has - throw(AssertionError("Allocation could not have its type statically determined $(string(V))")) + throw(AssertionError("$(string(fn))\n Allocation could not have its type statically determined $(string(V))")) end rt = active_reg_inner(Ty, (), world) if rt == ActiveState || rt == MixedState diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 4d48297ae5..937f61e77d 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -287,6 +287,21 @@ struct AutodiffCallInfo <: CallInfo info::CallInfo end +@static if VERSION < v"1.11.0-" +else + @inline function myunsafe_copyto!(dest::MemoryRef{T}, src::MemoryRef{T}, n) where {T} + Base.@_terminates_globally_notaskstate_meta + @boundscheck memoryref(dest, n), memoryref(src, n) + t1 = Base.@_gc_preserve_begin dest + t2 = Base.@_gc_preserve_begin src + Base.memmove(pointer(dest), pointer(src), n * Base.aligned_sizeof(T)) + Base.@_gc_preserve_end t2 + Base.@_gc_preserve_end t1 + return dest + end +end + + function abstract_call_known( interp::EnzymeInterpreter, @nospecialize(f), @@ -322,6 +337,29 @@ function abstract_call_known( end end + @static if VERSION < v"1.11.0-" + else + if f === Base.unsafe_copyto! && length(argtypes) == 4 && + widenconst(argtypes[2]) <: Base.MemoryRef && + widenconst(argtypes[3]) == widenconst(argtypes[2]) && + Base.allocatedinline(eltype(widenconst(argtypes[2]))) && Base.isbitstype(eltype(widenconst(argtypes[2]))) + + arginfo2 = ArgInfo( + fargs isa Nothing ? nothing : + [:(Enzyme.Compiler.Interpreter.myunsafe_copyto!), fargs[2:end]...], + [Core.Const(Enzyme.Compiler.Interpreter.myunsafe_copyto!), argtypes[2:end]...], + ) + return abstract_call_known( + interp, + Enzyme.Compiler.Interpreter.myunsafe_copyto!, + arginfo2, + si, + sv, + max_methods, + ) + end + end + if f === Enzyme.autodiff && length(argtypes) >= 4 if widenconst(argtypes[2]) <: Enzyme.Mode && widenconst(argtypes[3]) <: Enzyme.Annotation && diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 7214aa540f..d827a7b9dd 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -783,6 +783,7 @@ function nodecayed_phis!(mod::LLVM.Module) b = IRBuilder() position!(b, terminator(pb)) + v0 = v @inline function getparent(v, offset, hasload) if addr == 11 && addrspace(value_type(v)) == 10 @@ -794,16 +795,87 @@ function nodecayed_phis!(mod::LLVM.Module) if addr == 13 && !hasload if isa(v, LLVM.LoadInst) v2, o2, hl2 = getparent(operands(v)[1], LLVM.ConstantInt(offty, 0), true) - @assert o2 == LLVM.ConstantInt(offty, 0) + rhs = LLVM.ConstantInt(offty, 0) + if o2 != rhs + msg = sprint() do io::IO + println( + io, + "Enzyme internal error addr13 load doesn't keep offset 0", + ) + println(io, "v=", string(v)) + println(io, "v2=", string(v2)) + println(io, "o2=", string(o2)) + println(io, "hl2=", string(hl2)) + println(io, "offty=", string(offty)) + println(io, "rhs=", string(rhs)) + end + throw(AssertionError(msg)) + end return v2, offset, true end if isa(v, LLVM.CallInst) cf = LLVM.called_operand(v) if isa(cf, LLVM.Function) && LLVM.name(cf) == "julia.gc_loaded" ld = operands(v)[2] + while isa(ld, LLVM.BitCastInst) || isa(ld, LLVM.AddrSpaceCastInst) + ld = operands(ld)[1] + end if isa(ld, LLVM.LoadInst) v2, o2, hl2 = getparent(operands(ld)[1], LLVM.ConstantInt(offty, 0), true) - @assert o2 == LLVM.ConstantInt(offty, sizeof(Int)) + rhs = LLVM.ConstantInt(offty, sizeof(Int)) + if o2 != rhs + msg = sprint() do io::IO + println( + io, + "Enzyme internal error addr13 load doesn't keep offset 0", + ) + println(io, "mod=", string(LLVM.parent(f))) + println(io, "f=", string(f)) + println(io, "v=", string(v)) + println(io, "opv[1]=", string(operands(v)[1])) + println(io, "opv[2]=", string(operands(v)[2])) + println(io, "ld=", string(ld)) + println(io, "ld_op[1]=", string(operands(ld)[1])) + + println(io, "v2=", string(v2)) + println(io, "o2=", string(o2)) + println(io, "hl2=", string(hl2)) + + println(io, "offty=", string(offty)) + println(io, "rhs=", string(rhs)) + end + throw(AssertionError(msg)) + end + + # We currently only support gc_loaded(mem, ptr) where ptr = (({size_t, {}*}*)mem)->second + # [aka a load of the second element of mem] + base_2, off_2, _ = get_base_and_offset(v2) + base_1, off_1, _ = get_base_and_offset(operands(v)[1]) + if base_1 != base_2 || off_1 != off_2 + msg = sprint() do io::IO + println( + io, + "Enzyme internal error addr13 load data isn't offset of mem", + ) + println(io, "f=", string(f)) + println(io, "v=", string(v)) + println(io, "opv[1]=", string(operands(v)[1])) + println(io, "opv[2]=", string(operands(v)[2])) + println(io, "ld=", string(ld)) + println(io, "ld_op[1]=", string(operands(ld)[1])) + + println(io, "v2=", string(v2)) + println(io, "o2=", string(o2)) + println(io, "hl2=", string(hl2)) + + println(io, "base_1=", string(base_1)) + println(io, "base_2=", string(base_2)) + println(io, "off_1=", string(off_1)) + println(io, "off_2=", string(off_2)) + end + throw(AssertionError(msg)) + end + return v2, offset, true end end diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 839aa120d7..4f341ac3f5 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -157,6 +157,24 @@ function restore_lookups(mod::LLVM.Module) eraseInst(mod, f) end end + for f in functions(mod) + for fattr in collect(function_attributes(f)) + if isa(fattr, LLVM.StringAttribute) + if kind(fattr) == "enzymejl_needs_restoration" + v = parse(UInt, LLVM.value(fattr)) + replace_uses!( + f, + LLVM.Value( + LLVM.API.LLVMConstIntToPtr( + ConstantInt(T_size_t, convert(UInt, v)), + value_type(f), + ), + ), + ) + end + end + end + end end function check_ir(job, mod::LLVM.Module) @@ -457,25 +475,163 @@ function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns) initfn = unwrap_ptr_casts(LLVM.initializer(fn_got)) loadfn = first(instructions(first(blocks(initfn))))::LLVM.LoadInst - opv = operands(loadfn)[1]::LLVM.GlobalVariable - - if startswith(fname, "jl_") || startswith(fname, "ijl_") - else - @assert "unsupported jl got" + opv = operands(loadfn)[1] + if !isa(opv, LLVM.GlobalVariable) msg = sprint() do io::IO println( io, - "Enzyme internal error unsupported got", + "Enzyme internal error unsupported got(load)", ) - println(io, "inst=", inst) - println(io, "fname=", fname) - println(io, "FT=", FT) - println(io, "fn_got=", fn_got) - println(io, "init=", string(initfn)) + println(io, "mod=", string(mod)) + println(io, "initfn=", string(initfn)) + println(io, "loadfn=", string(loadfn)) println(io, "opv=", string(opv)) end throw(AssertionError(msg)) end + opv = opv::LLVM.GlobalVariable + + if startswith(fname, "jl_") || startswith(fname, "ijl_") || startswith(fname, "_j_") + else + found = nothing + for lbb in blocks(initfn), linst in collect(instructions(lbb)) + if !isa(linst, LLVM.CallInst) + continue + end + cv = LLVM.called_value(linst) + if !isa(cv, LLVM.Function) + continue + end + if LLVM.name(cv) == "ijl_load_and_lookup" + found = linst + break + end + end + if found == nothing + msg = sprint() do io::IO + println( + io, + "Enzyme internal error unsupported got", + ) + println(io, "inst=", inst) + println(io, "fname=", fname) + println(io, "FT=", FT) + println(io, "fn_got=", fn_got) + println(io, "init=", string(initfn)) + println(io, "opv=", string(opv)) + end + throw(AssertionError(msg)) + end + + legal1, arg1 = abs_cstring(operands(found)[1]) + if legal1 + else + arg1 = operands(found)[1] + + while isa(arg1, ConstantExpr) + if opcode(arg1) == LLVM.API.LLVMAddrSpaceCast || + opcode(arg1) == LLVM.API.LLVMBitCast || + opcode(arg1) == LLVM.API.LLVMIntToPtr + arg1 = operands(arg1)[1] + else + break + end + end + if !isa(arg1, LLVM.ConstantInt) + msg = sprint() do io::IO + println( + io, + "Enzyme internal error unsupported got(arg1)", + ) + println(io, "inst=", inst) + println(io, "fname=", fname) + println(io, "FT=", FT) + println(io, "fn_got=", fn_got) + println(io, "init=", string(initfn)) + println(io, "opv=", string(opv)) + println(io, "found=", string(found)) + println(io, "arg1=", string(arg1)) + end + throw(AssertionError(msg)) + end + + arg1 = reinterpret(Ptr{Cvoid}, convert(UInt, arg1)) + end + + legal2, fname = abs_cstring(operands(found)[2]) + if !legal2 + msg = sprint() do io::IO + println( + io, + "Enzyme internal error unsupported got(fname)", + ) + println(io, "inst=", inst) + println(io, "fname=", fname) + println(io, "FT=", FT) + println(io, "fn_got=", fn_got) + println(io, "init=", string(initfn)) + println(io, "opv=", string(opv)) + println(io, "found=", string(found)) + println(io, "fname=", string(operands(found)[2])) + end + throw(AssertionError(msg)) + end + + hnd = operands(found)[3] + + if !isa(hnd, LLVM.GlobalVariable) + msg = sprint() do io::IO + println( + io, + "Enzyme internal error unsupported got(hnd)", + ) + println(io, "inst=", inst) + println(io, "fname=", fname) + println(io, "FT=", FT) + println(io, "fn_got=", fn_got) + println(io, "init=", string(initfn)) + println(io, "opv=", string(opv)) + println(io, "found=", string(found)) + println(io, "hnd=", string(hnd)) + end + throw(AssertionError(msg)) + end + hnd = LLVM.name(hnd) + # println(string(mod)) + + # TODO we don't restore/lookup now because this fails + # @vchuravy / @gbaraldi this needs help looking at how to get the actual handle and setup + + if true + res = nothing + elseif arg1 isa AbstractString + res = ccall( + :ijl_load_and_lookup, + Ptr{Cvoid}, + (Cstring, Cstring, Ptr{Cvoid}), + arg1, + fname, + reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr), + ) + else + res = ccall( + :ijl_load_and_lookup, + Ptr{Cvoid}, + (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), + arg1, + fname, + reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr), + ) + end + + if res !== nothing + push!(function_attributes(newf), StringAttribute("enzymejl_needs_restoration", string(convert(UInt, res)))) + end + # TODO we can make this relocatable if desired by having restore lookups re-create this got initializer/etc + # metadata(newf)["enzymejl_flib"] = flib + # metadata(newf)["enzymejl_flib"] = flib + + end if value_type(newf) != value_type(inst) newf = const_pointercast(newf, value_type(inst)) diff --git a/src/jlrt.jl b/src/jlrt.jl index ca2491cece..5a8cf33e0c 100644 --- a/src/jlrt.jl +++ b/src/jlrt.jl @@ -1,15 +1,5 @@ # For julia runtime function emission - -declare_allocobj!(mod::LLVM.Module) = - get_function!(mod, "julia.gc_alloc_obj") do - T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - T_ppjlvalue = LLVM.PointerType(LLVM.PointerType(T_jlvalue)) - T_size_t = convert(LLVM.LLVMType, Int) - - - LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue]) - end + function emit_allocobj!( B::LLVM.IRBuilder, @nospecialize(tag::LLVM.Value), @@ -24,6 +14,7 @@ function emit_allocobj!( T_jlvalue = LLVM.StructType(LLVMType[]) T_pjlvalue = LLVM.PointerType(T_jlvalue) T_ppjlvalue = LLVM.PointerType(T_pjlvalue) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) T_int8 = LLVM.Int8Type() T_pint8 = LLVM.PointerType(T_int8) @@ -35,12 +26,16 @@ function emit_allocobj!( bitcast!(B, pgcstack, T_ppjlvalue), [LLVM.ConstantInt(current_task_offset())], ) - ptls_field = inbounds_gep!(B, T_pjlvalue, ct, [LLVM.ConstantInt(current_ptls_offset())]) - T_ppint8 = LLVM.PointerType(T_pint8) - ptls = load!(B, T_pint8, bitcast!(B, ptls_field, T_ppint8)) + + @static if VERSION < v"1.11.0-" + ptls_field = inbounds_gep!(B, T_pjlvalue, ct, [LLVM.ConstantInt(current_ptls_offset())]) + T_ppint8 = LLVM.PointerType(T_pint8) + ptls = load!(B, T_pint8, bitcast!(B, ptls_field, T_ppint8)) + else + ct = bitcast!(B, ct, T_pjlvalue) + end if needs_workaround - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) T_size_t = convert(LLVM.LLVMType, Int) # This doesn't allow for optimizations alty = LLVM.FunctionType(T_prjlvalue, [T_pint8, T_size_t, T_prjlvalue]) @@ -51,8 +46,15 @@ function emit_allocobj!( return call!(B, alty, alloc_obj, [ptls, Size, tag]) end + T_size_t = convert(LLVM.LLVMType, Int) + + @static if VERSION < v"1.11.0-" + alty = LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue]) + else + alty = LLVM.FunctionType(T_prjlvalue, [T_pjlvalue, T_size_t, T_prjlvalue]) + end - alloc_obj, alty = declare_allocobj!(mod) + alloc_obj, _ = get_function!(mod, "julia.gc_alloc_obj", alty) return call!(B, alty, alloc_obj, [ct, Size, tag], name) end From 9e945a5936f75eaab891361948148342c8f8772d Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 19 Oct 2024 12:07:55 -0700 Subject: [PATCH 370/495] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index e7b7bb9ae0..480409a3f0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.11" +version = "0.13.12" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.4" -Enzyme_jll = "0.0.154" +Enzyme_jll = "0.0.155" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From 72763e9aa28978ac820c286c01d4bfd00aa451a3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 20 Oct 2024 15:52:00 -0700 Subject: [PATCH 371/495] 1.11: the adventure continues, destroy (#1986) * 1.11: the adventure continues, destroy * fix * fixup * fix * cleanup * fix * fix * fix * fix * fix * fix * fix * fix --- src/absint.jl | 50 +++++++---- src/compiler.jl | 16 +++- src/compiler/optimize.jl | 73 ++-------------- src/rules/llvmrules.jl | 176 ++++++++++++--------------------------- src/rules/typerules.jl | 40 +++++++++ src/typetree.jl | 40 ++++++++- src/utils.jl | 26 ++++++ test/typetree.jl | 5 ++ 8 files changed, 219 insertions(+), 207 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 4ed9d6c91b..dba99d2b00 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -42,6 +42,9 @@ function absint(arg::LLVM.Value, partial::Bool = false) if nm == "julia.pointer_from_objref" return absint(operands(arg)[1], partial) end + if nm == "julia.gc_loaded" + return absint(operands(arg)[2], partial) + end if nm == "jl_typeof" || nm == "ijl_typeof" vals = abs_typeof(operands(arg)[1], partial) return (vals[1], vals[2]) @@ -158,7 +161,13 @@ function absint(arg::LLVM.Value, partial::Bool = false) end function actual_size(@nospecialize(typ2)) - if typ2 <: Array || typ2 <: AbstractString || typ2 <: Symbol + @static if VERSION < v"1.11-" + if typ2 <: Array + return sizeof(Int) + end + else + end + if typ2 <: AbstractString || typ2 <: Symbol return sizeof(Int) elseif Base.isconcretetype(typ2) return sizeof(typ2) @@ -256,6 +265,11 @@ function abs_typeof( return abs_typeof(operands(arg)[1], partial) end + if nm == "julia.gc_loaded" + legal, res, byref = abs_typeof(operands(arg)[2], partial) + return legal, res, byref + end + for (fname, ty) in ( ("jl_box_int64", Int64), ("ijl_box_int64", Int64), @@ -453,7 +467,7 @@ function abs_typeof( fo = fieldoffset(typ, i) if fieldoffset(typ, i) == offset offset = 0 - typ = fieldtype(typ, i) + typ = typed_fieldtype(typ, i) if !Base.allocatedinline(typ) if byref != GPUCompiler.BITS_VALUE legal = false @@ -464,7 +478,7 @@ function abs_typeof( break elseif fieldoffset(typ, i) > offset offset = offset - fieldoffset(typ, lasti) - typ = fieldtype(typ, lasti) + typ = typed_fieldtype(typ, lasti) @assert Base.isconcretetype(typ) if !Base.allocatedinline(typ) legal = false @@ -477,15 +491,15 @@ function abs_typeof( lasti = i end end - if !seen && fieldcount(typ) > 0 - offset = offset - fieldoffset(typ, lasti) - typ = fieldtype(typ, lasti) - @assert Base.isconcretetype(typ) - if !Base.allocatedinline(typ) - legal = false - end - seen = true - end + if !seen && fieldcount(typ) > 0 + offset = offset - fieldoffset(typ, lasti) + typ = typed_fieldtype(typ, lasti) + @assert Base.isconcretetype(typ) + if !Base.allocatedinline(typ) + legal = false + end + seen = true + end if !seen legal = false end @@ -495,8 +509,14 @@ function abs_typeof( while legal && should_recurse(typ2, value_type(arg), byref, dl) idx, _ = first_non_ghost(typ2) if idx != -1 - typ2 = fieldtype(typ2, idx) - if !Base.allocatedinline(typ2) + typ2 = typed_fieldtype(typ2, idx) + if Base.allocatedinline(typ2) + if byref == GPUCompiler.BITS_VALUE + continue + end + legal = false + break + else if byref != GPUCompiler.BITS_VALUE legal = false break @@ -532,7 +552,7 @@ function abs_typeof( @assert Base.isconcretetype(typ) cnt = 0 for i = 1:fieldcount(typ) - styp = fieldtype(typ, i) + styp = typed_fieldtype(typ, i) if isghostty(styp) continue end diff --git a/src/compiler.jl b/src/compiler.jl index c74dcfc912..cf879000b5 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -467,7 +467,7 @@ end return Val(AnyState) end - subT = fieldtype(T, f) + subT = typed_fieldtype(T, f) if justActive && !allocatedinline(subT) return Val(AnyState) @@ -2441,7 +2441,7 @@ function zero_single_allocation(builder, jlType, LLVMType, nobj, zeroAll, idx) if isa(ty, LLVM.StructType) i = 1 for ii = 1:fieldcount(jlty) - jlet = fieldtype(jlty, ii) + jlet = typed_fieldtype(jlty, ii) if isghostty(jlet) || Core.Compiler.isconstType(jlet) continue end @@ -3816,6 +3816,18 @@ function enzyme!( LLVM.API.LLVMValueRef, ) ), + "julia.gc_loaded" => @cfunction( + inoutgcloaded_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), "julia.pointer_from_objref" => @cfunction( inout_rule, UInt8, diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index d827a7b9dd..86760b423f 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -823,60 +823,18 @@ function nodecayed_phis!(mod::LLVM.Module) if isa(ld, LLVM.LoadInst) v2, o2, hl2 = getparent(operands(ld)[1], LLVM.ConstantInt(offty, 0), true) rhs = LLVM.ConstantInt(offty, sizeof(Int)) - if o2 != rhs - msg = sprint() do io::IO - println( - io, - "Enzyme internal error addr13 load doesn't keep offset 0", - ) - println(io, "mod=", string(LLVM.parent(f))) - println(io, "f=", string(f)) - println(io, "v=", string(v)) - println(io, "opv[1]=", string(operands(v)[1])) - println(io, "opv[2]=", string(operands(v)[2])) - println(io, "ld=", string(ld)) - println(io, "ld_op[1]=", string(operands(ld)[1])) - - println(io, "v2=", string(v2)) - println(io, "o2=", string(o2)) - println(io, "hl2=", string(hl2)) - - println(io, "offty=", string(offty)) - println(io, "rhs=", string(rhs)) - end - throw(AssertionError(msg)) - end - # We currently only support gc_loaded(mem, ptr) where ptr = (({size_t, {}*}*)mem)->second - # [aka a load of the second element of mem] base_2, off_2, _ = get_base_and_offset(v2) base_1, off_1, _ = get_base_and_offset(operands(v)[1]) - if base_1 != base_2 || off_1 != off_2 - msg = sprint() do io::IO - println( - io, - "Enzyme internal error addr13 load data isn't offset of mem", - ) - println(io, "f=", string(f)) - println(io, "v=", string(v)) - println(io, "opv[1]=", string(operands(v)[1])) - println(io, "opv[2]=", string(operands(v)[2])) - println(io, "ld=", string(ld)) - println(io, "ld_op[1]=", string(operands(ld)[1])) - - println(io, "v2=", string(v2)) - println(io, "o2=", string(o2)) - println(io, "hl2=", string(hl2)) - - println(io, "base_1=", string(base_1)) - println(io, "base_2=", string(base_2)) - println(io, "off_1=", string(off_1)) - println(io, "off_2=", string(off_2)) - end - throw(AssertionError(msg)) + + if o2 == rhs && base_1 == base_2 && off_1 == off_2 + return v2, offset, true end - return v2, offset, true + rhs = ptrtoint!(b, get_memory_data(b, operands(v)[1]), offty) + lhs = ptrtoint!(b, operands(v)[2], offty) + off2 = nuwsub!(b, rhs, lhs) + return v2, nuwadd!(b, offset, off2), true end end end @@ -1127,24 +1085,11 @@ function nodecayed_phis!(mod::LLVM.Module) else base_obj = nphi - # %value_phi11 = phi {} addrspace(10)* [ %55, %L78 ], [ %54, %L76 ] - - # %.phi.trans.insert77 = bitcast {} addrspace(10)* %value_phi11 to { i64, {} addrspace(10)** } addrspace(10)* - # %.phi.trans.insert78 = addrspacecast { i64, {} addrspace(10)** } addrspace(10)* %.phi.trans.insert77 to { i64, {} addrspace(10)** } addrspace(11)* - # %.phi.trans.insert79 = getelementptr inbounds { i64, {} addrspace(10)** }, { i64, {} addrspace(10)** } addrspace(11)* %.phi.trans.insert78, i64 0, i32 1 - # %.pre80 = load {} addrspace(10)**, {} addrspace(10)** addrspace(11)* %.phi.trans.insert79, align 8, !dbg !532, !tbaa !19, !alias.scope !26, !noalias !29 - - # %154 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %value_phi11, {} addrspace(10)** %.pre80), !dbg !532 - jlt = LLVM.PointerType(LLVM.StructType(LLVM.LLVMType[]), 10) pjlt = LLVM.PointerType(jlt) - gent = LLVM.StructType([convert(LLVMType, Int), pjlt]) - pgent = LLVM.PointerType(LLVM.StructType([convert(LLVMType, Int), pjlt]), 10) - nphi = bitcast!(nb, nphi, pgent) - nphi = addrspacecast!(nb, nphi, LLVM.PointerType(gent, 11)) - nphi = inbounds_gep!(nb, gent, nphi, [LLVM.ConstantInt(Int64(0)), LLVM.ConstantInt(Int32(1))]) - nphi = load!(nb, pjlt, nphi) + nphi = get_memory_data(nb, nphi) + nphi = bitcast!(nb, nphi, pjlt) GTy = LLVM.FunctionType(LLVM.PointerType(jlt, 13), LLVM.LLVMType[jlt, pjlt]) gcloaded, _ = get_function!( diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 899c9f6d43..edb2bcd6e6 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -447,21 +447,27 @@ end i8 = LLVM.IntType(8) algn = 0 - if width == 1 - shadowres = call_samefunc_with_inverted_bundles!( + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + ev = if width == 1 + shadowin + else + extract_value!(B, shadowin, idx - 1) + end + + callv = call_samefunc_with_inverted_bundles!( B, gutils, orig, - [shadowin], + [ev], [API.VT_Shadow], false, ) #=lookup=# - - # TODO zero based off runtime types, rather than presume floatlike? if is_constant_value(gutils, origops[1]) - elSize = get_array_elsz(B, shadowin) + elSize = get_array_elsz(B, ev) elSize = LLVM.zext!(B, elSize, LLVM.IntType(8 * sizeof(Csize_t))) - len = get_array_len(B, shadowin) + len = get_array_len(B, ev) length = LLVM.mul!(B, len, elSize) bt = GPUCompiler.backtrace(orig) btstr = sprint() do io @@ -471,7 +477,7 @@ end GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type $btstr" LLVM.memset!( B, - get_array_data(B, shadowres), + get_array_data(B, callv), LLVM.ConstantInt(i8, 0, false), length, algn, @@ -479,69 +485,25 @@ end end if get_runtime_activity(gutils) prev = new_from_original(gutils, orig) - shadowres = LLVM.select!( + callv = LLVM.select!( B, LLVM.icmp!( B, LLVM.API.LLVMIntNE, - shadowin, + ev, new_from_original(gutils, origops[1]), ), - shadowres, + callv, prev, ) - API.moveBefore(prev, shadowres, B) - end - else - shadowres = - UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx = 1:width - ev = extract_value!(B, shadowin, idx - 1) - callv = call_samefunc_with_inverted_bundles!( - B, - gutils, - orig, - [ev], - [API.VT_Shadow], - false, - ) #=lookup=# - if is_constant_value(gutils, origops[1]) - elSize = get_array_elsz(B, ev) - elSize = LLVM.zext!(B, elSize, LLVM.IntType(8 * sizeof(Csize_t))) - len = get_array_len(B, ev) - length = LLVM.mul!(B, len, elSize) - bt = GPUCompiler.backtrace(orig) - btstr = sprint() do io - print(io, "\nCaused by:") - Base.show_backtrace(io, bt) - end - GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type $btstr" - LLVM.memset!( - B, - get_array_data(B, callv), - LLVM.ConstantInt(i8, 0, false), - length, - algn, - ) - end - if get_runtime_activity(gutils) - prev = new_from_original(gutils, orig) - callv = LLVM.select!( - B, - LLVM.icmp!( - B, - LLVM.API.LLVMIntNE, - ev, - new_from_original(gutils, origops[1]), - ), - callv, - prev, - ) - if idx == 1 - API.moveBefore(prev, callv, B) - end + if idx == 1 + API.moveBefore(prev, callv, B) end - shadowres = insert_value!(B, shadowres, callv, idx - 1) + end + shadowres = if width == 1 + callv + else + insert_value!(B, shadowres, callv, idx - 1) end end @@ -666,7 +628,7 @@ function arraycopy_common(fwd, B, orig, shadowsrc, gutils, shadowdst; len=nothin if memory if fwd - shadowsrc = memoryptr + shadowsrc = inttoptr!(B, memoryptr, LLVM.PointerType(LLVM.IntType(8))) lookup_src = false else shadowsrc = invert_pointer(gutils, shadowsrc, B) @@ -797,19 +759,29 @@ end i8 = LLVM.IntType(8) algn = 0 - if width == 1 - shadowres = call_samefunc_with_inverted_bundles!( + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + ev = if width == 1 + shadowin + else + extract_value!(B, shadowin, idx - 1) + end + ev2 = if width == 1 + shadowdata + else + extract_value!(B, shadowdata, idx - 1) + end + callv = call_samefunc_with_inverted_bundles!( B, gutils, orig, - [shadowin, shadowdata, len], + [ev, ev2, len], [API.VT_Shadow, API.VT_Shadow, API.VT_Primal], false, ) #=lookup=# - - # TODO zero based off runtime types, rather than presume floatlike? if is_constant_value(gutils, origops[1]) - elSize = get_memory_elsz(B, shadowin) + elSize = get_array_elsz(B, ev) elSize = LLVM.zext!(B, elSize, LLVM.IntType(8 * sizeof(Csize_t))) length = LLVM.mul!(B, len, elSize) bt = GPUCompiler.backtrace(orig) @@ -820,7 +792,7 @@ end GPUCompiler.@safe_warn "TODO forward zero-set of memorycopy used memset rather than runtime type $btstr" LLVM.memset!( B, - shadowdata, + ev2, LLVM.ConstantInt(i8, 0, false), length, algn, @@ -828,69 +800,25 @@ end end if get_runtime_activity(gutils) prev = new_from_original(gutils, orig) - shadowres = LLVM.select!( + callv = LLVM.select!( B, LLVM.icmp!( B, LLVM.API.LLVMIntNE, - shadowin, + ev, new_from_original(gutils, origops[1]), ), - shadowres, + callv, prev, ) - API.moveBefore(prev, shadowres, B) - end - else - shadowres = - UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx = 1:width - ev = extract_value!(B, shadowin, idx - 1) - ev2 = extract_value!(B, shadowdata, idx - 1) - callv = call_samefunc_with_inverted_bundles!( - B, - gutils, - orig, - [ev, ev2, len], - [API.VT_Shadow, API.VT_Shadow, API.VT_Primal], - false, - ) #=lookup=# - if is_constant_value(gutils, origops[1]) - elSize = get_array_elsz(B, ev) - elSize = LLVM.zext!(B, elSize, LLVM.IntType(8 * sizeof(Csize_t))) - length = LLVM.mul!(B, len, elSize) - bt = GPUCompiler.backtrace(orig) - btstr = sprint() do io - print(io, "\nCaused by:") - Base.show_backtrace(io, bt) - end - GPUCompiler.@safe_warn "TODO forward zero-set of memorycopy used memset rather than runtime type $btstr" - LLVM.memset!( - B, - ev2, - LLVM.ConstantInt(i8, 0, false), - length, - algn, - ) + if idx == 1 + API.moveBefore(prev, callv, B) end - if get_runtime_activity(gutils) - prev = new_from_original(gutils, orig) - callv = LLVM.select!( - B, - LLVM.icmp!( - B, - LLVM.API.LLVMIntNE, - ev, - new_from_original(gutils, origops[1]), - ), - callv, - prev, - ) - if idx == 1 - API.moveBefore(prev, callv, B) - end - end - shadowres = insert_value!(B, shadowres, callv, idx - 1) + end + shadowres = if width == 1 + callv + else + insert_value!(B, shadowres, callv, idx - 1) end end diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index 2cba33c14e..de11d3c1cd 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -92,3 +92,43 @@ function inoutcopyslice_rule( end return UInt8(false) end + +function inoutgcloaded_rule( + direction::Cint, + ret::API.CTypeTreeRef, + args::Ptr{API.CTypeTreeRef}, + known_values::Ptr{API.IntList}, + numArgs::Csize_t, + val::LLVM.API.LLVMValueRef, +)::UInt8 + if numArgs != 1 + return UInt8(false) + end + inst = LLVM.Instruction(val) + + legal, typ = abs_typeof(inst) + + if legal + if (direction & API.DOWN) != 0 + ctx = LLVM.context(inst) + dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) + if GPUCompiler.deserves_retbox(typ) + typ = Ptr{typ} + end + rest = typetree(typ, ctx, dl) + changed, legal = API.EnzymeCheckedMergeTypeTree(ret, rest) + @assert legal + end + return UInt8(false) + end + + if (direction & API.UP) != 0 + changed, legal = API.EnzymeCheckedMergeTypeTree(unsafe_load(args, 2), ret) + @assert legal + end + if (direction & API.DOWN) != 0 + changed, legal = API.EnzymeCheckedMergeTypeTree(ret, unsafe_load(args, 2)) + @assert legal + end + return UInt8(false) +end \ No newline at end of file diff --git a/src/typetree.jl b/src/typetree.jl index 61d700acb8..8224b98952 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -123,7 +123,7 @@ function get_offsets(@nospecialize(T::Type)) results = Tuple{API.CConcreteType,Int}[] for f = 1:fieldcount(T) offset = fieldoffset(T, f) - subT = fieldtype(T, f) + subT = typed_fieldtype(T, f) if !allocatedinline(subT) || subT isa UnionAll || subT isa Union || subT == Union{} push!(results, (API.DT_Pointer, offset)) @@ -305,6 +305,42 @@ else end return tt end + + function typetree_inner( + AT::Type{<:GenericMemoryRef{kind,T}}, + ctx, + dl, + seen::TypeTreeTable, + ) where {kind,T} + offset = 0 + tt = copy(typetree(T, ctx, dl, seen)) + if !allocatedinline(T) && Base.isconcretetype(T) + Enzyme.merge!(tt, TypeTree(API.DT_Pointer, ctx)) + only!(tt, 0) + end + Enzyme.merge!(tt, TypeTree(API.DT_Pointer, ctx)) + only!(tt, 0) + + for f = 2:fieldcount(AT) + offset = fieldoffset(AT, f) + subT = typed_fieldtype(AT, f) + + subtree = copy(typetree(subT, ctx, dl, seen)) + + # Allocated inline so adjust first path + if allocatedinline(subT) + shift!(subtree, dl, 0, sizeof(subT), offset) + else + Enzyme.merge!(subtree, TypeTree(API.DT_Pointer, ctx)) + only!(subtree, offset) + end + + Enzyme.merge!(tt, subtree) + end + canonicalize!(tt, sizeof(AT), dl) + + return tt + end end import Base: ismutabletype @@ -352,7 +388,7 @@ function typetree_inner(@nospecialize(T::Type), ctx, dl, seen::TypeTreeTable) tt = TypeTree() for f = 1:fieldcount(T) offset = fieldoffset(T, f) - subT = fieldtype(T, f) + subT = typed_fieldtype(T, f) if subT isa UnionAll || subT isa Union || subT == Union{} if !allocatedinline(subT) diff --git a/src/utils.jl b/src/utils.jl index 167e589713..cd28ae9a20 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -277,3 +277,29 @@ else end export my_methodinstance + + +@static if VERSION < v"1.11-" + +@inline function typed_fieldtype(@nospecialize(T::Type), i::Int) + fieldtype(T, i) +end + +else + +@inline function typed_fieldtype(@nospecialize(T::Type), i::Int) + if T <: GenericMemoryRef && i == 1 || T <: GenericMemory && i == 2 + eT = eltype(T) + if !allocatedinline(eT) && Base.isconcretetype(eT) + Ptr{Ptr{eT}} + else + Ptr{eT} + end + else + fieldtype(T, i) + end +end + +end + +export typed_fieldtype diff --git a/test/typetree.jl b/test/typetree.jl index 3b47161f62..074103ea7c 100644 --- a/test/typetree.jl +++ b/test/typetree.jl @@ -74,6 +74,11 @@ end "{[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}" @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == "{[0]:Pointer, [0,0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,0]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,0]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,0]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" + + @static if VERSION < v"1.11-" + else + @test tt(MemoryRef{Float32}) == "{[-1]:Pointer, [0,-1]:Float@float, [8,0]:Integer, [8,1]:Integer, [8,2]:Integer, [8,3]:Integer, [8,4]:Integer, [8,5]:Integer, [8,6]:Integer, [8,7]:Integer, [8,8]:Pointer, [8,8,-1]:Float@float}" + end else @test tt(UnionMember) == "{[0]:Float@float, [4]:Pointer, [8]:Integer}" @test tt(LList2{Float64}) == "{[0]:Pointer, [4]:Float@double}" From 924a2716f069142873cf44f8b3c1f472b5c7f82b Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 21 Oct 2024 10:56:43 -0700 Subject: [PATCH 372/495] 1.11: more methodinstance stuff (#1989) * 1.11: more methodinstance stuff * fixup * fix * fix elsz issue * fix * fix * fix --- Project.toml | 2 +- src/compiler.jl | 24 -------------- src/rules/llvmrules.jl | 19 +++++------ src/rules/typerules.jl | 40 ---------------------- src/utils.jl | 75 ++++++++++++++++++++++++++++++++++-------- test/internal_rules.jl | 4 +-- 6 files changed, 74 insertions(+), 90 deletions(-) diff --git a/Project.toml b/Project.toml index 480409a3f0..b1c9a8ed17 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.4" -Enzyme_jll = "0.0.155" +Enzyme_jll = "0.0.156" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" diff --git a/src/compiler.jl b/src/compiler.jl index cf879000b5..36ba2c8656 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3816,30 +3816,6 @@ function enzyme!( LLVM.API.LLVMValueRef, ) ), - "julia.gc_loaded" => @cfunction( - inoutgcloaded_rule, - UInt8, - ( - Cint, - API.CTypeTreeRef, - Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, - Csize_t, - LLVM.API.LLVMValueRef, - ) - ), - "julia.pointer_from_objref" => @cfunction( - inout_rule, - UInt8, - ( - Cint, - API.CTypeTreeRef, - Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, - Csize_t, - LLVM.API.LLVMValueRef, - ) - ), "jl_inactive_inout" => @cfunction( inout_rule, UInt8, diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index edb2bcd6e6..667d6f8abb 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -628,13 +628,11 @@ function arraycopy_common(fwd, B, orig, shadowsrc, gutils, shadowdst; len=nothin if memory if fwd - shadowsrc = inttoptr!(B, memoryptr, LLVM.PointerType(LLVM.IntType(8))) lookup_src = false + shadowsrc = invert_pointer(gutils, memoryptr, B) else - shadowsrc = invert_pointer(gutils, shadowsrc, B) - if !fwd - shadowsrc = lookup_value(gutils, shadowsrc, B) - end + shadowsrc = invert_pointer(gutils, shadowsrc, B) + shadowsrc = lookup_value(gutils, shadowsrc, B) end else shadowsrc = invert_pointer(gutils, shadowsrc, B) @@ -674,12 +672,13 @@ function arraycopy_common(fwd, B, orig, shadowsrc, gutils, shadowdst; len=nothin # src already has done the lookup from the argument shadowsrc0 = if lookup_src if memory + # TODO this may not be at the same offset as the start of the copy, e.g. get_memory_data(src) != memoryptr get_memory_data(B, evsrc) else get_array_data(B, evsrc) end else - evsrc + inttoptr!(B, evsrc, LLVM.PointerType(LLVM.IntType(8))) end shadowdst0 = if memory @@ -781,7 +780,7 @@ end false, ) #=lookup=# if is_constant_value(gutils, origops[1]) - elSize = get_array_elsz(B, ev) + elSize = get_memory_elsz(B, ev) elSize = LLVM.zext!(B, elSize, LLVM.IntType(8 * sizeof(Csize_t))) length = LLVM.mul!(B, len, elSize) bt = GPUCompiler.backtrace(orig) @@ -792,7 +791,7 @@ end GPUCompiler.@safe_warn "TODO forward zero-set of memorycopy used memset rather than runtime type $btstr" LLVM.memset!( B, - ev2, + inttoptr!(B, ev2, LLVM.PointerType(LLVM.IntType(8))), LLVM.ConstantInt(i8, 0, false), length, algn, @@ -838,7 +837,7 @@ end shadowres = LLVM.Value(unsafe_load(shadowR)) len = new_from_original(gutils, origops[3]) - memoryptr = new_from_original(gutils, origops[2]) + memoryptr = origops[2] arraycopy_common(true, B, orig, origops[1], gutils, shadowres; len, memoryptr) end @@ -849,7 +848,7 @@ end origops = LLVM.operands(orig) if !is_constant_value(gutils, origops[1]) && !is_constant_value(gutils, orig) len = new_from_original(gutils, origops[3]) - memoryptr = new_from_original(gutils, origops[2]) + memoryptr = origops[2] arraycopy_common(false, B, orig, origops[1], gutils, nothing; len, memoryptr) end diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index de11d3c1cd..2cba33c14e 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -92,43 +92,3 @@ function inoutcopyslice_rule( end return UInt8(false) end - -function inoutgcloaded_rule( - direction::Cint, - ret::API.CTypeTreeRef, - args::Ptr{API.CTypeTreeRef}, - known_values::Ptr{API.IntList}, - numArgs::Csize_t, - val::LLVM.API.LLVMValueRef, -)::UInt8 - if numArgs != 1 - return UInt8(false) - end - inst = LLVM.Instruction(val) - - legal, typ = abs_typeof(inst) - - if legal - if (direction & API.DOWN) != 0 - ctx = LLVM.context(inst) - dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - if GPUCompiler.deserves_retbox(typ) - typ = Ptr{typ} - end - rest = typetree(typ, ctx, dl) - changed, legal = API.EnzymeCheckedMergeTypeTree(ret, rest) - @assert legal - end - return UInt8(false) - end - - if (direction & API.UP) != 0 - changed, legal = API.EnzymeCheckedMergeTypeTree(unsafe_load(args, 2), ret) - @assert legal - end - if (direction & API.DOWN) != 0 - changed, legal = API.EnzymeCheckedMergeTypeTree(ret, unsafe_load(args, 2)) - @assert legal - end - return UInt8(false) -end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index cd28ae9a20..0b441b1b25 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -253,24 +253,73 @@ export codegen_world_age if VERSION >= v"1.11.0-DEV.1552" + +const prevmethodinstance = GPUCompiler.generic_methodinstance + +function methodinstance_generator(world::UInt, source, self, ft::Type, tt::Type) + @nospecialize + @assert Core.Compiler.isType(ft) && Core.Compiler.isType(tt) + ft = ft.parameters[1] + tt = tt.parameters[1] + + stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :ft, :tt), Core.svec()) + + # look up the method match + method_error = :(throw(MethodError(ft, tt, $world))) + sig = Tuple{ft, tt.parameters...} + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + match = ccall(:jl_gf_invoke_lookup_worlds, Any, + (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), + sig, #=mt=# nothing, world, min_world, max_world) + match === nothing && return stub(world, source, method_error) + + # look up the method and code instance + mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, + (Any, Any, Any), match.method, match.spec_types, match.sparams) + ci = Core.Compiler.retrieve_code_info(mi, world) + + # prepare a new code info + new_ci = copy(ci) + empty!(new_ci.code) + empty!(new_ci.codelocs) + empty!(new_ci.linetable) + empty!(new_ci.ssaflags) + new_ci.ssavaluetypes = 0 + + # propagate edge metadata + new_ci.min_world = min_world[] + new_ci.max_world = max_world[] + new_ci.edges = MethodInstance[mi] + + # prepare the slots + new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt] + new_ci.slotflags = UInt8[0x00 for i = 1:3] + + # return the method instance + push!(new_ci.code, Core.Compiler.ReturnNode(mi)) + push!(new_ci.ssaflags, 0x00) + push!(new_ci.linetable, GPUCompiler.@LineInfoNode(methodinstance)) + push!(new_ci.codelocs, 1) + new_ci.ssavaluetypes += 1 + + return new_ci +end + +@eval function prevmethodinstance(ft, tt) + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, methodinstance_generator)) +end + # XXX: version of Base.method_instance that uses a function type @inline function my_methodinstance(@nospecialize(ft::Type), @nospecialize(tt::Type), world::Integer=tls_world_age()) sig = GPUCompiler.signature_type_by_tt(ft, tt) - # @assert Base.isdispatchtuple(sig) # JuliaLang/julia#52233 - - mi = ccall(:jl_method_lookup_by_tt, Any, - (Any, Csize_t, Any), - sig, world, #=method_table=# nothing) - mi === nothing && throw(MethodError(ft, tt, world)) - mi = mi::MethodInstance - - # `jl_method_lookup_by_tt` and `jl_method_lookup` can return a unspecialized mi - if !Base.isdispatchtuple(mi.specTypes) - mi = Core.Compiler.specialize_method(mi.def, sig, mi.sparam_vals)::MethodInstance + if Base.isdispatchtuple(sig) # JuliaLang/julia#52233 + return GPUCompiler.methodinstance(ft, tt, world) + else + return prevmethodinstance(ft, tt, world) end - - return mi end else import GPUCompiler: methodinstance as my_methodinstance diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 3635ce07e2..fb5926a1d3 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -2,8 +2,6 @@ module InternalRules using Enzyme using Enzyme.EnzymeRules -using EnzymeTestUtils -using FiniteDifferences using LinearAlgebra using SparseArrays using Test @@ -155,6 +153,7 @@ function tr_solv(A, B, uplo, trans, diag, idx) end +using FiniteDifferences @testset "Reverse triangular solve" begin A = [0.7550523937508613 0.7979976952197996 0.29318222271218364; 0.4416768066117529 0.4335305304334933 0.8895389673238051; 0.07752980210005678 0.05978245503334367 0.4504482683752542] B = [0.10527381151977078 0.5450388247476627 0.3179106723232359 0.43919576779182357 0.20974326586875847; 0.7551160501548224 0.049772782182839426 0.09284926395551141 0.07862188927391855 0.17346407477062986; 0.6258040138863172 0.5928022963567454 0.24251650865340169 0.6626410383247967 0.32752198021506784] @@ -576,6 +575,7 @@ end @test Enzyme.gradient(Reverse, chol_upper, x)[1] ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] end +using EnzymeTestUtils @testset "Linear solve for triangular matrices" begin @testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3)) From b76585a87d4c5b53c48ab87d2b7cf47ebd58a0e0 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 22 Oct 2024 15:48:10 -0700 Subject: [PATCH 373/495] 1.11: more gcloaded work (#1999) * 1.11: more gcloaded work * fix * fix * fix --- src/compiler.jl | 15 ++++++++++++- src/compiler/optimize.jl | 38 +++++++++++++++++++++++---------- test/optimize.jl | 46 ++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 88 insertions(+), 12 deletions(-) create mode 100644 test/optimize.jl diff --git a/src/compiler.jl b/src/compiler.jl index 36ba2c8656..bb7ec835f8 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3207,7 +3207,20 @@ function annotate!(mod, mode) ) if haskey(fns, fname) fn = fns[fname] - push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) + else + push!(function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) + ) + end for u in LLVM.uses(fn) c = LLVM.user(u) if !isa(c, LLVM.CallInst) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 86760b423f..a35de5608f 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -828,13 +828,14 @@ function nodecayed_phis!(mod::LLVM.Module) base_1, off_1, _ = get_base_and_offset(operands(v)[1]) if o2 == rhs && base_1 == base_2 && off_1 == off_2 - return v2, offset, true + return operands(v)[1], offset, true end rhs = ptrtoint!(b, get_memory_data(b, operands(v)[1]), offty) lhs = ptrtoint!(b, operands(v)[2], offty) - off2 = nuwsub!(b, rhs, lhs) - return v2, nuwadd!(b, offset, off2), true + off2 = nuwsub!(b, lhs, rhs) + add = nuwadd!(b, offset, off2) + return operands(v)[1], add, true end end end @@ -905,8 +906,12 @@ function nodecayed_phis!(mod::LLVM.Module) end if isa(v, LLVM.BitCastInst) + preop = operands(v)[1] + while isa(preop, LLVM.BitCastInst) + preop = operands(preop)[1] + end v2, offset, skipload = - getparent(operands(v)[1], offset, hasload) + getparent(preop, offset, hasload) v2 = bitcast!( b, v2, @@ -1059,7 +1064,7 @@ function nodecayed_phis!(mod::LLVM.Module) end nb = IRBuilder() - position!(nb, inst) + position!(nb, nonphi) offset = goffsets[inst] append!(LLVM.incoming(offset), offsets) @@ -1068,15 +1073,26 @@ function nodecayed_phis!(mod::LLVM.Module) end nphi = nextvs[inst] - if !all(x -> x[1] == nvs[1][1], nvs) - append!(LLVM.incoming(nphi), nvs) - else - replace_uses!(nphi, nvs[1][1]) + + function ogbc(x) + while isa(x, LLVM.BitCastInst) + x = operands(x)[1] + end + return x + end + + if all(x -> ogbc(x[1]) == ogbc(nvs[1][1]), nvs) + bc = ogbc(nvs[1][1]) + if value_type(bc) != value_type(nphi) + bc = bitcast!(nb, bc, value_type(nphi)) + end + replace_uses!(nphi, bc) LLVM.API.LLVMInstructionEraseFromParent(nphi) - nphi = nvs[1][1] + nphi = bc + else + append!(LLVM.incoming(nphi), nvs) end - position!(nb, nonphi) if addr == 13 @static if VERSION < v"1.11-" nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10)) diff --git a/test/optimize.jl b/test/optimize.jl new file mode 100644 index 0000000000..a4fcc1768f --- /dev/null +++ b/test/optimize.jl @@ -0,0 +1,46 @@ +using Enzyme, LinearAlgebra, Test + +function gcloaded_fixup(dest, src) + N = size(src) + dat = src.data + len = N[1] + + i = 1 + while true + j = 1 + while true + ld = @inbounds if i <= j + dat[(i-1) * 2 + j] + else + dat[(j-1) * 2 + i] + end + @inbounds dest[(i-1) * 2 + j] = ld + if j == len + break + end + j += 1 + end + if i == len + break + end + i += 1 + end + return nothing +end + +@testset "GCLoaded fixup" begin + H = Hermitian(Matrix([4.0 1.0; 2.0 5.0])) + dest = Matrix{Float64}(undef, 2, 2) + + Enzyme.autodiff( + ForwardWithPrimal, + gcloaded_fixup, + Const, + Const(dest), + Const(H), + )[1] + @test dest ≈ [4.0 2.0; 2.0 5.0] + dest = Matrix{Float64}(undef, 2, 2) + gcloaded_fixup(dest, H) + @test dest ≈ [4.0 2.0; 2.0 5.0] +end diff --git a/test/runtests.jl b/test/runtests.jl index 8c7ca39abc..b3a64a2a21 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -73,6 +73,7 @@ end include("abi.jl") include("typetree.jl") +include("optimize.jl") include("rules.jl") include("rrules.jl") From db78d0adb7cb590e47ebac70bebf0fd93f596728 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 22 Oct 2024 15:52:06 -0700 Subject: [PATCH 374/495] CI: disable metal/amdgpu for now (#2004) --- .buildkite/pipeline.yml | 104 +++--- test/internal_rules.jl | 750 +--------------------------------------- 2 files changed, 59 insertions(+), 795 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 1a9f70d04c..98b8facf86 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -25,55 +25,55 @@ steps: env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager - - label: "AMDGPU Julia v{{matrix.version}}" - matrix: - setup: - version: - - "1.10" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.version}}" - agents: - queue: "juliagpu" - rocm: "*" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - commands: | - echo "--- Setup Julia packages" - julia --color=yes -e ' - using Pkg - pkgs = [PackageSpec(; path) for path in (".", "lib/EnzymeCore", "lib/EnzymeTestUtils")] - push!(pkgs, PackageSpec(; name="AMDGPU")) - Pkg.develop(pkgs)' || exit 3 - - echo "+++ Run tests" - julia --color=yes test/amdgpu.jl - env: - JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager - - - label: "Metal Julia v{{matrix.version}}" - matrix: - setup: - version: - - "1.10" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.version}}" - agents: - queue: "juliaecosystem" - os: "macos" - arch: "aarch64" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - commands: | - echo "--- Setup Julia packages" - julia --color=yes -e ' - using Pkg - pkgs = [PackageSpec(; path) for path in (".", "lib/EnzymeCore", "lib/EnzymeTestUtils")] - push!(pkgs, PackageSpec(; name="Metal")) - Pkg.develop(pkgs)' || exit 3 - - echo "+++ Run tests" - julia --color=yes test/metal.jl - env: - JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager +# - label: "AMDGPU Julia v{{matrix.version}}" +# matrix: +# setup: +# version: +# - "1.10" +# plugins: +# - JuliaCI/julia#v1: +# version: "{{matrix.version}}" +# agents: +# queue: "juliagpu" +# rocm: "*" +# if: build.message !~ /\[skip tests\]/ +# timeout_in_minutes: 60 +# commands: | +# echo "--- Setup Julia packages" +# julia --color=yes -e ' +# using Pkg +# pkgs = [PackageSpec(; path) for path in (".", "lib/EnzymeCore", "lib/EnzymeTestUtils")] +# push!(pkgs, PackageSpec(; name="AMDGPU")) +# Pkg.develop(pkgs)' || exit 3 +# +# echo "+++ Run tests" +# julia --color=yes test/amdgpu.jl +# env: +# JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager +# +# - label: "Metal Julia v{{matrix.version}}" +# matrix: +# setup: +# version: +# - "1.10" +# plugins: +# - JuliaCI/julia#v1: +# version: "{{matrix.version}}" +# agents: +# queue: "juliaecosystem" +# os: "macos" +# arch: "aarch64" +# if: build.message !~ /\[skip tests\]/ +# timeout_in_minutes: 60 +# commands: | +# echo "--- Setup Julia packages" +# julia --color=yes -e ' +# using Pkg +# pkgs = [PackageSpec(; path) for path in (".", "lib/EnzymeCore", "lib/EnzymeTestUtils")] +# push!(pkgs, PackageSpec(; name="Metal")) +# Pkg.develop(pkgs)' || exit 3 +# +# echo "+++ Run tests" +# julia --color=yes test/metal.jl +# env: +# JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager diff --git a/test/internal_rules.jl b/test/internal_rules.jl index fb5926a1d3..67ec233982 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -1,4 +1,3 @@ -module InternalRules using Enzyme using Enzyme.EnzymeRules @@ -7,753 +6,18 @@ using SparseArrays using Test import Random -struct TPair - a::Float64 - b::Float64 -end - -function sorterrfn(t, x) - function lt(a, b) - return a.a < b.a - end - return first(sortperm(t, lt=lt)) * x -end - -@testset "Sort rules" begin - function f1(x) - a = [1.0, 3.0, x] - sort!(a) - return a[2] - end - - @test autodiff(Forward, f1, Duplicated(2.0, 1.0))[1] == 1 - @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0) - @test autodiff(Reverse, f1, Active, Active(2.0))[1][1] == 1 - @test autodiff(Forward, f1, Duplicated(4.0, 1.0))[1] == 0 - @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == (var"1"=0.0, var"2"=0.0) - @test autodiff(Reverse, f1, Active, Active(4.0))[1][1] == 0 - - function f2(x) - a = [1.0, -3.0, -x, -2x, x] - sort!(a; rev=true, lt=(x, y) -> abs(x) < abs(y) || (abs(x) == abs(y) && x < y)) - return sum(a .* [1, 2, 3, 4, 5]) - end - - @test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3 - @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0) - @test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3 - - function f3(x) - a = [2.0, 2.5, x, 1.0] - return partialsort(a, 2) - end - - @test autodiff(Forward, f3, Duplicated(1.5, 1.0))[1] == 1.0 - @test autodiff(Forward, f3, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0) - @test autodiff(Reverse, f3, Active(1.5))[1][1] == 1.0 - @test autodiff(Reverse, f3, Active(2.5))[1][1] == 0.0 - - function f4(x) - a = [2.0, 2.5, x, x / 2] - y = partialsort(a, 1:2) - return sum(y) - end - - @test autodiff(Forward, f4, Duplicated(1.5, 1.0))[1] == 1.5 - @test autodiff(Forward, f4, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.5, var"2"=3.0) - @test autodiff(Reverse, f4, Active(1.5))[1][1] == 1.5 - @test autodiff(Reverse, f4, Active(4.0))[1][1] == 0.5 - @test autodiff(Reverse, f4, Active(6.0))[1][1] == 0.0 - - dd = Duplicated([TPair(1, 2), TPair(2, 3), TPair(0, 1)], [TPair(0, 0), TPair(0, 0), TPair(0, 0)]) - res = Enzyme.autodiff(Reverse, sorterrfn, dd, Active(1.0)) - - @test res[1][2] ≈ 3 - @test dd.dval[1].a ≈ 0 - @test dd.dval[1].b ≈ 0 - @test dd.dval[2].a ≈ 0 - @test dd.dval[2].b ≈ 0 - @test dd.dval[3].a ≈ 0 - @test dd.dval[3].b ≈ 0 -end - -@testset "Linear Solve" begin - A = Float64[2 3; 5 7] - dA = zero(A) - b = Float64[11, 13] - db = zero(b) - - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)}) - - tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db)) - - dy = Float64[17, 19] - copyto!(shadow, dy) - - pullback(Const(\), Duplicated(A, dA), Duplicated(b, db), tape) - - z = transpose(A) \ dy - - y = A \ b - @test dA ≈ (-z * transpose(y)) - @test db ≈ z - - db = zero(b) - - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)}) - - tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db)) - - dy = Float64[17, 19] - copyto!(shadow, dy) - - pullback(Const(\), Const(A), Duplicated(b, db), tape) - - z = transpose(A) \ dy - - y = A \ b - @test db ≈ z - - dA = zero(A) - - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)}) - - tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b)) - - dy = Float64[17, 19] - copyto!(shadow, dy) - - pullback(Const(\), Duplicated(A, dA), Const(b), tape) - - z = transpose(A) \ dy - - y = A \ b - @test dA ≈ (-z * transpose(y)) - - # Ensure multi dim doesn't crash - function test2!(A) - A .= A \ [1.0 0;0.0 1.0] - return nothing - end - - A = rand(2,2) - dA = [1.0 0.0; 0.0 0.0] - - Enzyme.autodiff( - Enzyme.Reverse, - test2!, - Enzyme.Duplicated(A,dA), - ) -end - -function tr_solv(A, B, uplo, trans, diag, idx) - B = copy(B) - LAPACK.trtrs!(uplo, trans, diag, A, B) - return @inbounds B[idx] -end - - -using FiniteDifferences -@testset "Reverse triangular solve" begin - A = [0.7550523937508613 0.7979976952197996 0.29318222271218364; 0.4416768066117529 0.4335305304334933 0.8895389673238051; 0.07752980210005678 0.05978245503334367 0.4504482683752542] - B = [0.10527381151977078 0.5450388247476627 0.3179106723232359 0.43919576779182357 0.20974326586875847; 0.7551160501548224 0.049772782182839426 0.09284926395551141 0.07862188927391855 0.17346407477062986; 0.6258040138863172 0.5928022963567454 0.24251650865340169 0.6626410383247967 0.32752198021506784] - for idx in 1:15 - for uplo in ('L', 'U') - for diag in ('N', 'U') - for trans in ('N', 'T') - dA = zero(A) - dB = zero(B) - Enzyme.autodiff(Reverse, tr_solv, Duplicated(A, dA), Duplicated(B, dB), Const(uplo),Const(trans), Const(diag), Const(idx)) - fA = FiniteDifferences.grad(central_fdm(5, 1), A->tr_solv(A, B, uplo, trans, diag, idx), A)[1] - fB = FiniteDifferences.grad(central_fdm(5, 1), B->tr_solv(A, B, uplo, trans, diag, idx), B)[1] - - if max(abs.(dA)...) >= 1e-10 || max(abs.(fA)...) >= 1e-10 - @test dA ≈ fA - end - if max(abs.(dB)...) >= 1e-10 || max(abs.(fB)...) >= 1e-10 - @test dB ≈ fB - end - end - end - end - end -end +Enzyme.Compiler.DumpPostOpt[] = true function chol_lower0(x) c = copy(x) C, info = LinearAlgebra.LAPACK.potrf!('L', c) - return c[2,1] -end - -function chol_upper0(x) - c = copy(x) - C, info = LinearAlgebra.LAPACK.potrf!('U', c) - return c[1,2] -end - -@testset "Cholesky PotRF" begin - x = reshape([1.0, -0.10541615131279458, 0.6219810761363638, 0.293343219811946, -0.10541615131279458, 1.0, -0.05258941747718969, 0.34629296878264443, 0.6219810761363638, -0.05258941747718969, 1.0, 0.4692436399208845, 0.293343219811946, 0.34629296878264443, 0.4692436399208845, 1.0], 4, 4) - dL = zero(x) - dL[2, 1] = 1.0 - - @test Enzyme.gradient(Reverse, chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] - - @test Enzyme.gradient(Forward, chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] - - @test FiniteDifferences.grad(central_fdm(5, 1), chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] - - @test Enzyme.gradient(Forward, chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] - @test Enzyme.gradient(Reverse, chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] - @test FiniteDifferences.grad(central_fdm(5, 1), chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] -end - - -function tchol_lower(x, row, col) - c = copy(x) - C, info = LinearAlgebra.LAPACK.potrf!('L', c) - return c[row, col] -end -function tchol_upper(x, row, col) - c = copy(x) - C, info = LinearAlgebra.LAPACK.potrf!('U', c) - return c[row, col] -end - -@testset "Cholesky PotRF 3x3" begin - - x = [1.0 0.13147601759884564 0.5282944836504488; 0.13147601759884564 1.0 0.18506733179093515; 0.5282944836504488 0.18506733179093515 1.0] - for i in 1:size(x, 1) - for j in 1:size(x, 2) - reverse_grad = Enzyme.gradient(Reverse, x -> tchol_lower(x, i, j), x)[1] - forward_grad = Enzyme.gradient(Forward, x -> tchol_lower(x, i, j), x)[1] - finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tchol_lower(x, i, j), x)[1] - @test reverse_grad ≈ finite_diff - @test forward_grad ≈ finite_diff - - reverse_grad = Enzyme.gradient(Reverse, x -> tchol_upper(x, i, j), x)[1] - forward_grad = Enzyme.gradient(Forward, x -> tchol_upper(x, i, j), x)[1] - finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tchol_upper(x, i, j), x)[1] - @test reverse_grad ≈ finite_diff - @test forward_grad ≈ finite_diff - end - end -end - -function tcholsolv_lower(A, B, i) - c = copy(B) - C, info = LinearAlgebra.LAPACK.potrs!('L', A, c) - return c[i] -end -function tcholsolv_upper(A, B, i) - c = copy(B) - C, info = LinearAlgebra.LAPACK.potrs!('U', A, c) - return c[i] -end - - -@testset "Cholesky PotRS 3x5" begin - - x = [1.0 0.13147601759884564 0.5282944836504488; 0.13147601759884564 1.0 0.18506733179093515; 0.5282944836504488 0.18506733179093515 1.0] - for i in 1:15 - B = [3.1 2.7 5.9 2.4 1.6; 7.9 8.2 1.3 9.4 5.5; 4.7 2.9 9.8 7.1 4.3] - reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_lower(x, B, i)), B)[1] - # forward_grad = Enzyme.gradient(Forward, B -> tcholsolv_lower(x, B, i), B)[1] - finite_diff = FiniteDifferences.grad(central_fdm(5, 1), B -> tcholsolv_lower(x, B, i), B)[1] - @test reverse_grad ≈ finite_diff - # @test forward_grad ≈ finite_diff - - reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_upper(x, B, i)), B)[1] - # forward_grad = Enzyme.gradient(Forward, B -> tcholsolv_upper(x, B, i), B))[1] - finite_diff = FiniteDifferences.grad(central_fdm(5, 1), B -> tcholsolv_upper(x, B, i), B)[1] - @test reverse_grad ≈ finite_diff - # @test forward_grad ≈ finite_diff - - reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_lower(x, B, i)), x)[1] - #forward_grad = Enzyme.gradient(Forward, x -> tcholsolv_lower(x, B, i), x)[1] - finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tcholsolv_lower(x, B, i), x)[1] - @test reverse_grad ≈ finite_diff - #@test forward_grad ≈ finite_diff - # - reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_upper(x, B, i)), x)[1] - #forward_grad = Enzyme.gradient(Forward, x -> tcholsolv_upper(x, B, i), x)[1] - finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tcholsolv_upper(x, B, i), x)[1] - @test reverse_grad ≈ finite_diff - #@test forward_grad ≈ finite_diff - end -end - -@testset "Cholesky" begin - function symmetric_definite(n :: Int=10) - α = one(Float64) - A = spdiagm(-1 => α * ones(n-1), 0 => 4 * ones(n), 1 => conj(α) * ones(n-1)) - b = A * Float64[1:n;] - return A, b - end - - function divdriver_NC(x, fact, b) - res = fact\b - x .= res - return nothing - end - - function ldivdriver_NC(x, fact, b) - ldiv!(fact,b) - x .= b - return nothing - end - - function divdriver(x, A, b) - fact = cholesky(A) - divdriver_NC(x, fact, b) - end - - function divdriver_herm(x, A, b) - fact = cholesky(Hermitian(A)) - divdriver_NC(x, fact, b) - end - - function divdriver_sym(x, A, b) - fact = cholesky(Symmetric(A)) - divdriver_NC(x, fact, b) - end - - function ldivdriver(x, A, b) - fact = cholesky(A) - ldivdriver_NC(x, fact, b) - end - - function ldivdriver_herm(x, A, b) - fact = cholesky(Hermitian(A)) - ldivdriver_NC(x, fact, b) - end - - function ldivdriver_sym(x, A, b) - fact = cholesky(Symmetric(A)) - ldivdriver_NC(x, fact, b) - end - - # Test forward - function fwdJdxdb(driver, A, b) - adJ = zeros(size(A)) - dA = Duplicated(A, zeros(size(A))) - db = Duplicated(b, zeros(length(b))) - dx = Duplicated(zeros(length(b)), zeros(length(b))) - for i in 1:length(b) - copyto!(dA.val, A) - copyto!(db.val, b) - fill!(dA.dval, 0.0) - fill!(db.dval, 0.0) - fill!(dx.dval, 0.0) - db.dval[i] = 1.0 - Enzyme.autodiff( - Forward, - driver, - dx, - dA, - db - ) - adJ[i, :] = dx.dval - end - return adJ - end - - function const_fwdJdxdb(driver, A, b) - adJ = zeros(length(b), length(b)) - db = Duplicated(b, zeros(length(b))) - dx = Duplicated(zeros(length(b)), zeros(length(b))) - for i in 1:length(b) - copyto!(db.val, b) - fill!(db.dval, 0.0) - fill!(dx.dval, 0.0) - db.dval[i] = 1.0 - Enzyme.autodiff( - Forward, - driver, - dx, - Const(A), - db - ) - adJ[i, :] = dx.dval - end - return adJ - end - - function batchedfwdJdxdb(driver, A, b) - n = length(b) - function seed(i) - x = zeros(n) - x[i] = 1.0 - return x - end - adJ = zeros(size(A)) - dA = BatchDuplicated(A, ntuple(i -> zeros(size(A)), n)) - db = BatchDuplicated(b, ntuple(i -> seed(i), n)) - dx = BatchDuplicated(zeros(length(b)), ntuple(i -> zeros(length(b)), n)) - Enzyme.autodiff( - Forward, - driver, - dx, - dA, - db - ) - for i in 1:n - adJ[i, :] = dx.dval[i] - end - return adJ - end - - # Test reverse - function revJdxdb(driver, A, b) - adJ = zeros(size(A)) - dA = Duplicated(A, zeros(size(A))) - db = Duplicated(b, zeros(length(b))) - dx = Duplicated(zeros(length(b)), zeros(length(b))) - for i in 1:length(b) - copyto!(dA.val, A) - copyto!(db.val, b) - fill!(dA.dval, 0.0) - fill!(db.dval, 0.0) - fill!(dx.dval, 0.0) - dx.dval[i] = 1.0 - Enzyme.autodiff( - Reverse, - driver, - dx, - dA, - db - ) - adJ[i, :] = db.dval - end - return adJ - end - - function const_revJdxdb(driver, A, b) - adJ = zeros(length(b), length(b)) - db = Duplicated(b, zeros(length(b))) - dx = Duplicated(zeros(length(b)), zeros(length(b))) - for i in 1:length(b) - copyto!(db.val, b) - fill!(db.dval, 0.0) - fill!(dx.dval, 0.0) - dx.dval[i] = 1.0 - Enzyme.autodiff( - Reverse, - driver, - dx, - Const(A), - db - ) - adJ[i, :] = db.dval - end - return adJ - end - - function batchedrevJdxdb(driver, A, b) - n = length(b) - function seed(i) - x = zeros(n) - x[i] = 1.0 - return x - end - adJ = zeros(size(A)) - dA = BatchDuplicated(A, ntuple(i -> zeros(size(A)), n)) - db = BatchDuplicated(b, ntuple(i -> zeros(length(b)), n)) - dx = BatchDuplicated(zeros(length(b)), ntuple(i -> seed(i), n)) - Enzyme.autodiff( - Reverse, - driver, - dx, - dA, - db - ) - for i in 1:n - adJ[i, :] .= db.dval[i] - end - return adJ - end - - function Jdxdb(driver, A, b) - x = A\b - dA = zeros(size(A)) - db = zeros(length(b)) - J = zeros(length(b), length(b)) - for i in 1:length(b) - db[i] = 1.0 - dx = A\db - db[i] = 0.0 - J[i, :] = dx - end - return J - end - - function JdxdA(driver, A, b) - db = zeros(length(b)) - J = zeros(length(b), length(b)) - for i in 1:length(b) - db[i] = 1.0 - dx = A\db - db[i] = 0.0 - J[i, :] = dx - end - return J - end - - @testset "Testing $op" for (op, driver, driver_NC) in ( - (:\, divdriver, divdriver_NC), - (:\, divdriver_herm, divdriver_NC), - (:\, divdriver_sym, divdriver_NC), - (:ldiv!, ldivdriver, ldivdriver_NC), - (:ldiv!, ldivdriver_herm, ldivdriver_NC), - (:ldiv!, ldivdriver_sym, ldivdriver_NC) - ) - A, b = symmetric_definite(10) - n = length(b) - A = Matrix(A) - x = zeros(n) - x = driver(x, A, b) - fdm = forward_fdm(2, 1); - - function b_one(b) - _x = zeros(length(b)) - driver(_x,A,b) - return _x - end - - fdJ = op==:\ ? FiniteDifferences.jacobian(fdm, b_one, copy(b))[1] : nothing - fwdJ = fwdJdxdb(driver, A, b) - revJ = revJdxdb(driver, A, b) - batchedrevJ = batchedrevJdxdb(driver, A, b) - batchedfwdJ = batchedfwdJdxdb(driver, A, b) - J = Jdxdb(driver, A, b) - - if op == :\ - @test isapprox(fwdJ, fdJ) - end - - @test isapprox(fwdJ, revJ) - @test isapprox(fwdJ, batchedrevJ) - @test isapprox(fwdJ, batchedfwdJ) - - fwdJ = const_fwdJdxdb(driver_NC, cholesky(A), b) - revJ = const_revJdxdb(driver_NC, cholesky(A), b) - if op == :\ - @test isapprox(fwdJ, fdJ) - end - @test isapprox(fwdJ, revJ) - - function h(A, b) - A = copy(A) - LinearAlgebra.LAPACK.potrf!('U', A) - b2 = copy(b) - LinearAlgebra.LAPACK.potrs!('U', A, b2) - @inbounds b2[1] - end - - A = [1.3 0.5; 0.5 1.5] - b = [1., 2.] - dA = zero(A) - Enzyme.autodiff(Reverse, h, Active, Duplicated(A, dA), Const(b)) - # dA_fwd = Enzyme.gradient(Forward, A->h(A, b), A)[1] - dA_fd = FiniteDifferences.grad(central_fdm(5, 1), A->h(A, b), A)[1] - - @test isapprox(dA, dA_fd) - end + return @inbounds c[2,1] end -function chol_upper(x) - x = reshape(x, 4, 4) - x = parent(cholesky(Hermitian(x)).U) - x = convert(typeof(x), UpperTriangular(x)) - return x[1,2] -end - -@testset "Cholesky upper triangular v1" begin - x = [1.0, -0.10541615131279458, 0.6219810761363638, 0.293343219811946, -0.10541615131279458, 1.0, -0.05258941747718969, 0.34629296878264443, 0.6219810761363638, -0.05258941747718969, 1.0, 0.4692436399208845, 0.293343219811946, 0.34629296878264443, 0.4692436399208845, 1.0] - - @test Enzyme.gradient(Forward, chol_upper, x)[1] ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - - @test Enzyme.gradient(Reverse, chol_upper, x)[1] ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] -end - -using EnzymeTestUtils -@testset "Linear solve for triangular matrices" begin - @testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), - TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3)) - n = sizeB[1] - M = rand(TE, n, n) - B = rand(TE, sizeB...) - Y = zeros(TE, sizeB...) - A = T(M) - @testset "test through constructor" begin - _A = T(A) - f!(Y, A, B, ::T) where {T} = ldiv!(Y, T(A), B) - for TY in (Const, Duplicated, BatchDuplicated), - TM in (Const, Duplicated, BatchDuplicated), - TB in (Const, Duplicated, BatchDuplicated) - are_activities_compatible(Const, TY, TM, TB) || continue - test_reverse(f!, TY, (Y, TY), (M, TM), (B, TB), (_A, Const); atol = 1.0e-5, rtol = 1.0e-5) - end - end - @testset "test through `Adjoint` wrapper (regression test for #1306)" begin - # Test that we get the same derivative for `M` as for the adjoint of its - # (materialized) transpose. It's the same matrix, but represented differently - function f!(Y, A, B) - ldiv!(Y, A, B) - return nothing - end - A1 = T(M) - A2 = T(conj(permutedims(M))') - dA1 = make_zero(A1) - dA2 = make_zero(A2) - dB1 = make_zero(B) - dB2 = make_zero(B) - dY1 = rand(TE, sizeB...) - dY2 = copy(dY1) - autodiff(Reverse, f!, Duplicated(Y, dY1), Duplicated(A1, dA1), Duplicated(B, dB1)) - autodiff(Reverse, f!, Duplicated(Y, dY2), Duplicated(A2, dA2), Duplicated(B, dB2)) - @test dA1.data ≈ dA2.data - @test dB1 ≈ dB2 - end - end -end - -@testset "rand and randn rules" begin - # Distributed as x + unit normal + uniform - struct MyDistribution - x::Float64 - end - - Random.rand(rng::Random.AbstractRNG, d::MyDistribution) = d.x + randn() + rand() - Random.rand(d::MyDistribution) = rand(Random.default_rng(), d) - - # Outer rand should be differentiated through, and inner rand and randn should be ignored. - @test autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) == ((1.0,),) -end - - -@testset "Ranges" begin - function f1(x) - x = 25.0x - ts = Array(Base.range_start_stop_length(0.0, x, 30)) - return sum(ts) - end - function f2(x) - x = 25.0x - ts = Array(Base.range_start_stop_length(0.0, 0.25, 30)) - return sum(ts) + x - end - function f3(x) - ts = Array(Base.range_start_stop_length(x, 1.25, 30)) - return sum(ts) - end - @test Enzyme.autodiff(Forward, f1, Duplicated(0.1, 1.0)) == (374.99999999999994,) - @test Enzyme.autodiff(Forward, f2, Duplicated(0.1, 1.0)) == (25.0,) - @test Enzyme.autodiff(Forward, f3, Duplicated(0.1, 1.0)) == (15.0,) - - @test Enzyme.autodiff(Forward, f1, BatchDuplicated(0.1, (1.0, 2.0))) == - ((var"1" = 374.99999999999994, var"2" = 749.9999999999999),) - @test Enzyme.autodiff(Forward, f2, BatchDuplicated(0.1, (1.0, 2.0))) == - ((var"1"=25.0, var"2"=50.0),) - @test Enzyme.autodiff(Forward, f3, BatchDuplicated(0.1, (1.0, 2.0))) == - ((var"1"=15.0, var"2"=30.0),) - - @test Enzyme.autodiff(Reverse, f1, Active, Active(0.1)) == ((375.0,),) - @test Enzyme.autodiff(Reverse, f2, Active, Active(0.1)) == ((25.0,),) - @test Enzyme.autodiff(Reverse, f3, Active, Active(0.1)) == ((15.0,),) - - # Batch active rule isnt setup - # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f1(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((375.0,750.0)),) - # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f2(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((25.0,50.0)),) - # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f3(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((15.0,30.0)),) -end +Enzyme.API.printall!(true) -@testset "Ranges 2" begin - function f1(x) - x = 25.0x - ts = Array(0.0:x:3.0) - return sum(ts) - end - function f2(x) - x = 25.0x - ts = Array(0.0:0.25:3.0) - return sum(ts) + x - end - function f3(x) - x = 25.0x - ts = Array(x:0.25:3.0) - return sum(ts) - end - function f4(x) - x = 25.0x - ts = Array(0.0:0.25:x) - return sum(ts) - end - @test Enzyme.autodiff(Forward, f1, Duplicated(0.1, 1.0)) == (25.0,) - @test Enzyme.autodiff(Forward, f2, Duplicated(0.1, 1.0)) == (25.0,) - @test Enzyme.autodiff(Forward, f3, Duplicated(0.1, 1.0)) == (75.0,) - @test Enzyme.autodiff(Forward, f4, Duplicated(0.12, 1.0)) == (0,) - - @test Enzyme.autodiff(Forward, f1, BatchDuplicated(0.1, (1.0, 2.0))) == - ((var"1"=25.0, var"2"=50.0),) - @test Enzyme.autodiff(Forward, f2, BatchDuplicated(0.1, (1.0, 2.0))) == - ((var"1"=25.0, var"2"=50.0),) - @test Enzyme.autodiff(Forward, f3, BatchDuplicated(0.1, (1.0, 2.0))) == - ((var"1"=75.0, var"2"=150.0),) - @test Enzyme.autodiff(Forward, f4, BatchDuplicated(0.12, (1.0, 2.0))) == - ((var"1"=0.0, var"2"=0.0),) - - @test Enzyme.autodiff(Reverse, f1, Active, Active(0.1)) == ((25.0,),) - @test Enzyme.autodiff(Reverse, f2, Active, Active(0.1)) == ((25.0,),) - @test Enzyme.autodiff(Reverse, f3, Active, Active(0.1)) == ((75.0,),) - @test Enzyme.autodiff(Reverse, f4, Active, Active(0.12)) == ((0.0,),) - - # Batch active rule isnt setup - # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f1(x); nothing end, Active(1.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((25.0,50.0)),) - # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f2(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((25.0,50.0)),) - # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f3(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((75.0,150.0)),) - # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f4(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((0.0,0.0)),) -end - -@testset "SparseArrays spmatvec reverse rule" begin - C = zeros(18) - M = sprand(18, 9, 0.1) - v = randn(9) - α = 2.0 - β = 1.0 - - for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), - Tα in (Const, Active), Tβ in (Const, Active) - - are_activities_compatible(Tret, Tret, Tv, Tα, Tβ) || continue - test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ)) - - end - - - for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) - are_activities_compatible(Tret, Tret, Tv) || continue - test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) - end -end - -@testset "SparseArrays spmatmat reverse rule" begin - C = zeros(18, 11) - M = sprand(18, 9, 0.1) - v = randn(9, 11) - α = 2.0 - β = 1.0 - - for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), - Tα in (Const, Active), Tβ in (Const, Active) - - are_activities_compatible(Tret, Tv, Tα, Tβ) || continue - test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ)) - end - - for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) - are_activities_compatible(Tret, Tv) || continue - test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) - end -end +x = reshape([1.0, -0.10541615131279458, 0.6219810761363638, 0.293343219811946, -0.10541615131279458, 1.0, -0.05258941747718969, 0.34629296878264443, 0.6219810761363638, -0.05258941747718969, 1.0, 0.4692436399208845, 0.293343219811946, 0.34629296878264443, 0.4692436399208845, 1.0], 4, 4) + dL = zero(x) + dL[2, 1] = 1.0 -end # InternalRules + @test Enzyme.autodiff(Forward, chol_lower0, Duplicated(x, dL))[1] ≈ 0.05270807565639164 From 1c69a706116b443972f1c3e232a22437b3087d64 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 22 Oct 2024 16:03:29 -0700 Subject: [PATCH 375/495] Adapt to nightly [rm combinemuladd] (#2001) --- src/compiler/optimize.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index a35de5608f..2a1ef49553 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -353,6 +353,7 @@ end end else function combine_mul_add_tm!(pm, tm) +@static if VERSION < v"1.12.0-DEV.1390" function combine_mul_add(mod) @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm @@ -365,6 +366,7 @@ else return true end add!(pm, ModulePass("CombineMulAdd", combine_mul_add)) +end end end From 80c98873cdfc73b92cbc352d5976264956ab7205 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 22 Oct 2024 16:16:40 -0700 Subject: [PATCH 376/495] Fix sret undef (#1990) * Fix sret undef * add test * fix * 1.11: the adventure continues, destroy (#1986) * 1.11: the adventure continues, destroy * fix * fixup * fix * cleanup * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix --- src/Enzyme.jl | 4 +-- src/compiler/optimize.jl | 63 +++++++++++++++++++++++++++++++++------- src/rules/typerules.jl | 40 +++++++++++++++++++++++++ src/utils.jl | 9 +++++- test/abi.jl | 30 +++++++++++++++++++ 5 files changed, 133 insertions(+), 13 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 0a3c030f11..c3769e35ac 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1559,7 +1559,7 @@ end Base.@_inline_meta ntuple(Val(N)) do idx Base.@_inline_meta - return (i == idx) ? 1.0 : 0.0 + return (i == idx) ? T(1) : T(0) end end end @@ -1571,7 +1571,7 @@ end Base.@_inline_meta ntuple(Val(N)) do idx Base.@_inline_meta - return (i + start - 1 == idx) ? 1.0 : 0.0 + return (i + start - 1 == idx) ? T(1) : T(0) end end end diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 2a1ef49553..dc26d140bb 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -1684,17 +1684,23 @@ function propagate_returned!(mod::LLVM.Module) illegalUse = true break end - if !isa(ops[i], LLVM.AllocaInst) + if !isa(ops[i], LLVM.AllocaInst) && !isa(ops[i], LLVM.UndefValue) && !isa(ops[i], LLVM.PoisonValue) illegalUse = true break end - eltype = LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(ops[i])) + eltype = if isa(ops[i], LLVM.AllocaInst) + LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(ops[i])) + else + LLVM.eltype(value_type(ops[i])) + end seenfn = false todo = LLVM.Instruction[] - for u2 in LLVM.uses(ops[i]) + if isa(ops[i], LLVM.AllocaInst) + for u2 in LLVM.uses(ops[i]) un2 = LLVM.user(u2) push!(todo, un2) end + end while length(todo) > 0 un2 = pop!(todo) if isa(un2, LLVM.BitCastInst) @@ -1705,6 +1711,14 @@ function propagate_returned!(mod::LLVM.Module) end continue end + if isa(un2, LLVM.GetElementPtrInst) + push!(torem, un2) + for u3 in LLVM.uses(un2) + un3 = LLVM.user(u3) + push!(todo, un3) + end + continue + end if !isa(un2, LLVM.CallInst) illegalUse = true break @@ -1776,14 +1790,9 @@ function propagate_returned!(mod::LLVM.Module) illegalUse = true break end - if isa(ops[i], LLVM.UndefValue) + if isa(ops[i], LLVM.UndefValue) || isa(ops[i], LLVM.PoisonValue) continue end - @static if LLVM.version() >= v"12" - if isa(ops[i], LLVM.PoisonValue) - continue - end - end if ops[i] == arg continue end @@ -1911,6 +1920,7 @@ function propagate_returned!(mod::LLVM.Module) un = LLVM.user(u) push!(next, LLVM.name(LLVM.parent(LLVM.parent(un)))) end + delete_writes_into_removed_args(fn, toremove) nfn = LLVM.Function( API.EnzymeCloneFunctionWithoutReturnOrArgs(fn, keepret, toremove), ) @@ -1953,6 +1963,39 @@ function propagate_returned!(mod::LLVM.Module) end end end + +function delete_writes_into_removed_args(fn::LLVM.Function, toremove) + args = collect(parameters(fn)) + for tr in toremove + tr = tr + 1 + todorep = Tuple{LLVM.Instruction, LLVM.Value}[] + for opv in LLVM.uses(args[tr]) + u = LLVM.user(opv) + push!(todorep, (u, args[tr])) + end + toerase = LLVM.Instruction[] + while length(todorep) != 0 + cur, cval = pop!(todorep) + if isa(cur, LLVM.StoreInst) + if operands(cur)[2] == cval + LLVM.API.LLVMInstructionEraseFromParent(nphi) + continue + end + end + if isa(cur, LLVM.GetElementPtrInst) || + isa(cur, LLVM.BitCastInst) || + isa(cur, LLVM.AddrSpaceCastInst) + for opv in LLVM.uses(cur) + u = LLVM.user(opv) + push!(todorep, (u, cur)) + end + continue + end + throw(AssertionError("Deleting argument with an unknown dependency, $(string(cur)) uses $(string(cval))")) + end + end +end + function detect_writeonly!(mod::LLVM.Module) for f in functions(mod) if isempty(LLVM.blocks(f)) @@ -2376,7 +2419,7 @@ function removeDeadArgs!(mod::LLVM.Module, tm) kind(attr) == kind(StringAttribute("enzyme_sret")) || kind(attr) == kind(StringAttribute("enzyme_sret_v")) ) for attr in attrs - ) + ) && any_jltypes(sret_ty(fn, idx)) for u in LLVM.uses(fn) u = LLVM.user(u) if isa(u, LLVM.ConstantExpr) diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index 2cba33c14e..de11d3c1cd 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -92,3 +92,43 @@ function inoutcopyslice_rule( end return UInt8(false) end + +function inoutgcloaded_rule( + direction::Cint, + ret::API.CTypeTreeRef, + args::Ptr{API.CTypeTreeRef}, + known_values::Ptr{API.IntList}, + numArgs::Csize_t, + val::LLVM.API.LLVMValueRef, +)::UInt8 + if numArgs != 1 + return UInt8(false) + end + inst = LLVM.Instruction(val) + + legal, typ = abs_typeof(inst) + + if legal + if (direction & API.DOWN) != 0 + ctx = LLVM.context(inst) + dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) + if GPUCompiler.deserves_retbox(typ) + typ = Ptr{typ} + end + rest = typetree(typ, ctx, dl) + changed, legal = API.EnzymeCheckedMergeTypeTree(ret, rest) + @assert legal + end + return UInt8(false) + end + + if (direction & API.UP) != 0 + changed, legal = API.EnzymeCheckedMergeTypeTree(unsafe_load(args, 2), ret) + @assert legal + end + if (direction & API.DOWN) != 0 + changed, legal = API.EnzymeCheckedMergeTypeTree(ret, unsafe_load(args, 2)) + @assert legal + end + return UInt8(false) +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 0b441b1b25..55dc69769e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -327,7 +327,6 @@ end export my_methodinstance - @static if VERSION < v"1.11-" @inline function typed_fieldtype(@nospecialize(T::Type), i::Int) @@ -352,3 +351,11 @@ end end export typed_fieldtype + +# returns the inner type of an sret/enzyme_sret/enzyme_sret_v +function sret_ty(fn::LLVM.Function, idx::Int) + return eltype(LLVM.value_type(LLVM.parameters(fn)[idx])) +end + +export sret_ty + diff --git a/test/abi.jl b/test/abi.jl index acc8f26090..f27affd3a4 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -544,6 +544,36 @@ end @inferred hvp_and_gradient!(zeros(2), zeros(2), mulsin, [2.0, 3.0], [5.0, 2.7]) end +function ulogistic(x) + return x > 36 ? one(x) : 1 / (one(x) + 1/x) +end + +@noinline function u_transform_tuple(x) + yfirst = ulogistic(@inbounds x[1]) + yfirst, 2 +end + + +@noinline function mytransform(ts, x) + yfirst = ulogistic(@inbounds x[1]) + yrest, _ = u_transform_tuple(x) + (yfirst, yrest) +end + +function undefsret(trf, x) + p = mytransform(trf, x) + return 1/(p[2]) +end + +@testset "Undef sret" begin + trf = 0.1 + + x = randn(3) + dx = zero(x) + undefsret(trf, x) + autodiff(Reverse, undefsret, Active, Const(trf), Duplicated(x, dx)) +end + struct ByRefStruct x::Vector{Float64} v::Vector{Float64} From 274e5e5df1f545bd03f8d067c85b3a50c12ded77 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 22 Oct 2024 16:23:25 -0700 Subject: [PATCH 377/495] Update benchmark_pr.yml --- .github/workflows/benchmark_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/benchmark_pr.yml b/.github/workflows/benchmark_pr.yml index 1af037fd6a..5daa392daf 100644 --- a/.github/workflows/benchmark_pr.yml +++ b/.github/workflows/benchmark_pr.yml @@ -42,7 +42,7 @@ jobs: mkdir -p plots benchpkgplot ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.repository.default_branch}},${{github.event.pull_request.head.sha}}" --npart=10 --format=png --input-dir=results/ --output-dir=plots/ - name: Upload plot as artifact - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: plots path: plots From 6c23ee7c19ecf69b11dd13a9bcc59c15ff54b201 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Tue, 22 Oct 2024 16:33:02 -0700 Subject: [PATCH 378/495] Fix justActive condition (#1936) Closes #1935 --- src/compiler.jl | 2 +- test/runtests.jl | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index bb7ec835f8..b3806e2fe8 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -469,7 +469,7 @@ end subT = typed_fieldtype(T, f) - if justActive && !allocatedinline(subT) + if justActive && ismutabletype(subT) return Val(AnyState) end diff --git a/test/runtests.jl b/test/runtests.jl index b3a64a2a21..18d2e2da79 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -135,6 +135,17 @@ end @test Enzyme.Compiler.active_reg_inner(Tuple, (), nothing) == Enzyme.Compiler.DupState @test Enzyme.Compiler.active_reg_inner(Tuple, (), nothing, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) == Enzyme.Compiler.MixedState @test Enzyme.Compiler.active_reg_inner(Tuple{A,A} where A, (), nothing, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) == Enzyme.Compiler.MixedState + + # issue #1935 + struct Incomplete + x::Float64 + y + Incomplete(x) = new(x) + # incomplete constructor & non-bitstype field => !Base.allocatedinline(Incomplete) + end + @test Enzyme.Compiler.active_reg_inner(Tuple{Incomplete}, (), nothing, #=justActive=#Val(false)) == Enzyme.Compiler.MixedState + @test Enzyme.Compiler.active_reg_inner(Tuple{Incomplete}, (), nothing, #=justActive=#Val(true)) == Enzyme.Compiler.ActiveState + world = codegen_world_age(typeof(f0), Tuple{Float64}) thunk_a = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) thunk_b = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Const, Tuple{Const{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) From 0365ffff299c96a94a07a7aa9b9876588e20e8b3 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 22 Oct 2024 19:36:13 -0400 Subject: [PATCH 379/495] Restore internal rule tests --- test/internal_rules.jl | 750 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 743 insertions(+), 7 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 67ec233982..fb5926a1d3 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -1,3 +1,4 @@ +module InternalRules using Enzyme using Enzyme.EnzymeRules @@ -6,18 +7,753 @@ using SparseArrays using Test import Random -Enzyme.Compiler.DumpPostOpt[] = true +struct TPair + a::Float64 + b::Float64 +end + +function sorterrfn(t, x) + function lt(a, b) + return a.a < b.a + end + return first(sortperm(t, lt=lt)) * x +end + +@testset "Sort rules" begin + function f1(x) + a = [1.0, 3.0, x] + sort!(a) + return a[2] + end + + @test autodiff(Forward, f1, Duplicated(2.0, 1.0))[1] == 1 + @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0) + @test autodiff(Reverse, f1, Active, Active(2.0))[1][1] == 1 + @test autodiff(Forward, f1, Duplicated(4.0, 1.0))[1] == 0 + @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == (var"1"=0.0, var"2"=0.0) + @test autodiff(Reverse, f1, Active, Active(4.0))[1][1] == 0 + + function f2(x) + a = [1.0, -3.0, -x, -2x, x] + sort!(a; rev=true, lt=(x, y) -> abs(x) < abs(y) || (abs(x) == abs(y) && x < y)) + return sum(a .* [1, 2, 3, 4, 5]) + end + + @test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3 + @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0) + @test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3 + + function f3(x) + a = [2.0, 2.5, x, 1.0] + return partialsort(a, 2) + end + + @test autodiff(Forward, f3, Duplicated(1.5, 1.0))[1] == 1.0 + @test autodiff(Forward, f3, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0) + @test autodiff(Reverse, f3, Active(1.5))[1][1] == 1.0 + @test autodiff(Reverse, f3, Active(2.5))[1][1] == 0.0 + + function f4(x) + a = [2.0, 2.5, x, x / 2] + y = partialsort(a, 1:2) + return sum(y) + end + + @test autodiff(Forward, f4, Duplicated(1.5, 1.0))[1] == 1.5 + @test autodiff(Forward, f4, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.5, var"2"=3.0) + @test autodiff(Reverse, f4, Active(1.5))[1][1] == 1.5 + @test autodiff(Reverse, f4, Active(4.0))[1][1] == 0.5 + @test autodiff(Reverse, f4, Active(6.0))[1][1] == 0.0 + + dd = Duplicated([TPair(1, 2), TPair(2, 3), TPair(0, 1)], [TPair(0, 0), TPair(0, 0), TPair(0, 0)]) + res = Enzyme.autodiff(Reverse, sorterrfn, dd, Active(1.0)) + + @test res[1][2] ≈ 3 + @test dd.dval[1].a ≈ 0 + @test dd.dval[1].b ≈ 0 + @test dd.dval[2].a ≈ 0 + @test dd.dval[2].b ≈ 0 + @test dd.dval[3].a ≈ 0 + @test dd.dval[3].b ≈ 0 +end + +@testset "Linear Solve" begin + A = Float64[2 3; 5 7] + dA = zero(A) + b = Float64[11, 13] + db = zero(b) + + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)}) + + tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db)) + + dy = Float64[17, 19] + copyto!(shadow, dy) + + pullback(Const(\), Duplicated(A, dA), Duplicated(b, db), tape) + + z = transpose(A) \ dy + + y = A \ b + @test dA ≈ (-z * transpose(y)) + @test db ≈ z + + db = zero(b) + + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)}) + + tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db)) + + dy = Float64[17, 19] + copyto!(shadow, dy) + + pullback(Const(\), Const(A), Duplicated(b, db), tape) + + z = transpose(A) \ dy + + y = A \ b + @test db ≈ z + + dA = zero(A) + + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)}) + + tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b)) + + dy = Float64[17, 19] + copyto!(shadow, dy) + + pullback(Const(\), Duplicated(A, dA), Const(b), tape) + + z = transpose(A) \ dy + + y = A \ b + @test dA ≈ (-z * transpose(y)) + + # Ensure multi dim doesn't crash + function test2!(A) + A .= A \ [1.0 0;0.0 1.0] + return nothing + end + + A = rand(2,2) + dA = [1.0 0.0; 0.0 0.0] + + Enzyme.autodiff( + Enzyme.Reverse, + test2!, + Enzyme.Duplicated(A,dA), + ) +end + +function tr_solv(A, B, uplo, trans, diag, idx) + B = copy(B) + LAPACK.trtrs!(uplo, trans, diag, A, B) + return @inbounds B[idx] +end + + +using FiniteDifferences +@testset "Reverse triangular solve" begin + A = [0.7550523937508613 0.7979976952197996 0.29318222271218364; 0.4416768066117529 0.4335305304334933 0.8895389673238051; 0.07752980210005678 0.05978245503334367 0.4504482683752542] + B = [0.10527381151977078 0.5450388247476627 0.3179106723232359 0.43919576779182357 0.20974326586875847; 0.7551160501548224 0.049772782182839426 0.09284926395551141 0.07862188927391855 0.17346407477062986; 0.6258040138863172 0.5928022963567454 0.24251650865340169 0.6626410383247967 0.32752198021506784] + for idx in 1:15 + for uplo in ('L', 'U') + for diag in ('N', 'U') + for trans in ('N', 'T') + dA = zero(A) + dB = zero(B) + Enzyme.autodiff(Reverse, tr_solv, Duplicated(A, dA), Duplicated(B, dB), Const(uplo),Const(trans), Const(diag), Const(idx)) + fA = FiniteDifferences.grad(central_fdm(5, 1), A->tr_solv(A, B, uplo, trans, diag, idx), A)[1] + fB = FiniteDifferences.grad(central_fdm(5, 1), B->tr_solv(A, B, uplo, trans, diag, idx), B)[1] + + if max(abs.(dA)...) >= 1e-10 || max(abs.(fA)...) >= 1e-10 + @test dA ≈ fA + end + if max(abs.(dB)...) >= 1e-10 || max(abs.(fB)...) >= 1e-10 + @test dB ≈ fB + end + end + end + end + end +end function chol_lower0(x) c = copy(x) C, info = LinearAlgebra.LAPACK.potrf!('L', c) - return @inbounds c[2,1] + return c[2,1] +end + +function chol_upper0(x) + c = copy(x) + C, info = LinearAlgebra.LAPACK.potrf!('U', c) + return c[1,2] +end + +@testset "Cholesky PotRF" begin + x = reshape([1.0, -0.10541615131279458, 0.6219810761363638, 0.293343219811946, -0.10541615131279458, 1.0, -0.05258941747718969, 0.34629296878264443, 0.6219810761363638, -0.05258941747718969, 1.0, 0.4692436399208845, 0.293343219811946, 0.34629296878264443, 0.4692436399208845, 1.0], 4, 4) + dL = zero(x) + dL[2, 1] = 1.0 + + @test Enzyme.gradient(Reverse, chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + + @test Enzyme.gradient(Forward, chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + + @test FiniteDifferences.grad(central_fdm(5, 1), chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + + @test Enzyme.gradient(Forward, chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + @test Enzyme.gradient(Reverse, chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + @test FiniteDifferences.grad(central_fdm(5, 1), chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] +end + + +function tchol_lower(x, row, col) + c = copy(x) + C, info = LinearAlgebra.LAPACK.potrf!('L', c) + return c[row, col] +end +function tchol_upper(x, row, col) + c = copy(x) + C, info = LinearAlgebra.LAPACK.potrf!('U', c) + return c[row, col] +end + +@testset "Cholesky PotRF 3x3" begin + + x = [1.0 0.13147601759884564 0.5282944836504488; 0.13147601759884564 1.0 0.18506733179093515; 0.5282944836504488 0.18506733179093515 1.0] + for i in 1:size(x, 1) + for j in 1:size(x, 2) + reverse_grad = Enzyme.gradient(Reverse, x -> tchol_lower(x, i, j), x)[1] + forward_grad = Enzyme.gradient(Forward, x -> tchol_lower(x, i, j), x)[1] + finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tchol_lower(x, i, j), x)[1] + @test reverse_grad ≈ finite_diff + @test forward_grad ≈ finite_diff + + reverse_grad = Enzyme.gradient(Reverse, x -> tchol_upper(x, i, j), x)[1] + forward_grad = Enzyme.gradient(Forward, x -> tchol_upper(x, i, j), x)[1] + finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tchol_upper(x, i, j), x)[1] + @test reverse_grad ≈ finite_diff + @test forward_grad ≈ finite_diff + end + end +end + +function tcholsolv_lower(A, B, i) + c = copy(B) + C, info = LinearAlgebra.LAPACK.potrs!('L', A, c) + return c[i] +end +function tcholsolv_upper(A, B, i) + c = copy(B) + C, info = LinearAlgebra.LAPACK.potrs!('U', A, c) + return c[i] +end + + +@testset "Cholesky PotRS 3x5" begin + + x = [1.0 0.13147601759884564 0.5282944836504488; 0.13147601759884564 1.0 0.18506733179093515; 0.5282944836504488 0.18506733179093515 1.0] + for i in 1:15 + B = [3.1 2.7 5.9 2.4 1.6; 7.9 8.2 1.3 9.4 5.5; 4.7 2.9 9.8 7.1 4.3] + reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_lower(x, B, i)), B)[1] + # forward_grad = Enzyme.gradient(Forward, B -> tcholsolv_lower(x, B, i), B)[1] + finite_diff = FiniteDifferences.grad(central_fdm(5, 1), B -> tcholsolv_lower(x, B, i), B)[1] + @test reverse_grad ≈ finite_diff + # @test forward_grad ≈ finite_diff + + reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_upper(x, B, i)), B)[1] + # forward_grad = Enzyme.gradient(Forward, B -> tcholsolv_upper(x, B, i), B))[1] + finite_diff = FiniteDifferences.grad(central_fdm(5, 1), B -> tcholsolv_upper(x, B, i), B)[1] + @test reverse_grad ≈ finite_diff + # @test forward_grad ≈ finite_diff + + reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_lower(x, B, i)), x)[1] + #forward_grad = Enzyme.gradient(Forward, x -> tcholsolv_lower(x, B, i), x)[1] + finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tcholsolv_lower(x, B, i), x)[1] + @test reverse_grad ≈ finite_diff + #@test forward_grad ≈ finite_diff + # + reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_upper(x, B, i)), x)[1] + #forward_grad = Enzyme.gradient(Forward, x -> tcholsolv_upper(x, B, i), x)[1] + finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tcholsolv_upper(x, B, i), x)[1] + @test reverse_grad ≈ finite_diff + #@test forward_grad ≈ finite_diff + end +end + +@testset "Cholesky" begin + function symmetric_definite(n :: Int=10) + α = one(Float64) + A = spdiagm(-1 => α * ones(n-1), 0 => 4 * ones(n), 1 => conj(α) * ones(n-1)) + b = A * Float64[1:n;] + return A, b + end + + function divdriver_NC(x, fact, b) + res = fact\b + x .= res + return nothing + end + + function ldivdriver_NC(x, fact, b) + ldiv!(fact,b) + x .= b + return nothing + end + + function divdriver(x, A, b) + fact = cholesky(A) + divdriver_NC(x, fact, b) + end + + function divdriver_herm(x, A, b) + fact = cholesky(Hermitian(A)) + divdriver_NC(x, fact, b) + end + + function divdriver_sym(x, A, b) + fact = cholesky(Symmetric(A)) + divdriver_NC(x, fact, b) + end + + function ldivdriver(x, A, b) + fact = cholesky(A) + ldivdriver_NC(x, fact, b) + end + + function ldivdriver_herm(x, A, b) + fact = cholesky(Hermitian(A)) + ldivdriver_NC(x, fact, b) + end + + function ldivdriver_sym(x, A, b) + fact = cholesky(Symmetric(A)) + ldivdriver_NC(x, fact, b) + end + + # Test forward + function fwdJdxdb(driver, A, b) + adJ = zeros(size(A)) + dA = Duplicated(A, zeros(size(A))) + db = Duplicated(b, zeros(length(b))) + dx = Duplicated(zeros(length(b)), zeros(length(b))) + for i in 1:length(b) + copyto!(dA.val, A) + copyto!(db.val, b) + fill!(dA.dval, 0.0) + fill!(db.dval, 0.0) + fill!(dx.dval, 0.0) + db.dval[i] = 1.0 + Enzyme.autodiff( + Forward, + driver, + dx, + dA, + db + ) + adJ[i, :] = dx.dval + end + return adJ + end + + function const_fwdJdxdb(driver, A, b) + adJ = zeros(length(b), length(b)) + db = Duplicated(b, zeros(length(b))) + dx = Duplicated(zeros(length(b)), zeros(length(b))) + for i in 1:length(b) + copyto!(db.val, b) + fill!(db.dval, 0.0) + fill!(dx.dval, 0.0) + db.dval[i] = 1.0 + Enzyme.autodiff( + Forward, + driver, + dx, + Const(A), + db + ) + adJ[i, :] = dx.dval + end + return adJ + end + + function batchedfwdJdxdb(driver, A, b) + n = length(b) + function seed(i) + x = zeros(n) + x[i] = 1.0 + return x + end + adJ = zeros(size(A)) + dA = BatchDuplicated(A, ntuple(i -> zeros(size(A)), n)) + db = BatchDuplicated(b, ntuple(i -> seed(i), n)) + dx = BatchDuplicated(zeros(length(b)), ntuple(i -> zeros(length(b)), n)) + Enzyme.autodiff( + Forward, + driver, + dx, + dA, + db + ) + for i in 1:n + adJ[i, :] = dx.dval[i] + end + return adJ + end + + # Test reverse + function revJdxdb(driver, A, b) + adJ = zeros(size(A)) + dA = Duplicated(A, zeros(size(A))) + db = Duplicated(b, zeros(length(b))) + dx = Duplicated(zeros(length(b)), zeros(length(b))) + for i in 1:length(b) + copyto!(dA.val, A) + copyto!(db.val, b) + fill!(dA.dval, 0.0) + fill!(db.dval, 0.0) + fill!(dx.dval, 0.0) + dx.dval[i] = 1.0 + Enzyme.autodiff( + Reverse, + driver, + dx, + dA, + db + ) + adJ[i, :] = db.dval + end + return adJ + end + + function const_revJdxdb(driver, A, b) + adJ = zeros(length(b), length(b)) + db = Duplicated(b, zeros(length(b))) + dx = Duplicated(zeros(length(b)), zeros(length(b))) + for i in 1:length(b) + copyto!(db.val, b) + fill!(db.dval, 0.0) + fill!(dx.dval, 0.0) + dx.dval[i] = 1.0 + Enzyme.autodiff( + Reverse, + driver, + dx, + Const(A), + db + ) + adJ[i, :] = db.dval + end + return adJ + end + + function batchedrevJdxdb(driver, A, b) + n = length(b) + function seed(i) + x = zeros(n) + x[i] = 1.0 + return x + end + adJ = zeros(size(A)) + dA = BatchDuplicated(A, ntuple(i -> zeros(size(A)), n)) + db = BatchDuplicated(b, ntuple(i -> zeros(length(b)), n)) + dx = BatchDuplicated(zeros(length(b)), ntuple(i -> seed(i), n)) + Enzyme.autodiff( + Reverse, + driver, + dx, + dA, + db + ) + for i in 1:n + adJ[i, :] .= db.dval[i] + end + return adJ + end + + function Jdxdb(driver, A, b) + x = A\b + dA = zeros(size(A)) + db = zeros(length(b)) + J = zeros(length(b), length(b)) + for i in 1:length(b) + db[i] = 1.0 + dx = A\db + db[i] = 0.0 + J[i, :] = dx + end + return J + end + + function JdxdA(driver, A, b) + db = zeros(length(b)) + J = zeros(length(b), length(b)) + for i in 1:length(b) + db[i] = 1.0 + dx = A\db + db[i] = 0.0 + J[i, :] = dx + end + return J + end + + @testset "Testing $op" for (op, driver, driver_NC) in ( + (:\, divdriver, divdriver_NC), + (:\, divdriver_herm, divdriver_NC), + (:\, divdriver_sym, divdriver_NC), + (:ldiv!, ldivdriver, ldivdriver_NC), + (:ldiv!, ldivdriver_herm, ldivdriver_NC), + (:ldiv!, ldivdriver_sym, ldivdriver_NC) + ) + A, b = symmetric_definite(10) + n = length(b) + A = Matrix(A) + x = zeros(n) + x = driver(x, A, b) + fdm = forward_fdm(2, 1); + + function b_one(b) + _x = zeros(length(b)) + driver(_x,A,b) + return _x + end + + fdJ = op==:\ ? FiniteDifferences.jacobian(fdm, b_one, copy(b))[1] : nothing + fwdJ = fwdJdxdb(driver, A, b) + revJ = revJdxdb(driver, A, b) + batchedrevJ = batchedrevJdxdb(driver, A, b) + batchedfwdJ = batchedfwdJdxdb(driver, A, b) + J = Jdxdb(driver, A, b) + + if op == :\ + @test isapprox(fwdJ, fdJ) + end + + @test isapprox(fwdJ, revJ) + @test isapprox(fwdJ, batchedrevJ) + @test isapprox(fwdJ, batchedfwdJ) + + fwdJ = const_fwdJdxdb(driver_NC, cholesky(A), b) + revJ = const_revJdxdb(driver_NC, cholesky(A), b) + if op == :\ + @test isapprox(fwdJ, fdJ) + end + @test isapprox(fwdJ, revJ) + + function h(A, b) + A = copy(A) + LinearAlgebra.LAPACK.potrf!('U', A) + b2 = copy(b) + LinearAlgebra.LAPACK.potrs!('U', A, b2) + @inbounds b2[1] + end + + A = [1.3 0.5; 0.5 1.5] + b = [1., 2.] + dA = zero(A) + Enzyme.autodiff(Reverse, h, Active, Duplicated(A, dA), Const(b)) + # dA_fwd = Enzyme.gradient(Forward, A->h(A, b), A)[1] + dA_fd = FiniteDifferences.grad(central_fdm(5, 1), A->h(A, b), A)[1] + + @test isapprox(dA, dA_fd) + end end -Enzyme.API.printall!(true) +function chol_upper(x) + x = reshape(x, 4, 4) + x = parent(cholesky(Hermitian(x)).U) + x = convert(typeof(x), UpperTriangular(x)) + return x[1,2] +end + +@testset "Cholesky upper triangular v1" begin + x = [1.0, -0.10541615131279458, 0.6219810761363638, 0.293343219811946, -0.10541615131279458, 1.0, -0.05258941747718969, 0.34629296878264443, 0.6219810761363638, -0.05258941747718969, 1.0, 0.4692436399208845, 0.293343219811946, 0.34629296878264443, 0.4692436399208845, 1.0] + + @test Enzyme.gradient(Forward, chol_upper, x)[1] ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + + @test Enzyme.gradient(Reverse, chol_upper, x)[1] ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +end + +using EnzymeTestUtils +@testset "Linear solve for triangular matrices" begin + @testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), + TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3)) + n = sizeB[1] + M = rand(TE, n, n) + B = rand(TE, sizeB...) + Y = zeros(TE, sizeB...) + A = T(M) + @testset "test through constructor" begin + _A = T(A) + f!(Y, A, B, ::T) where {T} = ldiv!(Y, T(A), B) + for TY in (Const, Duplicated, BatchDuplicated), + TM in (Const, Duplicated, BatchDuplicated), + TB in (Const, Duplicated, BatchDuplicated) + are_activities_compatible(Const, TY, TM, TB) || continue + test_reverse(f!, TY, (Y, TY), (M, TM), (B, TB), (_A, Const); atol = 1.0e-5, rtol = 1.0e-5) + end + end + @testset "test through `Adjoint` wrapper (regression test for #1306)" begin + # Test that we get the same derivative for `M` as for the adjoint of its + # (materialized) transpose. It's the same matrix, but represented differently + function f!(Y, A, B) + ldiv!(Y, A, B) + return nothing + end + A1 = T(M) + A2 = T(conj(permutedims(M))') + dA1 = make_zero(A1) + dA2 = make_zero(A2) + dB1 = make_zero(B) + dB2 = make_zero(B) + dY1 = rand(TE, sizeB...) + dY2 = copy(dY1) + autodiff(Reverse, f!, Duplicated(Y, dY1), Duplicated(A1, dA1), Duplicated(B, dB1)) + autodiff(Reverse, f!, Duplicated(Y, dY2), Duplicated(A2, dA2), Duplicated(B, dB2)) + @test dA1.data ≈ dA2.data + @test dB1 ≈ dB2 + end + end +end + +@testset "rand and randn rules" begin + # Distributed as x + unit normal + uniform + struct MyDistribution + x::Float64 + end + + Random.rand(rng::Random.AbstractRNG, d::MyDistribution) = d.x + randn() + rand() + Random.rand(d::MyDistribution) = rand(Random.default_rng(), d) + + # Outer rand should be differentiated through, and inner rand and randn should be ignored. + @test autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) == ((1.0,),) +end + + +@testset "Ranges" begin + function f1(x) + x = 25.0x + ts = Array(Base.range_start_stop_length(0.0, x, 30)) + return sum(ts) + end + function f2(x) + x = 25.0x + ts = Array(Base.range_start_stop_length(0.0, 0.25, 30)) + return sum(ts) + x + end + function f3(x) + ts = Array(Base.range_start_stop_length(x, 1.25, 30)) + return sum(ts) + end + @test Enzyme.autodiff(Forward, f1, Duplicated(0.1, 1.0)) == (374.99999999999994,) + @test Enzyme.autodiff(Forward, f2, Duplicated(0.1, 1.0)) == (25.0,) + @test Enzyme.autodiff(Forward, f3, Duplicated(0.1, 1.0)) == (15.0,) + + @test Enzyme.autodiff(Forward, f1, BatchDuplicated(0.1, (1.0, 2.0))) == + ((var"1" = 374.99999999999994, var"2" = 749.9999999999999),) + @test Enzyme.autodiff(Forward, f2, BatchDuplicated(0.1, (1.0, 2.0))) == + ((var"1"=25.0, var"2"=50.0),) + @test Enzyme.autodiff(Forward, f3, BatchDuplicated(0.1, (1.0, 2.0))) == + ((var"1"=15.0, var"2"=30.0),) + + @test Enzyme.autodiff(Reverse, f1, Active, Active(0.1)) == ((375.0,),) + @test Enzyme.autodiff(Reverse, f2, Active, Active(0.1)) == ((25.0,),) + @test Enzyme.autodiff(Reverse, f3, Active, Active(0.1)) == ((15.0,),) + + # Batch active rule isnt setup + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f1(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((375.0,750.0)),) + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f2(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((25.0,50.0)),) + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f3(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((15.0,30.0)),) +end -x = reshape([1.0, -0.10541615131279458, 0.6219810761363638, 0.293343219811946, -0.10541615131279458, 1.0, -0.05258941747718969, 0.34629296878264443, 0.6219810761363638, -0.05258941747718969, 1.0, 0.4692436399208845, 0.293343219811946, 0.34629296878264443, 0.4692436399208845, 1.0], 4, 4) - dL = zero(x) - dL[2, 1] = 1.0 +@testset "Ranges 2" begin + function f1(x) + x = 25.0x + ts = Array(0.0:x:3.0) + return sum(ts) + end + function f2(x) + x = 25.0x + ts = Array(0.0:0.25:3.0) + return sum(ts) + x + end + function f3(x) + x = 25.0x + ts = Array(x:0.25:3.0) + return sum(ts) + end + function f4(x) + x = 25.0x + ts = Array(0.0:0.25:x) + return sum(ts) + end + @test Enzyme.autodiff(Forward, f1, Duplicated(0.1, 1.0)) == (25.0,) + @test Enzyme.autodiff(Forward, f2, Duplicated(0.1, 1.0)) == (25.0,) + @test Enzyme.autodiff(Forward, f3, Duplicated(0.1, 1.0)) == (75.0,) + @test Enzyme.autodiff(Forward, f4, Duplicated(0.12, 1.0)) == (0,) + + @test Enzyme.autodiff(Forward, f1, BatchDuplicated(0.1, (1.0, 2.0))) == + ((var"1"=25.0, var"2"=50.0),) + @test Enzyme.autodiff(Forward, f2, BatchDuplicated(0.1, (1.0, 2.0))) == + ((var"1"=25.0, var"2"=50.0),) + @test Enzyme.autodiff(Forward, f3, BatchDuplicated(0.1, (1.0, 2.0))) == + ((var"1"=75.0, var"2"=150.0),) + @test Enzyme.autodiff(Forward, f4, BatchDuplicated(0.12, (1.0, 2.0))) == + ((var"1"=0.0, var"2"=0.0),) + + @test Enzyme.autodiff(Reverse, f1, Active, Active(0.1)) == ((25.0,),) + @test Enzyme.autodiff(Reverse, f2, Active, Active(0.1)) == ((25.0,),) + @test Enzyme.autodiff(Reverse, f3, Active, Active(0.1)) == ((75.0,),) + @test Enzyme.autodiff(Reverse, f4, Active, Active(0.12)) == ((0.0,),) + + # Batch active rule isnt setup + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f1(x); nothing end, Active(1.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((25.0,50.0)),) + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f2(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((25.0,50.0)),) + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f3(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((75.0,150.0)),) + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f4(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((0.0,0.0)),) +end + +@testset "SparseArrays spmatvec reverse rule" begin + C = zeros(18) + M = sprand(18, 9, 0.1) + v = randn(9) + α = 2.0 + β = 1.0 + + for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), + Tα in (Const, Active), Tβ in (Const, Active) + + are_activities_compatible(Tret, Tret, Tv, Tα, Tβ) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ)) + + end + + + for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) + are_activities_compatible(Tret, Tret, Tv) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) + end +end + +@testset "SparseArrays spmatmat reverse rule" begin + C = zeros(18, 11) + M = sprand(18, 9, 0.1) + v = randn(9, 11) + α = 2.0 + β = 1.0 + + for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), + Tα in (Const, Active), Tβ in (Const, Active) + + are_activities_compatible(Tret, Tv, Tα, Tβ) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ)) + end + + for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) + are_activities_compatible(Tret, Tv) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) + end +end - @test Enzyme.autodiff(Forward, chol_lower0, Duplicated(x, dL))[1] ≈ 0.05270807565639164 +end # InternalRules From aa44483c6e895a7b54f033b2b264decdb89cd635 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 22 Oct 2024 16:39:58 -0700 Subject: [PATCH 380/495] Fix inlining for nightly (#2005) * Fix inlining for nightly * Update interpreter.jl --- src/compiler/interpreter.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 937f61e77d..19337961f6 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -247,6 +247,7 @@ let # overload `inlining_policy` argtypes::Vector{Any}, ) end + @static if isdefined(Core.Compiler, :inlining_policy) @eval function Core.Compiler.inlining_policy($(sigs_ex.args...)) if info isa NoInlineCallInfo if info.kind === :primitive @@ -266,6 +267,27 @@ let # overload `inlining_policy` end return @invoke Core.Compiler.inlining_policy($(args_ex.args...)) end + else + @eval function Core.Compiler.src_inlining_policy($(sigs_ex.args...)) + if info isa NoInlineCallInfo + if info.kind === :primitive + @safe_debug "Blocking inlining for primitive func" info.tt + elseif info.kind === :inactive + @safe_debug "Blocking inlining due to inactive rule" info.tt + elseif info.kind === :frule + @safe_debug "Blocking inlining due to frule" info.tt + else + @assert info.kind === :rrule + @safe_debug "Blocking inlining due to rrule" info.tt + end + return nothing + elseif info isa AlwaysInlineCallInfo + @safe_debug "Forcing inlining for primitive func" info.tt + return src + end + return @invoke Core.Compiler.src_inlining_policy($(args_ex.args...)) + end + end end import Core.Compiler: From 68e7e0721cfb27439cf31f26406cd4049206ff29 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 22 Oct 2024 16:40:29 -0700 Subject: [PATCH 381/495] EnzymeTestUtils: mark test as needing runtime activity (#2003) * fix * fix --- lib/EnzymeTestUtils/test/test_forward.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/EnzymeTestUtils/test/test_forward.jl b/lib/EnzymeTestUtils/test/test_forward.jl index 5f8e5e7c6c..f286f5bfe9 100644 --- a/lib/EnzymeTestUtils/test/test_forward.jl +++ b/lib/EnzymeTestUtils/test/test_forward.jl @@ -90,7 +90,8 @@ end x = TestStruct(randn(T, 5), randn(T)) end atol = rtol = sqrt(eps(real(T))) - test_forward(fun, Tret, (x, Tx); atol, rtol) + runtime_activity = TT <: TestStruct && (Tret <: Const) + test_forward(fun, Tret, (x, Tx); atol, rtol, runtime_activity) end end end From 064bedef27ae77fbc075506dacfb072eb894f8ce Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 22 Oct 2024 17:30:00 -0700 Subject: [PATCH 382/495] Update CI.yml --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c541e97a19..2a6ef3833b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -22,7 +22,7 @@ jobs: matrix: version: - '1.10' - - ~1.11.0-0 + - '1.11' - 'nightly' os: - ubuntu-20.04 From 6042c012126cfc0c83d7ae36edfe5072412350e6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 22 Oct 2024 17:30:42 -0700 Subject: [PATCH 383/495] Update CI.yml --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 2a6ef3833b..09039caf9a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -137,7 +137,7 @@ jobs: matrix: version: - '1.10' - - ~1.11.0-0 + - '1.11' - 'nightly' os: - ubuntu-latest From 53c31987154c191bcfd55398093ab78435341612 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 22 Oct 2024 17:34:27 -0700 Subject: [PATCH 384/495] Update pipeline.yml --- .buildkite/pipeline.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 98b8facf86..4d277625ce 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -4,6 +4,7 @@ steps: setup: version: - "1.10" + - "1.11" plugins: - JuliaCI/julia#v1: version: "{{matrix.version}}" From 820c0058405725ce657b28f458c87bdc0b38982e Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 22 Oct 2024 21:26:22 -0700 Subject: [PATCH 385/495] EnzymeTestUtils: use make_zero instead of zero_tangent (#2006) * fix * fix * EnzymeTestUtils: use make_zero instead of zero_tangent * bump * Add make zero * fix * fixup * fix --- lib/EnzymeTestUtils/Project.toml | 2 +- .../src/finite_difference_calls.jl | 10 +- lib/EnzymeTestUtils/src/generate_tangent.jl | 7 - lib/EnzymeTestUtils/src/test_reverse.jl | 4 +- lib/EnzymeTestUtils/src/to_vec.jl | 46 ++++++ lib/EnzymeTestUtils/test/generate_tangent.jl | 26 +--- src/make_zero.jl | 134 ++++++++++++++++++ test/runtests.jl | 6 + 8 files changed, 195 insertions(+), 40 deletions(-) diff --git a/lib/EnzymeTestUtils/Project.toml b/lib/EnzymeTestUtils/Project.toml index 72684a9781..2c481d7ec3 100644 --- a/lib/EnzymeTestUtils/Project.toml +++ b/lib/EnzymeTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeTestUtils" uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a" authors = ["Seth Axen ", "William Moses ", "Valentin Churavy "] -version = "0.2.0" +version = "0.2.1" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" diff --git a/lib/EnzymeTestUtils/src/finite_difference_calls.jl b/lib/EnzymeTestUtils/src/finite_difference_calls.jl index 7433b9ccd9..bb1540bfc3 100644 --- a/lib/EnzymeTestUtils/src/finite_difference_calls.jl +++ b/lib/EnzymeTestUtils/src/finite_difference_calls.jl @@ -29,13 +29,13 @@ function _fd_forward(fdm, f, rettype, y, activities) # vectorize inputs and outputs of function f_vec = first ∘ to_vec ∘ Base.splat(f_sig_args) ∘ from_vec_in if rettype <: Union{Duplicated,DuplicatedNoNeed} - all(ignores) && return zero_tangent(y) + all(ignores) && return Enzyme.make_zero(y) sig_arg_dval_vec, _ = to_vec(ẋs[.!ignores]) ret_deval_vec = FiniteDifferences.jvp(fdm, f_vec, (sig_arg_val_vec, sig_arg_dval_vec)) return from_vec_out(ret_deval_vec) elseif rettype <: Union{BatchDuplicated,BatchDuplicatedNoNeed} - all(ignores) && return (var"1"=zero_tangent(y),) + all(ignores) && return (var"1"=Enzyme.make_zero(y),) ret_dvals = map(ẋs[.!ignores]...) do sig_args_dvals... sig_args_dvals_vec, _ = to_vec(sig_args_dvals) ret_dval_vec = FiniteDifferences.jvp(fdm, f_vec, @@ -67,13 +67,13 @@ function _fd_reverse(fdm, f, ȳ, activities, active_return) xs = map(x -> x.val, activities) ignores = map(a -> a isa Const, activities) f_sig_args = _wrap_reverse_function(active_return, f, xs, ignores) - all(ignores) && return map(zero_tangent, xs) + all(ignores) && return map(Enzyme.make_zero, xs) ignores = collect(ignores) is_batch = _any_batch_duplicated(map(typeof, activities)...) batch_size = is_batch ? _batch_size(map(typeof, activities)...) : 1 x̄s = map(collect(activities)) do a if a isa Union{Const,Active} - dval = ntuple(_ -> zero_tangent(a.val), batch_size) + dval = ntuple(_ -> Enzyme.make_zero(a.val), batch_size) return is_batch ? dval : dval[1] else return a.dval @@ -178,7 +178,7 @@ function _wrap_reverse_function(active_return, f, xs, ignores) # zero, if the input and output alias. if active_return for k in keys(zeros) - zeros[k] = zero_tangent(k) + zeros[k] = Enzyme.make_zero(k) end end diff --git a/lib/EnzymeTestUtils/src/generate_tangent.jl b/lib/EnzymeTestUtils/src/generate_tangent.jl index 91822a509f..e34591b549 100644 --- a/lib/EnzymeTestUtils/src/generate_tangent.jl +++ b/lib/EnzymeTestUtils/src/generate_tangent.jl @@ -25,13 +25,6 @@ function rand_tangent(rng, x) return from_vec(v_new) end -# differs from Enzyme.make_zero primarily in that reshaped Arrays in the argument will share -# the same memory in the output. -function zero_tangent(x) - v, from_vec = to_vec(x) - return from_vec(zero(v)) -end - auto_activity(arg) = auto_activity(Random.default_rng(), arg) function auto_activity(rng, arg::Tuple) if length(arg) == 2 && arg[2] isa Type && arg[2] <: Annotation diff --git a/lib/EnzymeTestUtils/src/test_reverse.jl b/lib/EnzymeTestUtils/src/test_reverse.jl index 543f5de699..2425ea9318 100644 --- a/lib/EnzymeTestUtils/src/test_reverse.jl +++ b/lib/EnzymeTestUtils/src/test_reverse.jl @@ -92,12 +92,12 @@ function test_reverse( y = fcopy(args_copy...; deepcopy(fkwargs)...) # generate tangent for output if !_any_batch_duplicated(ret_activity, map(typeof, activities)...) - ȳ = ret_activity <: Const ? zero_tangent(y) : rand_tangent(rng, y) + ȳ = ret_activity <: Const ? Enzyme.make_zero(y) : rand_tangent(rng, y) else batch_size = _batch_size(ret_activity, map(typeof, activities)...) ks = ntuple(Symbol ∘ string, batch_size) ȳ = ntuple(batch_size) do _ - return ret_activity <: Const ? zero_tangent(y) : rand_tangent(y) + return ret_activity <: Const ? Enzyme.make_zero(y) : rand_tangent(y) end end # call finitedifferences, avoid mutating original arguments diff --git a/lib/EnzymeTestUtils/src/to_vec.jl b/lib/EnzymeTestUtils/src/to_vec.jl index 412c6efb1b..6585e53acb 100644 --- a/lib/EnzymeTestUtils/src/to_vec.jl +++ b/lib/EnzymeTestUtils/src/to_vec.jl @@ -89,6 +89,52 @@ function to_vec(x::RT, seen_vecs::AliasDict) where {RT<:Array} end return x_vec, Array_from_vec end + +@static if VERSION < v"1.11-" +else +# basic containers: loop over defined elements, recursively converting them to vectors +function to_vec(x::RT, seen_vecs::AliasDict) where {RT<:GenericMemory} + has_seen = haskey(seen_vecs, x) + is_const = Enzyme.Compiler.guaranteed_const_nongen(RT, nothing) + if has_seen || is_const + x_vec = Float32[] + else + x_vecs = Vector{<:AbstractFloat}[] + from_vecs = [] + subvec_inds = UnitRange{Int}[] + l = 0 + for i in eachindex(x) + isassigned(x, i) || continue + xi_vec, xi_from_vec = to_vec(x[i], seen_vecs) + push!(x_vecs, xi_vec) + push!(from_vecs, xi_from_vec) + push!(subvec_inds, (l + 1):(l + length(xi_vec))) + l += length(xi_vec) + end + x_vec = reduce(vcat, x_vecs; init=Float32[]) + seen_vecs[x] = x_vec + end + function Memory_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict) + if xor(has_seen, haskey(seen_xs, x)) + throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized.")) + end + has_seen && return reshape(seen_xs[x], size(x)) + is_const && return x + x_new = typeof(x)(undef, size(x)) + k = 1 + for i in eachindex(x) + isassigned(x, i) || continue + xi = from_vecs[k](x_vec_new[subvec_inds[k]], seen_xs) + x_new[i] = xi + k += 1 + end + seen_xs[x] = x_new + return x_new + end + return x_vec, Memory_from_vec +end +end + function to_vec(x::Tuple, seen_vecs::AliasDict) x_vec, from_vec = to_vec(collect(x), seen_vecs) function Tuple_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict) diff --git a/lib/EnzymeTestUtils/test/generate_tangent.jl b/lib/EnzymeTestUtils/test/generate_tangent.jl index 1e9f9727f4..738f0afa3d 100644 --- a/lib/EnzymeTestUtils/test/generate_tangent.jl +++ b/lib/EnzymeTestUtils/test/generate_tangent.jl @@ -1,6 +1,6 @@ using Test using EnzymeTestUtils -using EnzymeTestUtils: rand_tangent, zero_tangent +using EnzymeTestUtils: rand_tangent using Enzyme using Quaternions @@ -42,30 +42,6 @@ using Quaternions @test y.a != x.a end - @testset "zero_tangent" begin - @test zero_tangent(1) == 1 - @test zero_tangent(true) == true - @test zero_tangent(false) == false - @test zero_tangent(:foo) === :foo - @test zero_tangent("bar") === "bar" - @testset for T in ( - Float32, Float64, ComplexF32, ComplexF64, QuaternionF32, QuaternionF64 - ) - x = randn(T) - @test zero_tangent(x) === zero(T) - y = randn(T, 5) - @test zero_tangent(y) == zero(y) - @test zero_tangent(y) isa typeof(y) - end - x = TestStruct(TestStruct(:foo, TestStruct(1, 3.0f0 + 1im)), [4.0, 5.0]) - y = zero_tangent(x) - @test y.x.x == :foo - @test y.x.a.x == 1 - @test y.x.a.a === zero(ComplexF32) - @test y.a isa Vector{Float64} - @test y.a == zero(x.a) - end - @testset "auto_activity" begin @test EnzymeTestUtils.auto_activity((1.0, Const)) === Const(1.0) @test EnzymeTestUtils.auto_activity((1.0, Active)) === Active(1.0) diff --git a/src/make_zero.jl b/src/make_zero.jl index 4f627581ea..4130f6ce4d 100644 --- a/src/make_zero.jl +++ b/src/make_zero.jl @@ -16,6 +16,22 @@ end return Base.zero(x) end + +@static if VERSION < v"1.11-" +else +@inline function EnzymeCore.make_zero( + x::GenericMemory{kind, FT}, +)::GenericMemory{kind, FT} where {FT<:AbstractFloat,kind} + return Base.zero(x) +end +@inline function EnzymeCore.make_zero( + x::GenericMemory{kind, Complex{FT}}, +)::GenericMemory{kind, Complex{FT}} where {FT<:AbstractFloat,kind} + return Base.zero(x) +end +end + + @inline function EnzymeCore.make_zero( ::Type{Array{FT,N}}, seen::IdDict, @@ -43,6 +59,36 @@ end return newa end +@static if VERSION < v"1.11-" +else +@inline function EnzymeCore.make_zero( + ::Type{GenericMemory{kind, FT}}, + seen::IdDict, + prev::GenericMemory{kind, FT}, + ::Val{copy_if_inactive} = Val(false), +)::GenericMemory{kind, FT} where {copy_if_inactive,FT<:AbstractFloat,kind} + if haskey(seen, prev) + return seen[prev] + end + newa = Base.zero(prev) + seen[prev] = newa + return newa +end +@inline function EnzymeCore.make_zero( + ::Type{GenericMemory{kind, Complex{FT}}}, + seen::IdDict, + prev::GenericMemory{kind, Complex{FT}}, + ::Val{copy_if_inactive} = Val(false), +)::GenericMemory{kind, Complex{FT}} where {copy_if_inactive,FT<:AbstractFloat,kind} + if haskey(seen, prev) + return seen[prev] + end + newa = Base.zero(prev) + seen[prev] = newa + return newa +end +end + @inline function EnzymeCore.make_zero( ::Type{RT}, seen::IdDict, @@ -86,6 +132,34 @@ end return newa end +@static if VERSION < v"1.11-" +else +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT<:GenericMemory} + if haskey(seen, prev) + return seen[prev] + end + if guaranteed_const_nongen(RT, nothing) + return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev + end + newa = RT(undef, size(prev)) + seen[prev] = newa + for I in eachindex(prev) + if isassigned(prev, I) + pv = prev[I] + innerty = Core.Typeof(pv) + @inbounds newa[I] = + EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) + end + end + return newa +end +end + @inline function EnzymeCore.make_zero( ::Type{RT}, seen::IdDict, @@ -267,6 +341,25 @@ end nothing end +@static if VERSION < v"1.11-" +else +@inline function EnzymeCore.make_zero!( + prev::GenericMemory{kind, T}, + seen::ST, +)::Nothing where {T<:AbstractFloat,kind,ST} + fill!(prev, zero(T)) + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::Array{GenericMemory, Complex{T}}, + seen::ST, +)::Nothing where {T<:AbstractFloat,kind,ST} + fill!(prev, zero(Complex{T})) + nothing +end +end + @inline function EnzymeCore.make_zero!( prev::Base.RefValue{T}, )::Nothing where {T<:AbstractFloat} @@ -318,6 +411,47 @@ end nothing end +@static if VERSION < v"1.11-" +else +@inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T})::Nothing where {T<:AbstractFloat,kind} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::GenericMemory{kind, Complex{T}}, +)::Nothing where {T<:AbstractFloat, kind} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T}, seen::ST)::Nothing where {T,kind,ST} + if guaranteed_const_nongen(T, nothing) + return + end + if in(seen, prev) + return + end + push!(seen, prev) + + for I in eachindex(prev) + if isassigned(prev, I) + pv = prev[I] + SBT = Core.Typeof(pv) + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + @inbounds prev[I] = make_zero_immutable!(pv, seen) + nothing + else + EnzymeCore.make_zero!(pv, seen) + nothing + end + end + end + nothing +end +end + + @inline function EnzymeCore.make_zero!( prev::Base.RefValue{T}, seen::ST, diff --git a/test/runtests.jl b/test/runtests.jl index 18d2e2da79..f33f46fb8a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -172,7 +172,13 @@ end world = codegen_world_age(typeof(mul2), Tuple{Vector{Float64}}) forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(mul2)}, Active, Tuple{Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, true)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) res = forward(Const(mul2), d) + + @static if VERSION < v"1.11-" @test typeof(res[1]) == Tuple{Float64, Float64} + else + @test typeof(res[1]) == NamedTuple{(Symbol("1"),Symbol("2"),Symbol("3"),Symbol("4"),Symbol("5"),Symbol("6")), Tuple{Any, Core.LLVMPtr{UInt8, 0}, Any, Core.LLVMPtr{Any, 0}, Float64, Float64}} + end + pullback(Const(mul2), d, 1.0, res[1]) @test d.dval[1] ≈ 5.0 @test d.dval[2] ≈ 3.0 From 064e65c8c8c5a2babf84929732bd179dec7d4305 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 22 Oct 2024 23:22:33 -0700 Subject: [PATCH 386/495] More 1.11 debugging (#2007) * More 1.11 debugging * fix shadow * fix --- src/make_zero.jl | 2 +- src/rules/llvmrules.jl | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/make_zero.jl b/src/make_zero.jl index 4130f6ce4d..f2fd055c61 100644 --- a/src/make_zero.jl +++ b/src/make_zero.jl @@ -352,7 +352,7 @@ else end @inline function EnzymeCore.make_zero!( - prev::Array{GenericMemory, Complex{T}}, + prev::GenericMemory{kind, Complex{T}}, seen::ST, )::Nothing where {T<:AbstractFloat,kind,ST} fill!(prev, zero(Complex{T})) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 667d6f8abb..35399d0b24 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -656,6 +656,9 @@ function arraycopy_common(fwd, B, orig, shadowsrc, gutils, shadowdst; len=nothin algn = 0 i8 = LLVM.IntType(8) + shadowsrcs = LLVM.Value[] + shadowdsts = LLVM.Value[] + for i = 1:width evsrc = if width == 1 @@ -691,6 +694,14 @@ function arraycopy_common(fwd, B, orig, shadowsrc, gutils, shadowdst; len=nothin LLVM.memset!(B, shadowdst0, LLVM.ConstantInt(i8, 0, false), length, algn) end + push!(shadowsrcs, shadowsrc0) + push!(shadowdsts, shadowdst0) + end + + for i in 1:width + shadowsrc0 = shadowsrcs[i] + shadowdst0 = shadowdsts[i] + API.sub_transfer( gutils, fwd ? API.DEM_ReverseModePrimal : API.DEM_ReverseModeGradient, From 5df3d5cdf07766745077b0d7140f5e1bd00f0efd Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 22 Oct 2024 23:22:50 -0700 Subject: [PATCH 387/495] Update Project.toml (#2008) * Update Project.toml * newly working tests --- Project.toml | 2 +- lib/EnzymeTestUtils/test/test_reverse.jl | 8 +------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index b1c9a8ed17..480db14294 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.4" -Enzyme_jll = "0.0.156" +Enzyme_jll = "0.0.157" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" diff --git a/lib/EnzymeTestUtils/test/test_reverse.jl b/lib/EnzymeTestUtils/test/test_reverse.jl index 901c259af8..9a04000027 100644 --- a/lib/EnzymeTestUtils/test/test_reverse.jl +++ b/lib/EnzymeTestUtils/test/test_reverse.jl @@ -131,13 +131,7 @@ end Tx in (Const, Duplicated, BatchDuplicated) are_activities_compatible(Tret, Tx) || continue - if Tx <: Const - test_reverse(f, Tret, (x, Tx)) - else - @test_broken !fails() do - return test_reverse(f, Tret, (x, Tx)) - end - end + test_reverse(f, Tret, (x, Tx)) end end From 647b13c77da370c93ad933c3f7edb695c49ca693 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 23 Oct 2024 14:05:31 +0200 Subject: [PATCH 388/495] Add CI testing for EnzymeCore (#2009) --- .github/workflows/CI.yml | 62 +++++++++++++++++++++++++++++++++ .gitignore | 1 + lib/EnzymeCore/test/misc.jl | 27 ++++++++++++++ lib/EnzymeCore/test/runtests.jl | 28 +++------------ 4 files changed, 94 insertions(+), 24 deletions(-) create mode 100644 lib/EnzymeCore/test/misc.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 09039caf9a..643cb2c043 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -188,6 +188,68 @@ jobs: files: lcov.info token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false # or true if you want CI to fail when Codecov fails + enzymecore: + name: EnzymeCore - Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ matrix.libEnzyme }} libEnzyme - ${{ github.event_name }} + runs-on: ${{ matrix.os }} + continue-on-error: ${{ matrix.version == 'nightly' }} + env: + JULIA_PROJECT: "lib/EnzymeCore" + strategy: + fail-fast: false + matrix: + version: + - '1.10' + - ~1.11.0-0 + - 'nightly' + os: + - ubuntu-latest + arch: + - x64 + libEnzyme: [packaged] + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - uses: actions/cache@v2 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: setup EnzymeCore + shell: julia --color=yes {0} + id: setup_testutils + continue-on-error: ${{ matrix.version == 'nightly' }} + run: | + using Pkg + Pkg.develop([PackageSpec(; path) for path in (".",)]) + Pkg.instantiate() + env: + JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager + - name: Run the tests + if: matrix.version != 'nightly' || steps.setup_testutils.outcome == 'success' + continue-on-error: ${{ matrix.version == 'nightly' }} + id: run_tests + shell: julia --color=yes {0} + run: | + using Pkg + Pkg.test("EnzymeCore"; coverage=true) + - uses: julia-actions/julia-processcoverage@v1 + if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' + with: + directories: lib/EnzymeCore/src + - uses: codecov/codecov-action@v4 + if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success' + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false # or true if you want CI to fail when Codecov fails integration: name: Integration Tests - ${{ matrix.test }} runs-on: ${{ matrix.os }} diff --git a/.gitignore b/.gitignore index e7ee8ed2f5..09ad77f58c 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.jl.cov *.jl.mem /Manifest.toml +lib/EnzymeCore/Manifest.toml /Manifest-v*.toml /test/Manifest.toml /docs/Manifest.toml diff --git a/lib/EnzymeCore/test/misc.jl b/lib/EnzymeCore/test/misc.jl new file mode 100644 index 0000000000..3c24ddc7c3 --- /dev/null +++ b/lib/EnzymeCore/test/misc.jl @@ -0,0 +1,27 @@ +using Test +using EnzymeCore +import EnzymeCore.EnzymeRules: forward, has_frule_from_sig + +g(x) = x ^ 2 +function forward(config, ::Const{typeof(g)}, ::Type{<:Const}, x::Const) + return Const(g(x.val)) +end + +@test has_frule_from_sig(Base.signature_type(g, Tuple{Float64})) + +f(;kwargs) = 1.0 + +function forward(config, ::Const{typeof(f)}, ::Type{<:Const}; kwargs...) + return Const(f(; kwargs...)) +end + +@test has_frule_from_sig(Base.signature_type(f, Tuple{})) + +data = [1.0, 2.0, 3.0, 4.0] + +d = @view data[2:end] +y = @view data[3:end] +@test_skip @test_throws AssertionError Duplicated(d, y) + +@test_throws ErrorException Active(data) +@test_skip @test_throws ErrorException Active(d) diff --git a/lib/EnzymeCore/test/runtests.jl b/lib/EnzymeCore/test/runtests.jl index d85d4dea15..114e7d7157 100644 --- a/lib/EnzymeCore/test/runtests.jl +++ b/lib/EnzymeCore/test/runtests.jl @@ -1,28 +1,8 @@ using Test using EnzymeCore -import EnzymeCore.EnzymeRules: forward, has_frule_from_sig - -g(x) = x ^ 2 -function forward(config, ::Const{typeof(g)}, ::Type{<:Const}, x::Const) - return Const(g(x.val)) -end - -@test has_frule_from_sig(Base.signature_type(g, Tuple{Float64})) - -f(;kwargs) = 1.0 - -function forward(config, ::Const{typeof(f)}, ::Type{<:Const}; kwargs...) - return Const(f(; kwargs...)) +@testset verbose = true "EnzymeCore" begin + @testset "Miscellaneous" begin + include("misc.jl") + end end - -@test has_frule_from_sig(Base.signature_type(f, Tuple{})) - -data = [1.0, 2.0, 3.0, 4.0] - -d = @view data[2:end] -y = @view data[3:end] -@test_throws ErrorException Duplicated(d, y) - -@test_throws ErrorException Active(data) -@test_throws ErrorException Active(d) \ No newline at end of file From cdee02827c96388bfa4b076250dca5d5b431e788 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 24 Oct 2024 16:02:43 -0700 Subject: [PATCH 389/495] No strict aliasing or f9 (#2014) --- test/runtests.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index f33f46fb8a..777cb4a63d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -695,6 +695,7 @@ end @test autodiff(Reverse, f8, Active, Active(2.0))[1][1] == 2 @test autodiff(Forward, f8, Duplicated(2.0, 1.0))[1] == 2 + Enzyme.API.strictAliasing!(false) function f9(x) y = [] foreach(i -> push!(y, i^2), [1.0, x, x]) @@ -703,6 +704,7 @@ end @test autodiff(Reverse, f9, Active, Active(2.0))[1][1] == 8 @test autodiff(Forward, f9, Duplicated(2.0, 1.0))[1] == 8 + Enzyme.API.strictAliasing!(true) f10(x) = hypot(x, 2x) @test autodiff(Reverse, f10, Active, Active(2.0))[1][1] == sqrt(5) end From ecd490cc112479c1c1ab137fdde6ca4f34d883e3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 28 Oct 2024 06:25:37 -0400 Subject: [PATCH 390/495] More 1.11 stuff (#2015) * More 1.11 stuff * fixup --- src/compiler.jl | 38 +++++++++++++++++++++++++++++++------ src/compiler/interpreter.jl | 5 +++++ test/abi.jl | 18 ++++++++++++++++++ 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index b3806e2fe8..22f0c21b13 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1615,8 +1615,8 @@ function julia_error( legal2, obj = absint(cur) # Only do so for the immediate operand/etc to a phi, since otherwise we will make multiple - if legal2 && - active_reg_inner(TT, (), world) == ActiveState && + if legal2 + if active_reg_inner(TT, (), world) == ActiveState && isa(cur, LLVM.ConstantExpr) && cur == data2 if width == 1 @@ -1634,6 +1634,14 @@ function julia_error( end return shadowres end + end + +@static if VERSION < v"1.11-" +else + if obj isa Memory && obj == typeof(obj).instance + return make_batched(ncur, prevbb) + end +end end badval = if legal2 @@ -1652,10 +1660,8 @@ function julia_error( if isa(cur, LLVM.UndefValue) return make_batched(ncur, prevbb) end - @static if LLVM.version() >= v"12" - if isa(cur, LLVM.PoisonValue) - return make_batched(ncur, prevbb) - end + if isa(cur, LLVM.PoisonValue) + return make_batched(ncur, prevbb) end if isa(cur, LLVM.ConstantAggregateZero) return make_batched(ncur, prevbb) @@ -1794,6 +1800,18 @@ function julia_error( return shadowres end end + + if isa(cur, LLVM.LoadInst) || isa(cur, LLVM.BitCastInst) || isa(cur, LLVM.AddrSpaceCastInst) || (isa(cur, LLVM.GetElementPtrInst) && all(x->isa(x, LLVM.ConstantInt), operands(cur)[2:end])) + lhs = make_replacement(operands(cur)[1], prevbb) + if illegal + return ncur + end + if lhs == operands(ncur)[1] + return make_batched(ncur, prevbb) + elseif width != 1 && isa(lhs, LLVM.InsertValueInst) && operands(lhs)[2] == operands(ncur)[1] + return make_batched(ncur, prevbb) + end + end if isa(cur, LLVM.PHIInst) Bphi = IRBuilder() @@ -6322,6 +6340,14 @@ function GPUCompiler.codegen( func = mi.specTypes.parameters[1] +@static if VERSION < v"1.11-" +else + if func == typeof(Core.memoryref) + attributes = function_attributes(llvmfn) + push!(attributes, EnumAttribute("alwaysinline", 0)) + end +end + meth = mi.def name = meth.name jlmod = meth.module diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 19337961f6..bf80c3be30 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -111,6 +111,11 @@ end function is_alwaysinline_func(@nospecialize(TT)) isa(TT, DataType) || return false + @static if VERSION ≥ v"1.11-" + if TT.parameters[1] == typeof(Core.memoryref) + return true + end + end return false end diff --git a/test/abi.jl b/test/abi.jl index f27affd3a4..1c62741ef1 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -590,5 +590,23 @@ end Enzyme.autodiff(Forward, byrefs, BatchDuplicated([1.0], ([1.0], [1.0])), BatchDuplicated([1.0], ([1.0], [1.0]) ) ) end + +function myunique0() + return Vector{Float64}(undef, 0) +end +@static if VERSION < v"1.11-" +@testset "Forward mode array construct" begin + autodiff(Forward, myunique0, Duplicated) +end +else +function myunique() + m = Memory{Float64}.instance + return Core.memoryref(m) +end +@testset "Forward mode array construct" begin + autodiff(Forward, myunique, Duplicated) + autodiff(Forward, myunique0, Duplicated) +end +end include("usermixed.jl") From 229db30fb1f0da0d84998f4420466679d79d24ec Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 28 Oct 2024 15:25:20 -0500 Subject: [PATCH 391/495] 1.11 fix copy bounds error (#2017) * 1.11 fix copy bounds error * fix * unsafe is unsafe --- src/compiler/interpreter.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index bf80c3be30..cbb144af41 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -318,7 +318,12 @@ end else @inline function myunsafe_copyto!(dest::MemoryRef{T}, src::MemoryRef{T}, n) where {T} Base.@_terminates_globally_notaskstate_meta - @boundscheck memoryref(dest, n), memoryref(src, n) + # if dest.length < n + # throw(BoundsError(dest, 1:n)) + # end + # if src.length < n + # throw(BoundsError(src, 1:n)) + # end t1 = Base.@_gc_preserve_begin dest t2 = Base.@_gc_preserve_begin src Base.memmove(pointer(dest), pointer(src), n * Base.aligned_sizeof(T)) From 201c993afd2dad7818b4d0c6b43b0c07cb3f558f Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 28 Oct 2024 16:26:21 -0500 Subject: [PATCH 392/495] Generalize interpreter (#2019) --- src/compiler/interpreter.jl | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index cbb144af41..098b476d73 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -40,14 +40,18 @@ struct EnzymeInterpreter <: AbstractInterpreter inf_params::InferenceParams opt_params::OptimizationParams - mode::API.CDerivativeMode + forward_rules::Bool + reverse_rules::Bool + deferred_lower::Bool end function EnzymeInterpreter( cache_or_token, mt::Union{Nothing,Core.MethodTable}, world::UInt, - mode::API.CDerivativeMode, + forward_rules::Bool, + reverse_rules::Bool, + deferred_lower::Bool = true ) @assert world <= Base.get_world_counter() @@ -70,10 +74,20 @@ function EnzymeInterpreter( # parameters for inference and optimization parms, OptimizationParams(), - mode, + forward_rules, + reverse_rules, + deferred_lower ) end +EnzymeInterpreter( + cache_or_token, + mt::Union{Nothing,Core.MethodTable}, + world::UInt, + mode::API.CDerivativeMode, + deferred_lower::Bool = true +) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower) + Core.Compiler.InferenceParams(interp::EnzymeInterpreter) = interp.inf_params Core.Compiler.OptimizationParams(interp::EnzymeInterpreter) = interp.opt_params get_inference_world(interp::EnzymeInterpreter) = interp.world @@ -206,12 +220,18 @@ function Core.Compiler.abstract_call_gf_by_type( callinfo = AlwaysInlineCallInfo(callinfo, atype) elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) callinfo = NoInlineCallInfo(callinfo, atype, :inactive) - elseif interp.mode == API.DEM_ForwardMode - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) + else + if interp.forward_rules + if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) callinfo = NoInlineCallInfo(callinfo, atype, :frule) + end + end + + if interp.reverse_rules + if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) + callinfo = NoInlineCallInfo(callinfo, atype, :rrule) + end end - elseif EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) - callinfo = NoInlineCallInfo(callinfo, atype, :rrule) end @static if VERSION ≥ v"1.11-" return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) @@ -392,7 +412,7 @@ function abstract_call_known( end end - if f === Enzyme.autodiff && length(argtypes) >= 4 + if interp.deferred_lower && f === Enzyme.autodiff && length(argtypes) >= 4 if widenconst(argtypes[2]) <: Enzyme.Mode && widenconst(argtypes[3]) <: Enzyme.Annotation && widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation} From 92d1ebdb1a9dd42078dfd5aa3c9b39506bc9c894 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 28 Oct 2024 22:12:09 -0500 Subject: [PATCH 393/495] Generalize interpreter v2 (#2023) * Generalize interpreter v2 * More fix * fix --- src/compiler/interpreter.jl | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 098b476d73..d1db80b0b9 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -23,7 +23,7 @@ else import Core.Compiler: get_world_counter, get_world_counter as get_inference_world end -struct EnzymeInterpreter <: AbstractInterpreter +struct EnzymeInterpreter{T} <: AbstractInterpreter @static if HAS_INTEGRATED_CACHE token::Any else @@ -43,6 +43,7 @@ struct EnzymeInterpreter <: AbstractInterpreter forward_rules::Bool reverse_rules::Bool deferred_lower::Bool + handler::T end function EnzymeInterpreter( @@ -51,7 +52,8 @@ function EnzymeInterpreter( world::UInt, forward_rules::Bool, reverse_rules::Bool, - deferred_lower::Bool = true + deferred_lower::Bool = true, + handler = nothing ) @assert world <= Base.get_world_counter() @@ -76,7 +78,8 @@ function EnzymeInterpreter( OptimizationParams(), forward_rules, reverse_rules, - deferred_lower + deferred_lower, + handler ) end @@ -85,8 +88,9 @@ EnzymeInterpreter( mt::Union{Nothing,Core.MethodTable}, world::UInt, mode::API.CDerivativeMode, - deferred_lower::Bool = true -) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower) + deferred_lower::Bool = true, + handler = nothing +) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, handler) Core.Compiler.InferenceParams(interp::EnzymeInterpreter) = interp.inf_params Core.Compiler.OptimizationParams(interp::EnzymeInterpreter) = interp.opt_params @@ -112,16 +116,8 @@ Core.Compiler.may_compress(::EnzymeInterpreter) = true Core.Compiler.may_discard_trees(::EnzymeInterpreter) = false Core.Compiler.verbose_stmt_info(::EnzymeInterpreter) = false -if isdefined(Base.Experimental, Symbol("@overlay")) - Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = - Core.Compiler.OverlayMethodTable(interp.world, interp.method_table) -else - - # On 1.6- CUDA.jl will poison the method table at the end of the world - # using GPUCompiler: WorldOverlayMethodTable - # Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = - # WorldOverlayMethodTable(interp.world) -end +Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = + Core.Compiler.OverlayMethodTable(interp.world, interp.method_table) function is_alwaysinline_func(@nospecialize(TT)) isa(TT, DataType) || return false @@ -431,6 +427,9 @@ function abstract_call_known( ) end end + if interp.handler != nothing + return interp.handler(interp, f, arginfo, si, sv, max_methods) + end return Base.@invoke abstract_call_known( interp::AbstractInterpreter, f, From 2bde982c3ad54542e5d393ad2f1e27fd9a68efe0 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 28 Oct 2024 22:13:25 -0500 Subject: [PATCH 394/495] fix EnzymeTestUtils (#2022) * fix * fix --- lib/EnzymeTestUtils/Project.toml | 2 +- .../src/finite_difference_calls.jl | 39 +++++++++++++------ lib/EnzymeTestUtils/src/generate_tangent.jl | 7 ++++ lib/EnzymeTestUtils/src/test_reverse.jl | 4 +- lib/EnzymeTestUtils/test/generate_tangent.jl | 26 ++++++++++++- 5 files changed, 62 insertions(+), 16 deletions(-) diff --git a/lib/EnzymeTestUtils/Project.toml b/lib/EnzymeTestUtils/Project.toml index 2c481d7ec3..981c284b8c 100644 --- a/lib/EnzymeTestUtils/Project.toml +++ b/lib/EnzymeTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeTestUtils" uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a" authors = ["Seth Axen ", "William Moses ", "Valentin Churavy "] -version = "0.2.1" +version = "0.2.2" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" diff --git a/lib/EnzymeTestUtils/src/finite_difference_calls.jl b/lib/EnzymeTestUtils/src/finite_difference_calls.jl index bb1540bfc3..dadc53c7f1 100644 --- a/lib/EnzymeTestUtils/src/finite_difference_calls.jl +++ b/lib/EnzymeTestUtils/src/finite_difference_calls.jl @@ -29,13 +29,13 @@ function _fd_forward(fdm, f, rettype, y, activities) # vectorize inputs and outputs of function f_vec = first ∘ to_vec ∘ Base.splat(f_sig_args) ∘ from_vec_in if rettype <: Union{Duplicated,DuplicatedNoNeed} - all(ignores) && return Enzyme.make_zero(y) + all(ignores) && return zero_tangent(y) sig_arg_dval_vec, _ = to_vec(ẋs[.!ignores]) ret_deval_vec = FiniteDifferences.jvp(fdm, f_vec, (sig_arg_val_vec, sig_arg_dval_vec)) return from_vec_out(ret_deval_vec) elseif rettype <: Union{BatchDuplicated,BatchDuplicatedNoNeed} - all(ignores) && return (var"1"=Enzyme.make_zero(y),) + all(ignores) && return (var"1"=zero_tangent(y),) ret_dvals = map(ẋs[.!ignores]...) do sig_args_dvals... sig_args_dvals_vec, _ = to_vec(sig_args_dvals) ret_dval_vec = FiniteDifferences.jvp(fdm, f_vec, @@ -49,6 +49,16 @@ function _fd_forward(fdm, f, rettype, y, activities) end _fd_forward(fdm, f, ::Type{<:Const}, y, activities) = () +function multi_tovec(active_return, vals) + if active_return + v0, v1 = vals[1], Base.tail(vals) + res = vcat(to_vec(v0)[1], to_vec(v1)[1]) + return res + else + to_vec(vals)[1] + end +end + #= _fd_reverse(fdm, f, ȳ, activities, active_return) @@ -67,13 +77,13 @@ function _fd_reverse(fdm, f, ȳ, activities, active_return) xs = map(x -> x.val, activities) ignores = map(a -> a isa Const, activities) f_sig_args = _wrap_reverse_function(active_return, f, xs, ignores) - all(ignores) && return map(Enzyme.make_zero, xs) + all(ignores) && return map(zero_tangent, xs) ignores = collect(ignores) is_batch = _any_batch_duplicated(map(typeof, activities)...) batch_size = is_batch ? _batch_size(map(typeof, activities)...) : 1 x̄s = map(collect(activities)) do a if a isa Union{Const,Active} - dval = ntuple(_ -> Enzyme.make_zero(a.val), batch_size) + dval = ntuple(_ -> zero_tangent(a.val), batch_size) return is_batch ? dval : dval[1] else return a.dval @@ -84,15 +94,15 @@ function _fd_reverse(fdm, f, ȳ, activities, active_return) sigarginds = eachindex(x̄s)[.!ignores] sigargs_vec, from_vec_in = to_vec(sigargs) # vectorize inputs and outputs of function - f_vec = first ∘ to_vec ∘ Base.splat(f_sig_args) ∘ from_vec_in + f_vec = Base.Fix1(multi_tovec, active_return) ∘ Base.splat(f_sig_args) ∘ from_vec_in if !is_batch ȳ_extended = (ȳ, s̄igargs...) - ȳ_extended_vec, _ = to_vec(ȳ_extended) + ȳ_extended_vec = multi_tovec(active_return, ȳ_extended) fd_vec = only(FiniteDifferences.j′vp(fdm, f_vec, ȳ_extended_vec, sigargs_vec)) fd = from_vec_in(fd_vec) else fd = Tuple(zip(map(ȳ, s̄igargs...) do ȳ_extended... - ȳ_extended_vec, _ = to_vec(ȳ_extended) + ȳ_extended_vec = multi_tovec(active_return, ȳ_extended) fd_vec = only(FiniteDifferences.j′vp(fdm, f_vec, ȳ_extended_vec, sigargs_vec)) return from_vec_in(fd_vec) @@ -154,11 +164,13 @@ function _wrap_reverse_function(active_return, f, xs, ignores) retargs = Any[] j = 1 + inputs = IdDict() + for (i, (x, ignore)) in enumerate(zip(xs, ignores)) if ignore - push!(callargs, deepcopy(x)) + push!(callargs, Base.deepcopy_internal(x, inputs)) else - arg = deepcopy(sigargs[j]) + arg = Base.deepcopy_internal(sigargs[j], inputs) push!(callargs, arg) push!(retargs, arg) j += 1 @@ -172,17 +184,20 @@ function _wrap_reverse_function(active_return, f, xs, ignores) # it will already be taken into account. This is implemented using the deepcopy_internal, which # will add all objects inside the return into the dict `zeros`. zeros = IdDict() - origRet = Base.deepcopy_internal(deepcopy(f)(callargs...), zeros) + origRet = Base.deepcopy_internal(f, inputs)(callargs...) + Base.deepcopy_internal(origRet, zeros) # we will now explicitly zero all objects returned, and replace any of the args with this # zero, if the input and output alias. if active_return for k in keys(zeros) - zeros[k] = Enzyme.make_zero(k) + zeros[k] = zero_tangent(k) end + return (origRet, Base.deepcopy_internal(retargs, zeros)...) + else + return (origRet, retargs...) end - return (origRet, Base.deepcopy_internal(retargs, zeros)...) end return fnew end diff --git a/lib/EnzymeTestUtils/src/generate_tangent.jl b/lib/EnzymeTestUtils/src/generate_tangent.jl index e34591b549..91822a509f 100644 --- a/lib/EnzymeTestUtils/src/generate_tangent.jl +++ b/lib/EnzymeTestUtils/src/generate_tangent.jl @@ -25,6 +25,13 @@ function rand_tangent(rng, x) return from_vec(v_new) end +# differs from Enzyme.make_zero primarily in that reshaped Arrays in the argument will share +# the same memory in the output. +function zero_tangent(x) + v, from_vec = to_vec(x) + return from_vec(zero(v)) +end + auto_activity(arg) = auto_activity(Random.default_rng(), arg) function auto_activity(rng, arg::Tuple) if length(arg) == 2 && arg[2] isa Type && arg[2] <: Annotation diff --git a/lib/EnzymeTestUtils/src/test_reverse.jl b/lib/EnzymeTestUtils/src/test_reverse.jl index 2425ea9318..543f5de699 100644 --- a/lib/EnzymeTestUtils/src/test_reverse.jl +++ b/lib/EnzymeTestUtils/src/test_reverse.jl @@ -92,12 +92,12 @@ function test_reverse( y = fcopy(args_copy...; deepcopy(fkwargs)...) # generate tangent for output if !_any_batch_duplicated(ret_activity, map(typeof, activities)...) - ȳ = ret_activity <: Const ? Enzyme.make_zero(y) : rand_tangent(rng, y) + ȳ = ret_activity <: Const ? zero_tangent(y) : rand_tangent(rng, y) else batch_size = _batch_size(ret_activity, map(typeof, activities)...) ks = ntuple(Symbol ∘ string, batch_size) ȳ = ntuple(batch_size) do _ - return ret_activity <: Const ? Enzyme.make_zero(y) : rand_tangent(y) + return ret_activity <: Const ? zero_tangent(y) : rand_tangent(y) end end # call finitedifferences, avoid mutating original arguments diff --git a/lib/EnzymeTestUtils/test/generate_tangent.jl b/lib/EnzymeTestUtils/test/generate_tangent.jl index 738f0afa3d..1e9f9727f4 100644 --- a/lib/EnzymeTestUtils/test/generate_tangent.jl +++ b/lib/EnzymeTestUtils/test/generate_tangent.jl @@ -1,6 +1,6 @@ using Test using EnzymeTestUtils -using EnzymeTestUtils: rand_tangent +using EnzymeTestUtils: rand_tangent, zero_tangent using Enzyme using Quaternions @@ -42,6 +42,30 @@ using Quaternions @test y.a != x.a end + @testset "zero_tangent" begin + @test zero_tangent(1) == 1 + @test zero_tangent(true) == true + @test zero_tangent(false) == false + @test zero_tangent(:foo) === :foo + @test zero_tangent("bar") === "bar" + @testset for T in ( + Float32, Float64, ComplexF32, ComplexF64, QuaternionF32, QuaternionF64 + ) + x = randn(T) + @test zero_tangent(x) === zero(T) + y = randn(T, 5) + @test zero_tangent(y) == zero(y) + @test zero_tangent(y) isa typeof(y) + end + x = TestStruct(TestStruct(:foo, TestStruct(1, 3.0f0 + 1im)), [4.0, 5.0]) + y = zero_tangent(x) + @test y.x.x == :foo + @test y.x.a.x == 1 + @test y.x.a.a === zero(ComplexF32) + @test y.a isa Vector{Float64} + @test y.a == zero(x.a) + end + @testset "auto_activity" begin @test EnzymeTestUtils.auto_activity((1.0, Const)) === Const(1.0) @test EnzymeTestUtils.auto_activity((1.0, Active)) === Active(1.0) From 12c1abbf3a32db5a24bce2e5f23bd5026723d92c Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 29 Oct 2024 14:16:04 -0400 Subject: [PATCH 395/495] Julia 1.12: WIP (#2020) * Julia 1.12: WIP * linetable * Update utils.jl * Update utils.jl --- src/utils.jl | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 55dc69769e..9492dccdd6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -213,8 +213,12 @@ function codegen_world_age_generator(world::UInt, source, self, ft::Type, tt::Ty # prepare a new code info new_ci = copy(ci) empty!(new_ci.code) - empty!(new_ci.codelocs) - resize!(new_ci.linetable, 1) # see note below + @static if isdefined(Core, :DebugInfo) + new_ci.debuginfo = Core.DebugInfo(:none) + else + empty!(new_ci.codelocs) + resize!(new_ci.linetable, 1) # see note below + end empty!(new_ci.ssaflags) new_ci.ssavaluetypes = 0 new_ci.min_world = min_world[] @@ -232,7 +236,10 @@ function codegen_world_age_generator(world::UInt, source, self, ft::Type, tt::Ty # return the codegen world age push!(new_ci.code, ReturnNode(world)) push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` - push!(new_ci.codelocs, 1) # see note below + @static if isdefined(Core, :DebugInfo) + else + push!(new_ci.codelocs, 1) # see note below + end new_ci.ssavaluetypes += 1 # NOTE: we keep the first entry of the original linetable, and use it for location info From 19ef90bb96f23783c2fe0ab409d7af308a5587f4 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 30 Oct 2024 02:16:31 -0400 Subject: [PATCH 396/495] More nightly stuff (#2026) --- src/compiler/interpreter.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index d1db80b0b9..51e20f8dc4 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -57,10 +57,10 @@ function EnzymeInterpreter( ) @assert world <= Base.get_world_counter() - parms = @static if VERSION < v"1.12" - InferenceParams(unoptimize_throw_blocks = false) - else + parms = @static if VERSION >= v"1.12.0-DEV.1017" InferenceParams() + else + InferenceParams(; unoptimize_throw_blocks=false) end return EnzymeInterpreter( From 5332027dc412c5b5215608d6d64dfa3d12b94596 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 30 Oct 2024 02:38:09 -0400 Subject: [PATCH 397/495] More refined scratch dirs --- deps/build_local.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/deps/build_local.jl b/deps/build_local.jl index 32089e8553..9c47011bac 100644 --- a/deps/build_local.jl +++ b/deps/build_local.jl @@ -9,10 +9,6 @@ using Pkg, Scratch, Preferences, Libdl BUILD_TYPE = "RelWithDebInfo" BCLoad = true -# 1. Get a scratch directory -scratch_dir = get_scratch!(Enzyme_jll, "build") -isdir(scratch_dir) && rm(scratch_dir; recursive=true) - source_dir = nothing branch = nothing @@ -85,6 +81,11 @@ end LLVM_DIR = joinpath(LLVM.artifact_dir, "lib", "cmake", "llvm") LLVM_VER_MAJOR = Base.libllvm_version.major +# 1. Get a scratch directory +scratch_dir = get_scratch!(Enzyme_jll, "build_$(LLVM_VER_MAJOR)_$(llvm_assertions)") +isdir(scratch_dir) && rm(scratch_dir; recursive=true) + + # Build! @info "Building" source_dir scratch_dir LLVM_DIR BUILD_TYPE run(`cmake -DLLVM_DIR=$(LLVM_DIR) -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) -DENZYME_EXTERNAL_SHARED_LIB=ON -B$(scratch_dir) -S$(source_dir)`) From c5cf7dbf71eae999125926ee10747b45a91f213b Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 30 Oct 2024 02:44:09 -0400 Subject: [PATCH 398/495] Fix empty forward gradient (#2025) --- src/Enzyme.jl | 3 ++- src/internal_rules.jl | 10 +++++++--- src/rules/customrules.jl | 28 +++++++++++++++++++++++----- test/sugar.jl | 6 ++++++ 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index c3769e35ac..2e8643744b 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1563,6 +1563,7 @@ end end end end +@inline onehot(x::Tuple{}) = () @inline function onehot(x::NTuple{N,T}) where {T,N} onehot(NTuple{N,T}) end @@ -2141,7 +2142,7 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1])) end end else - :(specialize_output(TupleArray($tmp, size($arg)), $(vals[1]))) + tmp end else tmp diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 6fe70df8cf..539223ffa5 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -1303,18 +1303,22 @@ end function EnzymeRules.augmented_primal( config::EnzymeRules.RevConfig, func::Const{Colon}, - ::Type{<:Active}, + ::Type{RT}, start::Annotation{<:AbstractFloat}, step::Annotation{<:AbstractFloat}, stop::Annotation{<:AbstractFloat}, -) +) where RT <: Active if EnzymeRules.needs_primal(config) primal = func.val(start.val, step.val, stop.val) else primal = nothing end - return EnzymeRules.AugmentedReturn(primal, nothing, nothing) + return EnzymeRules.AugmentedReturn{ + EnzymeRules.primal_type(config, RT), + Nothing, + Nothing + }(primal, nothing, nothing) end function EnzymeRules.reverse( diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index fbd646866b..96661849b2 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -968,6 +968,7 @@ end TapeT = Nothing + if ( aug_RT <: EnzymeRules.AugmentedReturn || aug_RT <: EnzymeRules.AugmentedReturnFlexShadow @@ -996,7 +997,7 @@ end else TapeT = Any end - + mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) llvmf = nothing @@ -1027,8 +1028,16 @@ end rkwfunc = Core.kwfunc(EnzymeRules.reverse) if EnzymeRules.isapplicable(rkwfunc, rev_TT; world) @safe_debug "Applying custom reverse rule (kwcall)" TT = rev_TT - llvmf = nested_codegen!(mode, mod, rkwfunc, rev_TT, world) - rev_RT = Core.Compiler.return_type(rkwfunc, rev_TT, world) + try + llvmf = nested_codegen!(mode, mod, rkwfunc, rev_TT, world) + rev_RT = Core.Compiler.return_type(rkwfunc, rev_TT, world) + catch e + rev_TT = Tuple{typeof(world),typeof(rkwfunc),rev_TT.parameters...} + llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) + pushfirst!(args, LLVM.ConstantInt(world)) + rev_RT = Union{} + applicablefn = false + end else rev_TT = Tuple{typeof(world),typeof(rkwfunc),rev_TT.parameters...} llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) @@ -1039,8 +1048,17 @@ end else if EnzymeRules.isapplicable(EnzymeRules.reverse, rev_TT; world) @safe_debug "Applying custom reverse rule" TT = rev_TT - llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, rev_TT, world) - rev_RT = Core.Compiler.return_type(EnzymeRules.reverse, rev_TT, world) + try + llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, rev_TT, world) + rev_RT = Core.Compiler.return_type(EnzymeRules.reverse, rev_TT, world) + catch e + rev_TT = + Tuple{typeof(world),typeof(EnzymeRules.reverse),rev_TT.parameters...} + llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) + pushfirst!(args, LLVM.ConstantInt(world)) + rev_RT = Union{} + applicablefn = false + end else rev_TT = Tuple{typeof(world),typeof(EnzymeRules.reverse),rev_TT.parameters...} diff --git a/test/sugar.jl b/test/sugar.jl index 097472ab22..340a54c569 100644 --- a/test/sugar.jl +++ b/test/sugar.jl @@ -4,6 +4,12 @@ using LinearAlgebra mul_scalar(x, y) = x[1]*y[2] + x[2]*y[1] mul_vector(x, y) = [x[1]*y[2], x[2]*y[1]] +@testset "Forward Empty Gradient" begin + inp = Float64[] + res = gradient(Forward, sin, inp) + @test res[1] === inp +end + @testset "Forward Multi-Arg Gradient" begin res = gradient(Forward, mul_scalar, [2.0, 3.0], [2.7, 3.1]) @test res[1] ≈ [3.1, 2.7] From 3728b0c2d80dd6816785e47ef130c5f35716bc47 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 30 Oct 2024 09:07:49 -0500 Subject: [PATCH 399/495] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 480db14294..fd1866fa56 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.12" +version = "0.13.13" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.4" -Enzyme_jll = "0.0.157" +Enzyme_jll = "0.0.158" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From 26ca6fe82e64aaca34d85554b8d60de24360632d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Wed, 30 Oct 2024 22:17:48 +0100 Subject: [PATCH 400/495] Import `needs_primal` from Reactant.jl (#2021) * Implement `needs_primal` * Test `needs_primal` * Export `needs_primal` * Bump versions * Move `needs_primal` tests EnzymeCore * Use tilde version specifier to also support v0.8.4 * Move `WithPrimal`, `NoPrimal` tests to EnzymeCore --------- Co-authored-by: William Moses --- Project.toml | 2 +- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/EnzymeCore.jl | 14 ++++++++++++++ lib/EnzymeCore/test/runtests.jl | 28 ++++++++++++++++++++++++++++ test/runtests.jl | 19 ------------------- 5 files changed, 44 insertions(+), 21 deletions(-) diff --git a/Project.toml b/Project.toml index fd1866fa56..ed748bdc77 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ EnzymeStaticArraysExt = "StaticArrays" BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.8.4" +EnzymeCore = "0.8.4, 0.8.5" Enzyme_jll = "0.0.158" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 2e45d2c2f6..18c3bbad00 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.4" +version = "0.8.5" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 394cd00a5f..c751aaac38 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -7,6 +7,7 @@ export MixedDuplicated, BatchMixedDuplicated export DefaultABI, FFIABI, InlineABI, NonGenABI export BatchDuplicatedFunc export within_autodiff +export needs_primal function batch_size end @@ -351,6 +352,15 @@ Return a new mode which excludes the primal value. """ @inline NoPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{false,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}() +""" + needs_primal(::Mode) + needs_primal(::Type{Mode}) + +Returns `true` if the mode needs the primal value, otherwise `false`. +""" +@inline needs_primal(::ReverseMode{ReturnPrimal}) where {ReturnPrimal} = ReturnPrimal +@inline needs_primal(::Type{<:ReverseMode{ReturnPrimal}}) where {ReturnPrimal} = ReturnPrimal + """ struct ReverseModeSplit{ ReturnPrimal, @@ -424,6 +434,8 @@ Return a new instance of [`ReverseModeSplit`](@ref) mode where `Width` is set to @inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() @inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() +@inline needs_primal(::ReverseModeSplit{ReturnPrimal}) where {ReturnPrimal} = ReturnPrimal +@inline needs_primal(::Type{<:ReverseModeSplit{ReturnPrimal}}) where {ReturnPrimal} = ReturnPrimal """ struct ForwardMode{ @@ -480,6 +492,8 @@ const ForwardWithPrimal = ForwardMode{true, DefaultABI, false, false}() @inline WithPrimal(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{true,ABI,ErrIfFuncWritten,RuntimeActivity}() @inline NoPrimal(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() +@inline needs_primal(::ForwardMode{ReturnPrimal}) where {ReturnPrimal} = ReturnPrimal +@inline needs_primal(::Type{<:ForwardMode{ReturnPrimal}}) where {ReturnPrimal} = ReturnPrimal function autodiff end function autodiff_deferred end diff --git a/lib/EnzymeCore/test/runtests.jl b/lib/EnzymeCore/test/runtests.jl index 114e7d7157..61f0e7af5c 100644 --- a/lib/EnzymeCore/test/runtests.jl +++ b/lib/EnzymeCore/test/runtests.jl @@ -2,6 +2,34 @@ using Test using EnzymeCore @testset verbose = true "EnzymeCore" begin + @testset "WithPrimal" begin + @test WithPrimal(Reverse) === ReverseWithPrimal + @test NoPrimal(Reverse) === Reverse + @test WithPrimal(ReverseWithPrimal) === ReverseWithPrimal + @test NoPrimal(ReverseWithPrimal) === Reverse + + @test WithPrimal(set_runtime_activity(Reverse)) === set_runtime_activity(ReverseWithPrimal) + + @test WithPrimal(Forward) === ForwardWithPrimal + @test NoPrimal(Forward) === Forward + @test WithPrimal(ForwardWithPrimal) === ForwardWithPrimal + @test NoPrimal(ForwardWithPrimal) === Forward + + @test WithPrimal(ReverseSplitNoPrimal) === ReverseSplitWithPrimal + @test NoPrimal(ReverseSplitNoPrimal) === ReverseSplitNoPrimal + @test WithPrimal(ReverseSplitWithPrimal) === ReverseSplitWithPrimal + @test NoPrimal(ReverseSplitWithPrimal) === ReverseSplitNoPrimal + end + + @testset "needs_primal" begin + @test needs_primal(Reverse) === false + @test needs_primal(ReverseWithPrimal) === true + @test needs_primal(Forward) === false + @test needs_primal(ForwardWithPrimal) === true + @test needs_primal(ReverseSplitNoPrimal) === false + @test needs_primal(ReverseSplitWithPrimal) === true + end + @testset "Miscellaneous" begin include("misc.jl") end diff --git a/test/runtests.jl b/test/runtests.jl index 777cb4a63d..d28461d26a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3615,25 +3615,6 @@ end @test res[2][6] ≈ 6.0 end -@testset "WithPrimal" begin - @test WithPrimal(Reverse) === ReverseWithPrimal - @test NoPrimal(Reverse) === Reverse - @test WithPrimal(ReverseWithPrimal) === ReverseWithPrimal - @test NoPrimal(ReverseWithPrimal) === Reverse - - @test WithPrimal(set_runtime_activity(Reverse)) === set_runtime_activity(ReverseWithPrimal) - - @test WithPrimal(Forward) === ForwardWithPrimal - @test NoPrimal(Forward) === Forward - @test WithPrimal(ForwardWithPrimal) === ForwardWithPrimal - @test NoPrimal(ForwardWithPrimal) === Forward - - @test WithPrimal(ReverseSplitNoPrimal) === ReverseSplitWithPrimal - @test NoPrimal(ReverseSplitNoPrimal) === ReverseSplitNoPrimal - @test WithPrimal(ReverseSplitWithPrimal) === ReverseSplitWithPrimal - @test NoPrimal(ReverseSplitWithPrimal) === ReverseSplitNoPrimal -end - # TEST EXTENSIONS using SpecialFunctions @testset "SpecialFunctions ext" begin From 2c8a5818a4a51b0846e9d2982b8cfb413579fdc8 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 31 Oct 2024 10:15:51 -0500 Subject: [PATCH 401/495] Update Project.toml --- lib/EnzymeTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/EnzymeTestUtils/Project.toml b/lib/EnzymeTestUtils/Project.toml index 981c284b8c..2c481d7ec3 100644 --- a/lib/EnzymeTestUtils/Project.toml +++ b/lib/EnzymeTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeTestUtils" uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a" authors = ["Seth Axen ", "William Moses ", "Valentin Churavy "] -version = "0.2.2" +version = "0.2.1" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" From 6e6d6b8c0ad851937949ac883e72c2fa9b7674c3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 1 Nov 2024 01:32:35 -0500 Subject: [PATCH 402/495] Fix enzymecore tests (#2038) * ix enzymecore tests * fix * Update runtests.jl --- .github/workflows/CI.yml | 2 +- lib/EnzymeCore/test/runtests.jl | 38 ++++++++++++++++----------------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 643cb2c043..0f47a044b6 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -199,7 +199,7 @@ jobs: matrix: version: - '1.10' - - ~1.11.0-0 + - '1.11' - 'nightly' os: - ubuntu-latest diff --git a/lib/EnzymeCore/test/runtests.jl b/lib/EnzymeCore/test/runtests.jl index 61f0e7af5c..2fb7fd74f2 100644 --- a/lib/EnzymeCore/test/runtests.jl +++ b/lib/EnzymeCore/test/runtests.jl @@ -3,31 +3,31 @@ using EnzymeCore @testset verbose = true "EnzymeCore" begin @testset "WithPrimal" begin - @test WithPrimal(Reverse) === ReverseWithPrimal - @test NoPrimal(Reverse) === Reverse - @test WithPrimal(ReverseWithPrimal) === ReverseWithPrimal - @test NoPrimal(ReverseWithPrimal) === Reverse + @test EnzymeCore.WithPrimal(Reverse) === ReverseWithPrimal + @test EnzymeCore.NoPrimal(Reverse) === Reverse + @test EnzymeCore.WithPrimal(ReverseWithPrimal) === ReverseWithPrimal + @test EnzymeCore.NoPrimal(ReverseWithPrimal) === Reverse - @test WithPrimal(set_runtime_activity(Reverse)) === set_runtime_activity(ReverseWithPrimal) + @test EnzymeCore.WithPrimal(EnzymeCore.set_runtime_activity(Reverse)) === EnzymeCore.set_runtime_activity(ReverseWithPrimal) - @test WithPrimal(Forward) === ForwardWithPrimal - @test NoPrimal(Forward) === Forward - @test WithPrimal(ForwardWithPrimal) === ForwardWithPrimal - @test NoPrimal(ForwardWithPrimal) === Forward + @test EnzymeCore.WithPrimal(Forward) === ForwardWithPrimal + @test EnzymeCore.NoPrimal(Forward) === Forward + @test EnzymeCore.WithPrimal(ForwardWithPrimal) === ForwardWithPrimal + @test EnzymeCore.NoPrimal(ForwardWithPrimal) === Forward - @test WithPrimal(ReverseSplitNoPrimal) === ReverseSplitWithPrimal - @test NoPrimal(ReverseSplitNoPrimal) === ReverseSplitNoPrimal - @test WithPrimal(ReverseSplitWithPrimal) === ReverseSplitWithPrimal - @test NoPrimal(ReverseSplitWithPrimal) === ReverseSplitNoPrimal + @test EnzymeCore.WithPrimal(ReverseSplitNoPrimal) === ReverseSplitWithPrimal + @test EnzymeCore.NoPrimal(ReverseSplitNoPrimal) === ReverseSplitNoPrimal + @test EnzymeCore.WithPrimal(ReverseSplitWithPrimal) === ReverseSplitWithPrimal + @test EnzymeCore.NoPrimal(ReverseSplitWithPrimal) === ReverseSplitNoPrimal end @testset "needs_primal" begin - @test needs_primal(Reverse) === false - @test needs_primal(ReverseWithPrimal) === true - @test needs_primal(Forward) === false - @test needs_primal(ForwardWithPrimal) === true - @test needs_primal(ReverseSplitNoPrimal) === false - @test needs_primal(ReverseSplitWithPrimal) === true + @test EnzymeCore.needs_primal(Reverse) === false + @test EnzymeCore.needs_primal(ReverseWithPrimal) === true + @test EnzymeCore.needs_primal(Forward) === false + @test EnzymeCore.needs_primal(ForwardWithPrimal) === true + @test EnzymeCore.needs_primal(ReverseSplitNoPrimal) === false + @test EnzymeCore.needs_primal(ReverseSplitWithPrimal) === true end @testset "Miscellaneous" begin From ff9df4333a7941c518c7298073eabf8c6939c2b0 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 1 Nov 2024 02:24:10 -0500 Subject: [PATCH 403/495] More 1.11 type info (#2028) * More 1.11 type info * fix * fix * pre * fix * no infinite recur * fix * fix * fix * fix * fix * fix * fix * fix * clean --- src/absint.jl | 312 ++++++++++++++++++++++++++------------- src/compiler.jl | 62 ++++++-- src/compiler/optimize.jl | 13 +- src/typetree.jl | 2 +- test/Project.toml | 2 +- 5 files changed, 270 insertions(+), 121 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index dba99d2b00..67c2170ce3 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -5,6 +5,23 @@ function absint(arg::LLVM.Value, partial::Bool = false) if isa(arg, LLVM.BitCastInst) || isa(arg, LLVM.AddrSpaceCastInst) return absint(operands(arg)[1], partial) end + if isa(arg, ConstantExpr) && value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked) + ce = arg + while isa(ce, ConstantExpr) + if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || + opcode(ce) == LLVM.API.LLVMBitCast || + opcode(ce) == LLVM.API.LLVMIntToPtr + ce = operands(ce)[1] + else + break + end + end + if isa(ce, LLVM.ConstantInt) + ptr = reinterpret(Ptr{Cvoid}, convert(UInt, ce)) + typ = Base.unsafe_pointer_to_objref(ptr) + return (true, typ) + end + end if isa(arg, ConstantExpr) if opcode(arg) == LLVM.API.LLVMAddrSpaceCast || opcode(arg) == LLVM.API.LLVMBitCast return absint(operands(arg)[1], partial) @@ -103,17 +120,6 @@ function absint(arg::LLVM.Value, partial::Bool = false) end end end - if isa(arg, ConstantExpr) - ce = arg - if opcode(ce) == LLVM.API.LLVMIntToPtr - ce = operands(ce)[1] - if isa(ce, LLVM.ConstantInt) - ptr = reinterpret(Ptr{Cvoid}, convert(UInt, ce)) - typ = Base.unsafe_pointer_to_objref(ptr) - return (true, typ) - end - end - end if isa(arg, GlobalVariable) gname = LLVM.name(arg) @@ -166,6 +172,9 @@ function actual_size(@nospecialize(typ2)) return sizeof(Int) end else + if typ2 <: GenericMemory + return sum(map(sizeof,fieldtypes(typ2))) + end end if typ2 <: AbstractString || typ2 <: Symbol return sizeof(Int) @@ -213,9 +222,8 @@ function should_recurse(@nospecialize(typ2), arg_t, byref, dl) end end -function get_base_and_offset(larg::LLVM.Value)::Tuple{LLVM.Value, Int, Bool} +function get_base_and_offset(larg::LLVM.Value)::Tuple{LLVM.Value, Int} offset = 0 - error = false while true if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.AddrSpaceCastInst) larg = operands(larg)[1] @@ -235,22 +243,38 @@ function get_base_and_offset(larg::LLVM.Value)::Tuple{LLVM.Value, Int, Bool} if isa(larg, LLVM.Argument) break end - error = true break end - return larg, offset, error + return larg, offset end function abs_typeof( arg::LLVM.Value, - partial::Bool = false, + partial::Bool = false, seenphis=Set{LLVM.PHIInst}() )::Union{Tuple{Bool,Type,GPUCompiler.ArgumentCC},Tuple{Bool,Nothing,Nothing}} if isa(arg, LLVM.BitCastInst) || isa(arg, LLVM.AddrSpaceCastInst) - return abs_typeof(operands(arg)[1], partial) + return abs_typeof(operands(arg)[1], partial, seenphis) + end + if isa(arg, ConstantExpr) && value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked) + ce = arg + while isa(ce, ConstantExpr) + if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || + opcode(ce) == LLVM.API.LLVMBitCast || + opcode(ce) == LLVM.API.LLVMIntToPtr + ce = operands(ce)[1] + else + break + end + end + if isa(ce, LLVM.ConstantInt) + ptr = reinterpret(Ptr{Cvoid}, convert(UInt, ce)) + val = Base.unsafe_pointer_to_objref(ptr) + return (true, Core.Typeof(val), GPUCompiler.BITS_REF) + end end if isa(arg, ConstantExpr) if opcode(arg) == LLVM.API.LLVMAddrSpaceCast || opcode(arg) == LLVM.API.LLVMBitCast - return abs_typeof(operands(arg)[1], partial) + return abs_typeof(operands(arg)[1], partial, seenphis) end end @@ -262,11 +286,11 @@ function abs_typeof( end if nm == "julia.pointer_from_objref" - return abs_typeof(operands(arg)[1], partial) + return abs_typeof(operands(arg)[1], partial, seenphis) end if nm == "julia.gc_loaded" - legal, res, byref = abs_typeof(operands(arg)[2], partial) + legal, res, byref = abs_typeof(operands(arg)[2], partial, seenphis) return legal, res, byref end @@ -342,7 +366,7 @@ function abs_typeof( unionalls = [] legal = true for sarg in operands(arg)[index:end-1] - slegal, foundv, _ = abs_typeof(sarg, partial) + slegal, foundv, _ = abs_typeof(sarg, partial, seenphis) if slegal push!(found, foundv) elseif partial @@ -377,7 +401,7 @@ function abs_typeof( end resvals = [] while index != length(operands(arg)) - legal, pval, _ = abs_typeof(operands(arg)[index], partial) + legal, pval, _ = abs_typeof(operands(arg)[index], partial, seenphis) if !legal break end @@ -403,7 +427,7 @@ function abs_typeof( end if nm == "jl_array_copy" || nm == "ijl_array_copy" - legal, RT, _ = abs_typeof(operands(arg)[1], partial) + legal, RT, _ = abs_typeof(operands(arg)[1], partial, seenphis) if legal @assert RT <: Array return (legal, RT, GPUCompiler.MUT_REF) @@ -413,7 +437,7 @@ function abs_typeof( @static if VERSION < v"1.11-" else if nm == "jl_genericmemory_copy_slice" || nm == "ijl_genericmemory_copy_slice" - legal, RT, _ = abs_typeof(operands(arg)[1], partial) + legal, RT, _ = abs_typeof(operands(arg)[1], partial, seenphis) if legal @assert RT <: Memory return (legal, RT, GPUCompiler.MUT_REF) @@ -438,103 +462,124 @@ function abs_typeof( end if isa(arg, LLVM.LoadInst) - larg, offset, error = get_base_and_offset(operands(arg)[1]) - - if !error - legal, typ, byref = abs_typeof(larg) - if legal && (byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF) && Base.isconcretetype(typ) - @static if VERSION < v"1.11-" - if typ <: Array && Base.isconcretetype(typ) - T = eltype(typ) - if offset == 0 - return (true, Ptr{T}, GPUCompiler.BITS_VALUE) - else - return (true, Int, GPUCompiler.BITS_VALUE) + if isa(operands(arg)[1], LLVM.ConstantExpr) && isa(value_type(arg), LLVM.PointerType) && addrspace(value_type(arg)) == Tracked + ce = operands(arg)[1] + while isa(ce, ConstantExpr) + if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || + opcode(ce) == LLVM.API.LLVMBitCast || + opcode(ce) == LLVM.API.LLVMIntToPtr + ce = operands(ce)[1] + else + break + end + end + if isa(ce, LLVM.ConstantInt) + ptr = unsafe_load(reinterpret(Ptr{Ptr{Cvoid}}, convert(UInt, ce))) + if ptr != C_NULL + obj = Base.unsafe_pointer_to_objref(ptr) + return (true, Core.Typeof(obj), GPUCompiler.BITS_REF) + end + end + end + + larg, offset = get_base_and_offset(operands(arg)[1]) + legal, typ, byref = abs_typeof(larg, false, seenphis) + + if legal && (byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF) && Base.isconcretetype(typ) + @static if VERSION < v"1.11-" + if typ <: Array && Base.isconcretetype(typ) + T = eltype(typ) + if offset == 0 + if !allocatedinline(T) && Base.isconcretetype(T) + T = Ptr{T} end + return (true, Ptr{T}, GPUCompiler.BITS_VALUE) + else + return (true, Int, GPUCompiler.BITS_VALUE) end end - if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF - dl = LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(arg)))) - - byref = GPUCompiler.BITS_VALUE - legal = true + end + if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF + dl = LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(arg)))) + + byref = GPUCompiler.BITS_VALUE + legal = true - while offset != 0 && legal - @assert Base.isconcretetype(typ) - seen = false - lasti = 1 - for i = 1:fieldcount(typ) - fo = fieldoffset(typ, i) - if fieldoffset(typ, i) == offset - offset = 0 - typ = typed_fieldtype(typ, i) - if !Base.allocatedinline(typ) - if byref != GPUCompiler.BITS_VALUE - legal = false - end - byref = GPUCompiler.MUT_REF - end - seen = true - break - elseif fieldoffset(typ, i) > offset - offset = offset - fieldoffset(typ, lasti) - typ = typed_fieldtype(typ, lasti) - @assert Base.isconcretetype(typ) - if !Base.allocatedinline(typ) + while offset != 0 && legal + @assert Base.isconcretetype(typ) + seen = false + lasti = 1 + for i = 1:fieldcount(typ) + fo = fieldoffset(typ, i) + if fieldoffset(typ, i) == offset + offset = 0 + typ = typed_fieldtype(typ, i) + if !Base.allocatedinline(typ) + if byref != GPUCompiler.BITS_VALUE legal = false end - seen = true - break + byref = GPUCompiler.MUT_REF end - - if fo != 0 && fo != fieldoffset(typ, i-1) - lasti = i + seen = true + break + elseif fieldoffset(typ, i) > offset + offset = offset - fieldoffset(typ, lasti) + typ = typed_fieldtype(typ, lasti) + @assert Base.isconcretetype(typ) + if !Base.allocatedinline(typ) + legal = false end + seen = true + break + end + + if fo != 0 && fo != fieldoffset(typ, i-1) + lasti = i end - if !seen && fieldcount(typ) > 0 - offset = offset - fieldoffset(typ, lasti) - typ = typed_fieldtype(typ, lasti) - @assert Base.isconcretetype(typ) - if !Base.allocatedinline(typ) - legal = false - end - seen = true - end - if !seen + end + if !seen && fieldcount(typ) > 0 + offset = offset - fieldoffset(typ, lasti) + typ = typed_fieldtype(typ, lasti) + @assert Base.isconcretetype(typ) + if !Base.allocatedinline(typ) legal = false end + seen = true end - - typ2 = typ - while legal && should_recurse(typ2, value_type(arg), byref, dl) - idx, _ = first_non_ghost(typ2) - if idx != -1 - typ2 = typed_fieldtype(typ2, idx) - if Base.allocatedinline(typ2) - if byref == GPUCompiler.BITS_VALUE - continue - end + if !seen + legal = false + end + end + + typ2 = typ + while legal && should_recurse(typ2, value_type(arg), byref, dl) + idx, _ = first_non_ghost(typ2) + if idx != -1 + typ2 = typed_fieldtype(typ2, idx) + if Base.allocatedinline(typ2) + if byref == GPUCompiler.BITS_VALUE + continue + end + legal = false + break + else + if byref != GPUCompiler.BITS_VALUE legal = false break - else - if byref != GPUCompiler.BITS_VALUE - legal = false - break - end - byref = GPUCompiler.MUT_REF - continue end + byref = GPUCompiler.MUT_REF + continue end - legal = false - break - end - if legal - return (true, typ2, byref) end + legal = false + break + end + if legal + return (true, typ2, byref) end - elseif legal && typ <: Ptr && Base.isconcretetype(typ) - return (true, eltype(typ), GPUCompiler.BITS_VALUE) end + elseif legal && typ <: Ptr && Base.isconcretetype(typ) + return (true, eltype(typ), GPUCompiler.BITS_VALUE) end end @@ -543,7 +588,7 @@ function abs_typeof( indptrs = LLVM.API.LLVMGetIndices(arg) numind = LLVM.API.LLVMGetNumIndices(arg) offset = Cuint[unsafe_load(indptrs, i) for i = 1:numind] - found, typ, byref = abs_typeof(larg, partial) + found, typ, byref = abs_typeof(larg, partial, seenphis) if !found return (false, nothing, nothing) end @@ -571,6 +616,67 @@ function abs_typeof( end end + if isa(arg, LLVM.PHIInst) + if arg in seenphis + return (false, nothing, nothing) + end + todo = LLVM.PHIInst[arg] + ops = LLVM.Value[] + seen = Set{LLVM.PHIInst}() + legal = true + while length(todo) > 0 + cur = pop!(todo) + if cur in seen + continue + end + push!(seen, cur) + for (v, _) in LLVM.incoming(cur) + v2, off = get_base_and_offset(v) + if off != 0 + if isa(v, LLVM.Instruction) && arg in collect(operands(v)) + legal = false + break + end + push!(ops, v) + elseif v2 isa LLVM.PHIInst + push!(todo, v2) + else + if isa(v2, LLVM.Instruction) && arg in collect(operands(v2)) + legal = false + break + end + push!(ops, v2) + end + end + end + if legal + resvals = nothing + seenphis2 = copy(seenphis) + push!(seenphis2, arg) + for op in ops + tmp = abs_typeof(op, partial, seenphis2) + if resvals == nothing + resvals = tmp + else + if tmp[1] == false || resvals[1] == false + resvals = (false, nothing, nothing) + break + elseif tmp[2] == resvals[2] && ( tmp[3] == resvals[3] || ( in(tmp[3],(GPUCompiler.BITS_REF, GPUCompiler.MUT_REF)) && in(resvals[3],(GPUCompiler.BITS_REF, GPUCompiler.MUT_REF))) ) + + continue + elseif partial + resvals = (true, Union{resvals[2], tmp[2]}, GPUCompiler.BITS_REF) + else + resvals = (false, nothing, nothing) + break + end + end + end + if resvals != nothing + return resvals + end + end + end if isa(arg, LLVM.Argument) f = LLVM.Function(LLVM.API.LLVMGetParamParent(arg)) diff --git a/src/compiler.jl b/src/compiler.jl index 22f0c21b13..c4425131ef 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1607,6 +1607,7 @@ function julia_error( end legal, TT, byref = abs_typeof(cur, true) + if legal if guaranteed_const_nongen(TT, world) return make_batched(ncur, prevbb) @@ -1644,6 +1645,19 @@ else end end +@static if VERSION < v"1.11-" +else + if isa(cur, LLVM.LoadInst) + larg, off = get_base_and_offset(operands(cur)[1]) + if isa(larg, LLVM.LoadInst) + legal2, obj = absint(larg) + if legal2 && obj isa Memory && obj == typeof(obj).instance + return make_batched(ncur, prevbb) + end + end + end +end + badval = if legal2 string(obj) * " of type" * " " * string(TT) else @@ -1855,6 +1869,14 @@ end push!(created, phi2) return phi2 end + + tt = TypeTree(API.EnzymeGradientUtilsAllocAndGetTypeTree(gutils, cur)) + st = API.EnzymeTypeTreeToString(tt) + st2 = Base.unsafe_string(st) + API.EnzymeStringFree(st) + if st2 == "{[-1]:Integer}" + return make_batched(ncur, prevbb) + end illegal = true illegalVal = cur @@ -6855,41 +6877,49 @@ end fn = isa(inst, LLVM.CallInst) ? LLVM.called_operand(inst) : nothing if !API.HasFromStack(inst) && - isa(inst, LLVM.CallInst) && - (!isa(fn, LLVM.Function) || isempty(blocks(fn))) + ((isa(inst, LLVM.CallInst) && + (!isa(fn, LLVM.Function) || isempty(blocks(fn))) ) || isa(inst, LLVM.LoadInst)) legal, source_typ, byref = abs_typeof(inst) codegen_typ = value_type(inst) if legal - typ = if codegen_typ isa LLVM.PointerType + typ = if codegen_typ isa LLVM.PointerType || codegen_typ isa LLVM.IntegerType llvm_source_typ = convert(LLVMType, source_typ; allow_boxed = true) # pointers are used for multiple kinds of arguments # - literal pointer values - if source_typ <: Ptr || source_typ <: Core.LLVMPtr + if byref == GPUCompiler.BITS_VALUE source_typ elseif byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF Ptr{source_typ} else - # println(string(mod)) - println(string(f)) - @show legal, source_typ, byref, llvm_source_typ, codegen_typ, string(inst) - @show enzyme_custom_extract_mi(f) - @assert false + msg = sprint() do io + println(io, "Enzyme illegal state") + println(io, string(f)) + println(io, "legal=", legal) + println(io, "source_typ=", source_typ) + println(io, "byref=", byref) + println(io, "llvm_source_typ=", llvm_source_typ) + println(io, "codegen_typ=", codegen_typ) + println(io, "inst=", string(inst)) + println(io, enzyme_custom_extract_mi(f)) + end + throw(AssertionError(msg)) end else source_typ end + ec = typetree(typ, ctx, dl, seen) if isa(inst, LLVM.CallInst) LLVM.API.LLVMAddCallSiteAttribute( inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute( "enzyme_type", - string(typetree(typ, ctx, dl, seen)), + string(ec), ), ) else - metadata(inst)["enzyme_type"] = to_md(typetree(typ, ctx, dl, seen), ctx) + metadata(inst)["enzyme_type"] = to_md(ec, ctx) end elseif codegen_typ == T_prjlvalue if isa(inst, LLVM.CallInst) @@ -6918,7 +6948,7 @@ end if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id || intr == LLVM.Intrinsic("llvm.memset").id - base, offset, _ = get_base_and_offset(operands(inst)[1]) + base, offset = get_base_and_offset(operands(inst)[1]) legal, jTy, byref = abs_typeof(base) sz = if intr == LLVM.Intrinsic("llvm.memcpy").id || @@ -6959,6 +6989,7 @@ end if !legal continue end + if !guaranteed_const_nongen(jTy, world) continue end @@ -8029,8 +8060,11 @@ end push!(sret_types, Nothing) elseif rettype <: Const else - @show rettype, CC - @assert false + msg = sprint() do io + println(io, "rettype=", rettype) + println(io, "CC=", CC) + end + throw(AssertionError(msg)) end end diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index dc26d140bb..69fdfa18f5 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -627,6 +627,7 @@ function nodecayed_phis!(mod::LLVM.Module) # Simple handler to fix addrspace 11 #complex handler for addrspace 13, which itself comes from a load of an # addrspace 10 + ctx = LLVM.context(mod) for f in functions(mod) guaranteedInactive = false @@ -826,17 +827,25 @@ function nodecayed_phis!(mod::LLVM.Module) v2, o2, hl2 = getparent(operands(ld)[1], LLVM.ConstantInt(offty, 0), true) rhs = LLVM.ConstantInt(offty, sizeof(Int)) - base_2, off_2, _ = get_base_and_offset(v2) - base_1, off_1, _ = get_base_and_offset(operands(v)[1]) + base_2, off_2 = get_base_and_offset(v2) + base_1, off_1 = get_base_and_offset(operands(v)[1]) if o2 == rhs && base_1 == base_2 && off_1 == off_2 return operands(v)[1], offset, true end + pty = TypeTree(API.DT_Pointer, LLVM.context(ld)) + only!(pty, -1) rhs = ptrtoint!(b, get_memory_data(b, operands(v)[1]), offty) + metadata(rhs)["enzyme_type"] = to_md(pty, ctx) lhs = ptrtoint!(b, operands(v)[2], offty) + metadata(rhs)["enzyme_type"] = to_md(pty, ctx) off2 = nuwsub!(b, lhs, rhs) + ity = TypeTree(API.DT_Integer, LLVM.context(ld)) + only!(ity, -1) + metadata(off2)["enzyme_type"] = to_md(ity, ctx) add = nuwadd!(b, offset, off2) + metadata(add)["enzyme_type"] = to_md(ity, ctx) return operands(v)[1], add, true end end diff --git a/src/typetree.jl b/src/typetree.jl index 8224b98952..adf3738625 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -254,7 +254,7 @@ function typetree_inner( dl, seen::TypeTreeTable, ) where {T} - tt = copy(typetree(T, ctx, dl, seen)) + tt = copy(typetree(T == UInt8 ? Nothing : T, ctx, dl, seen)) merge!(tt, TypeTree(API.DT_Pointer, ctx)) only!(tt, -1) return tt diff --git a/test/Project.toml b/test/Project.toml index 3ce8fc645c..818e0ac708 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -24,4 +24,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Aqua = "0.8" -EnzymeTestUtils = "0.1.4, 0.2" +EnzymeTestUtils = "0.2.1" From 4d28fa699c488ef56e79e392a9eccfcf3fefd0ca Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 1 Nov 2024 08:37:19 -0500 Subject: [PATCH 404/495] Update Project.toml (#2040) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ed748bdc77..311fbbbe29 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.4, 0.8.5" -Enzyme_jll = "0.0.158" +Enzyme_jll = "0.0.159" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From 04dcea79851256ce71ea480cf814af707ac21327 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 3 Nov 2024 12:44:37 -0600 Subject: [PATCH 405/495] More 1.11 work (#2044) * More 1.11 work * fix * fix * fix * fix --- src/absint.jl | 38 +++++++++--- src/compiler.jl | 121 +++++++++++++++++++++++++++---------- src/compiler/optimize.jl | 2 +- src/compiler/utils.jl | 33 +++++++--- src/compiler/validation.jl | 47 +++++++------- test/runtests.jl | 2 +- 6 files changed, 170 insertions(+), 73 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 67c2170ce3..05723115e9 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -222,10 +222,24 @@ function should_recurse(@nospecialize(typ2), arg_t, byref, dl) end end -function get_base_and_offset(larg::LLVM.Value)::Tuple{LLVM.Value, Int} +function get_base_and_offset(larg::LLVM.Value; offsetAllowed=true, inttoptr=false)::Tuple{LLVM.Value, Int} offset = 0 while true - if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.AddrSpaceCastInst) + if isa(larg, LLVM.ConstantExpr) + if opcode(larg) == LLVM.API.LLVMBitCast || opcode(larg) == LLVM.API.LLVMAddrSpaceCast || opcode(larg) == LLVM.API.LLVMPtrToInt + larg = operands(larg)[1] + continue + end + if inttoptr && opcode(larg) == LLVM.API.LLVMIntToPtr + larg = operands(larg)[1] + continue + end + end + if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.AddrSpaceCastInst) || isa(larg, LLVM.IntToPtrInst) + larg = operands(larg)[1] + continue + end + if inttoptr && isa(larg, LLVM.PtrToIntInst) larg = operands(larg)[1] continue end @@ -235,10 +249,18 @@ function get_base_and_offset(larg::LLVM.Value)::Tuple{LLVM.Value, Int} position!(b, larg) offty = LLVM.IntType(8 * sizeof(Int)) offset2 = API.EnzymeComputeByteOffsetOfGEP(b, larg, offty) - @assert isa(offset2, LLVM.ConstantInt) - offset += convert(Int, offset2) - larg = operands(larg)[1] - continue + if isa(offset2, LLVM.ConstantInt) + val = convert(Int, offset2) + if offsetAllowed || val == 0 + offset += val + larg = operands(larg)[1] + continue + else + break + end + else + break + end end if isa(larg, LLVM.Argument) break @@ -429,7 +451,9 @@ function abs_typeof( if nm == "jl_array_copy" || nm == "ijl_array_copy" legal, RT, _ = abs_typeof(operands(arg)[1], partial, seenphis) if legal - @assert RT <: Array + if !(RT <: Array) + return (false, nothing, nothing) + end return (legal, RT, GPUCompiler.MUT_REF) end return (legal, RT, nothing) diff --git a/src/compiler.jl b/src/compiler.jl index c4425131ef..69ceeabc05 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -423,6 +423,7 @@ const inactiveglobs = Set{String}(( "jl_boxed_uint8_cache", "ijl_boxed_int8_cache", "jl_boxed_int8_cache", + "jl_nothing", )) @enum ActivityState begin @@ -1104,6 +1105,38 @@ struct Return2 ret2::Any end +function force_recompute!(mod::LLVM.Module) + for f in functions(mod), bb in blocks(f), inst in instructions(bb) + if isa(inst, LLVM.LoadInst) + has_loaded = false + for u in LLVM.uses(inst) + v = LLVM.user(u) + if isa(v, LLVM.CallInst) + cf = LLVM.called_operand(v) + if isa(cf, LLVM.Function) && LLVM.name(cf) == "julia.gc_loaded" && operands(v)[2] == inst + has_loaded = true + break + end + end + if isa(v, LLVM.BitCastInst) + for u2 in LLVM.uses(v) + v2 = LLVM.user(u2) + if isa(v2, LLVM.CallInst) + cf = LLVM.called_operand(v2) + if isa(cf, LLVM.Function) && LLVM.name(cf) == "julia.gc_loaded" && operands(v2)[2] == v + has_loaded = true + break + end + end + end + end + end + if has_loaded + metadata(inst)["enzyme_nocache"] = MDNode(LLVM.Metadata[]) + end + end + end +end function permit_inlining!(f::LLVM.Function) for bb in blocks(f), inst in instructions(bb) # remove illegal invariant.load and jtbaa_const invariants @@ -3317,6 +3350,7 @@ function annotate!(mod, mode) if haskey(fns, fname) fn = fns[fname] push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_nocache")) end end @@ -3378,6 +3412,23 @@ function annotate!(mod, mode) end end + for boxfn in ( + "julia.gc_alloc_obj", + "jl_gc_alloc_typed", + "ijl_gc_alloc_typed", + ) + if haskey(fns, boxfn) + fn = fns[boxfn] + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + fn, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + kind(EnumAttribute("allockind")), + ) + push!(function_attributes(fn), no_escaping_alloc) + push!(function_attributes(fn), LLVM.EnumAttribute("allockind", (AllocFnKind(AFKE_Alloc) | AllocFnKind(AFKE_Uninitialized)).data)) + end + end + for boxfn in ( "julia.gc_alloc_obj", "jl_gc_alloc_typed", @@ -3417,24 +3468,39 @@ function annotate!(mod, mode) fn = fns[boxfn] push!(return_attributes(fn), LLVM.EnumAttribute("noalias", 0)) push!(function_attributes(fn), no_escaping_alloc) + push!(function_attributes(fn), LLVM.EnumAttribute("mustprogress")) + push!(function_attributes(fn), LLVM.EnumAttribute("willreturn")) + push!(function_attributes(fn), LLVM.EnumAttribute("nounwind")) + push!(function_attributes(fn), LLVM.EnumAttribute("nofree")) accattr = if LLVM.version().major <= 15 LLVM.EnumAttribute("inaccessiblememonly") else - EnumAttribute( - "memory", - MemoryEffect( - (MRI_NoModRef << getLocationPos(ArgMem)) | - (MRI_ModRef << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ) + if boxfn in ( + "jl_genericmemory_copy_slice", + "ijl_genericmemory_copy_slice",) + EnumAttribute( + "memory", + MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_ModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) + else + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_ModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) + end end if !( boxfn in ( "jl_array_copy", "ijl_array_copy", - "jl_genericmemory_copy_slice", - "ijl_genericmemory_copy_slice", "jl_idtable_rehash", "ijl_idtable_rehash", ) @@ -3457,8 +3523,6 @@ function annotate!(mod, mode) boxfn in ( "jl_array_copy", "ijl_array_copy", - "jl_genericmemory_copy_slice", - "ijl_genericmemory_copy_slice", "jl_idtable_rehash", "ijl_idtable_rehash", ) @@ -3476,10 +3540,8 @@ function annotate!(mod, mode) if !isa(cf, LLVM.Function) continue end - if LLVM.name(cf) != "julia.call" && LLVM.name(cf) != "julia.call2" - continue - end - if operands(c)[1] != fn + if !(cf == fn || + ((LLVM.name(cf) == "julia.call" || LLVM.name(cf) != "julia.call2") && operands(c)[1] == fn)) continue end LLVM.API.LLVMAddCallSiteAttribute( @@ -3499,31 +3561,17 @@ function annotate!(mod, mode) boxfn in ( "jl_array_copy", "ijl_array_copy", - "jl_genericmemory_copy_slice", - "ijl_genericmemory_copy_slice", "jl_idtable_rehash", "ijl_idtable_rehash", ) ) - attr = if LLVM.version().major <= 15 - LLVM.EnumAttribute("inaccessiblememonly") - else - EnumAttribute( - "memory", - MemoryEffect( - (MRI_NoModRef << getLocationPos(ArgMem)) | - (MRI_ModRef << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ) - end LLVM.API.LLVMAddCallSiteAttribute( c, reinterpret( LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex, ), - attr, + accattr, ) end end @@ -6920,6 +6968,16 @@ end ) else metadata(inst)["enzyme_type"] = to_md(ec, ctx) + +@static if VERSION < v"1.11-" +else + legal2, obj = absint(inst) + if legal2 obj isa Memory && obj == typeof(obj).instance + metadata(inst)["nonnull"] = MDNode(LLVM.Metadata[]) + end +end + + end elseif codegen_typ == T_prjlvalue if isa(inst, LLVM.CallInst) @@ -7152,6 +7210,7 @@ end if params.run_enzyme # Generate the adjoint memcpy_alloca_to_loadstore(mod) + force_recompute!(mod) adjointf, augmented_primalf, TapeType = enzyme!( job, diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 69fdfa18f5..1a4d450074 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -716,7 +716,7 @@ function nodecayed_phis!(mod::LLVM.Module) while length(addrtodo) != 0 v = pop!(addrtodo) - base = get_base_object(v) + base, _ = get_base_and_offset(v; offsetAllowed=false) if in(base, seen) continue end diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index 5539c5ed06..09ff90a50a 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -1,7 +1,30 @@ + +@enum(AllocFnKindEnum, + AFKE_Unknown = 0, + AFKE_Alloc = 1, + AFKE_Realloc = 2, + AFKE_Free = 4, + AFKE_Uninitialized = 8, + AFKE_Zeroed = 16, + AFKE_Aligned = 32, +) + +struct AllocFnKind + data::UInt32 + AllocFnKind() = new(0) + AllocFnKind(x::UInt32) = new(x) + AllocFnKind(x::AllocFnKindEnum) = new(UInt32(x)) +end + +function Base.:|(lhs::AllocFnKind, rhs::AllocFnKind) + AllocFnKind(UInt32(lhs.data) | UInt32(rhs.data)) +end + struct MemoryEffect data::UInt32 end + @enum(ModRefInfo, MRI_NoModRef = 0, MRI_Ref = 1, MRI_Mod = 2, MRI_ModRef = 3) @enum(IRMemLocation, ArgMem = 0, InaccessibleMem = 1, Other = 2) @@ -277,16 +300,6 @@ end T_ppjlvalue() = LLVM.PointerType(LLVM.PointerType(LLVM.StructType(LLVMType[]))) -@inline function get_base_object(v) - if isa(v, LLVM.AddrSpaceCastInst) || isa(v, LLVM.BitCastInst) - return get_base_object(operands(v)[1]) - end - if isa(v, LLVM.GetElementPtrInst) - return get_base_object(operands(v)[1]) - end - return v -end - function declare_pgcstack!(mod) get_function!( mod, diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 4f341ac3f5..ee2a7120fe 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -417,24 +417,6 @@ function check_ir!(job, errors, mod::LLVM.Module) return errors end - -function unwrap_ptr_casts(val::LLVM.Value) - while true - is_simple_cast = false - is_simple_cast |= isa(val, LLVM.BitCastInst) - is_simple_cast |= isa(val, LLVM.AddrSpaceCastInst) || isa(val, LLVM.PtrToIntInst) - is_simple_cast |= isa(val, LLVM.ConstantExpr) && opcode(val) == LLVM.API.LLVMAddrSpaceCast - is_simple_cast |= isa(val, LLVM.ConstantExpr) && opcode(val) == LLVM.API.LLVMIntToPtr - is_simple_cast |= isa(val, LLVM.ConstantExpr) && opcode(val) == LLVM.API.LLVMBitCast - - if !is_simple_cast - return val - else - val = operands(val)[1] - end - end -end - function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns) calls = [] isInline = API.EnzymeGetCLBool(cglobal((:EnzymeInline, API.libEnzyme))) != 0 @@ -445,7 +427,7 @@ function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns) # remove illegal invariant.load and jtbaa_const invariants elseif isa(inst, LLVM.LoadInst) - fn_got = unwrap_ptr_casts(operands(inst)[1]) + fn_got, _ = get_base_and_offset(operands(inst)[1]; offsetAllowed=false, inttoptr=false) fname = String(name(fn_got)) match_ = match(r"^jlplt_(.*)_\d+_got$", fname) @@ -473,7 +455,7 @@ function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns) @assert FT !== nothing newf, _ = get_function!(mod, String(fname), FT) - initfn = unwrap_ptr_casts(LLVM.initializer(fn_got)) + initfn, _ = get_base_and_offset(LLVM.initializer(fn_got); offsetAllowed=false, inttoptr=false) loadfn = first(instructions(first(blocks(initfn))))::LLVM.LoadInst opv = operands(loadfn)[1] if !isa(opv, LLVM.GlobalVariable) @@ -902,10 +884,27 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) fname = ops[2] if isa(flib, LLVM.LoadInst) - op = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(flib, 0)) - while isa(op, LLVM.ConstantExpr) - op = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(op, 0)) + op, _ = get_base_and_offset(operands(flib)[1]; offsetAllowed=false, inttoptr=true) + + if isa(op, LLVM.LoadInst) + pop, _ = get_base_and_offset(operands(op)[1]; offsetAllowed=false, inttoptr=true) + + if isa(pop, LLVM.GlobalVariable) + zop, _ = get_base_and_offset(LLVM.initializer(pop); offsetAllowed=false, inttoptr=true) + + rep = zop + PT = value_type(rep) + if isa(PT, LLVM.PointerType) + rep = LLVM.const_inttoptr(rep, LLVM.PointerType(eltype(PT))) + rep = LLVM.const_addrspacecast(rep, PT) + replace_uses!(pop, rep) + LLVM.API.LLVMInstructionEraseFromParent(pop) + end + + op = zop + end end + if isa(op, ConstantInt) rep = reinterpret(Ptr{Cvoid}, convert(Csize_t, op) + 8) ld = unsafe_load(convert(Ptr{Ptr{Cvoid}}, rep)) @@ -923,6 +922,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) if isa(fname, LLVM.GlobalVariable) fname = LLVM.initializer(fname) end + if (isa(fname, LLVM.ConstantArray) || isa(fname, LLVM.ConstantDataArray)) && eltype(value_type(fname)) == LLVM.IntType(8) fname = String(map((x) -> convert(UInt8, x), collect(fname)[1:(end-1)])) @@ -1017,6 +1017,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) fname, ) end + replaceWith = LLVM.ConstantInt(LLVM.IntType(8 * sizeof(Int)), reinterpret(UInt, res)) for u in LLVM.uses(inst) diff --git a/test/runtests.jl b/test/runtests.jl index d28461d26a..31dcbd3e66 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -176,7 +176,7 @@ end @static if VERSION < v"1.11-" @test typeof(res[1]) == Tuple{Float64, Float64} else - @test typeof(res[1]) == NamedTuple{(Symbol("1"),Symbol("2"),Symbol("3"),Symbol("4"),Symbol("5"),Symbol("6")), Tuple{Any, Core.LLVMPtr{UInt8, 0}, Any, Core.LLVMPtr{Any, 0}, Float64, Float64}} + @test typeof(res[1]) == NamedTuple{(Symbol("1"),Symbol("2"),Symbol("3")), Tuple{Any, Float64, Float64}} end pullback(Const(mul2), d, 1.0, res[1]) From 5bf2cfd25227aaa3f4cf24291799466554aa4ca3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 3 Nov 2024 14:08:30 -0600 Subject: [PATCH 406/495] Update Project.toml (#2046) --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 311fbbbe29..1fa4a49186 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.13" +version = "0.13.14" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.4, 0.8.5" -Enzyme_jll = "0.0.159" +Enzyme_jll = "0.0.160" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From 1f0ff430ad48e1ac44711773bc40024f53bcd392 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 3 Nov 2024 15:22:02 -0600 Subject: [PATCH 407/495] Continuing 1.11 names (#2045) * Continuing 1.11 names * cleanup * cleanup --- src/absint.jl | 57 +++++++++++++++----------------------- src/compiler/validation.jl | 24 ++-------------- 2 files changed, 24 insertions(+), 57 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 05723115e9..d693102eda 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -124,7 +124,7 @@ function absint(arg::LLVM.Value, partial::Bool = false) if isa(arg, GlobalVariable) gname = LLVM.name(arg) for (k, v) in JuliaGlobalNameMap - if gname == k || gname == "ejl_" * k + if gname == "ejl_" * k return (true, v) end end @@ -138,14 +138,13 @@ function absint(arg::LLVM.Value, partial::Bool = false) if isa(arg, LLVM.LoadInst) && value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked) ptr = operands(arg)[1] - ce = ptr - while isa(ce, ConstantExpr) - if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || - opcode(ce) == LLVM.API.LLVMBitCast || - opcode(ce) == LLVM.API.LLVMIntToPtr - ce = operands(ce)[1] - else - break + ce, _ = get_base_and_offset(ptr; offsetAllowed=false, inttoptr=true) + if isa(ce, GlobalVariable) + gname = LLVM.name(ce) + for (k, v) in JuliaGlobalNameMap + if gname == k + return (true, v) + end end end if !isa(ce, LLVM.ConstantInt) @@ -278,14 +277,13 @@ function abs_typeof( return abs_typeof(operands(arg)[1], partial, seenphis) end if isa(arg, ConstantExpr) && value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked) - ce = arg - while isa(ce, ConstantExpr) - if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || - opcode(ce) == LLVM.API.LLVMBitCast || - opcode(ce) == LLVM.API.LLVMIntToPtr - ce = operands(ce)[1] - else - break + ce, _ = get_base_and_offset(arg; offsetAllowed=false, inttoptr=true) + if isa(ce, GlobalVariable) + gname = LLVM.name(ce) + for (k, v) in JuliaGlobalNameMap + if gname == k + return (true, Core.Typeof(v), GPUCompiler.BITS_REF) + end end end if isa(ce, LLVM.ConstantInt) @@ -485,27 +483,16 @@ function abs_typeof( end end - if isa(arg, LLVM.LoadInst) - if isa(operands(arg)[1], LLVM.ConstantExpr) && isa(value_type(arg), LLVM.PointerType) && addrspace(value_type(arg)) == Tracked - ce = operands(arg)[1] - while isa(ce, ConstantExpr) - if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || - opcode(ce) == LLVM.API.LLVMBitCast || - opcode(ce) == LLVM.API.LLVMIntToPtr - ce = operands(ce)[1] - else - break - end - end - if isa(ce, LLVM.ConstantInt) - ptr = unsafe_load(reinterpret(Ptr{Ptr{Cvoid}}, convert(UInt, ce))) - if ptr != C_NULL - obj = Base.unsafe_pointer_to_objref(ptr) - return (true, Core.Typeof(obj), GPUCompiler.BITS_REF) + if isa(arg, LLVM.LoadInst) + ce, _ = get_base_and_offset(operands(arg)[1]; offsetAllowed=false, inttoptr=true) + if isa(ce, GlobalVariable) + gname = LLVM.name(ce) + for (k, v) in JuliaGlobalNameMap + if gname == k + return (true, Core.Typeof(v), GPUCompiler.BITS_REF) end end end - larg, offset = get_base_and_offset(operands(arg)[1]) legal, typ, byref = abs_typeof(larg, false, seenphis) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index ee2a7120fe..5216244e52 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -508,17 +508,7 @@ function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns) legal1, arg1 = abs_cstring(operands(found)[1]) if legal1 else - arg1 = operands(found)[1] - - while isa(arg1, ConstantExpr) - if opcode(arg1) == LLVM.API.LLVMAddrSpaceCast || - opcode(arg1) == LLVM.API.LLVMBitCast || - opcode(arg1) == LLVM.API.LLVMIntToPtr - arg1 = operands(arg1)[1] - else - break - end - end + arg1, _ = get_base_and_offset(operands(found)[1]; offsetAllowed=false, inttoptr=true) if !isa(arg1, LLVM.ConstantInt) msg = sprint() do io::IO println( @@ -756,17 +746,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) ofn = LLVM.parent(LLVM.parent(inst)) mod = LLVM.parent(ofn) - arg1 = operands(inst)[1] - - while isa(arg1, ConstantExpr) - if opcode(arg1) == LLVM.API.LLVMAddrSpaceCast || - opcode(arg1) == LLVM.API.LLVMBitCast || - opcode(arg1) == LLVM.API.LLVMIntToPtr - arg1 = operands(arg1)[1] - else - break - end - end + arg1, _ = get_base_and_offset(operands(inst)[1]; offsetAllowed=false, inttoptr=true) if isa(arg1, LLVM.ConstantInt) arg1 = reinterpret(Ptr{Cvoid}, convert(UInt, arg1)) legal2, fname = abs_cstring(operands(inst)[2]) From cb4c695b5807bbe3a467741eaaa593c7455b5e0d Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 3 Nov 2024 15:22:23 -0600 Subject: [PATCH 408/495] Optimize undef arg (#2047) --- src/compiler/optimize.jl | 40 +++++++++++++++------- test/optimize.jl | 71 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 11 deletions(-) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 1a4d450074..2967d5d6c0 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -1836,7 +1836,7 @@ function propagate_returned!(mod::LLVM.Module) LLVM.replace_uses!(arg, val) end end - # sese if there are no users of the value (excluding recursive/return) + # see if there are no users of the value (excluding recursive/return) baduse = false for u in LLVM.uses(arg) u = LLVM.user(u) @@ -1923,13 +1923,14 @@ function propagate_returned!(mod::LLVM.Module) end end for (fn, keepret, toremove) in tofinalize - try - todo = LLVM.CallInst[] - for u in LLVM.uses(fn) - un = LLVM.user(u) - push!(next, LLVM.name(LLVM.parent(LLVM.parent(un)))) - end - delete_writes_into_removed_args(fn, toremove) + todo = LLVM.CallInst[] + for u in LLVM.uses(fn) + un = LLVM.user(u) + push!(next, LLVM.name(LLVM.parent(LLVM.parent(un)))) + end + delete_writes_into_removed_args(fn, toremove) + nm = LLVM.name(fn) + #try nfn = LLVM.Function( API.EnzymeCloneFunctionWithoutReturnOrArgs(fn, keepret, toremove), ) @@ -1946,9 +1947,9 @@ function propagate_returned!(mod::LLVM.Module) end eraseInst(mod, fn) changed = true - catch - break - end + # catch e + # break + #end end if !changed break @@ -2000,6 +2001,23 @@ function delete_writes_into_removed_args(fn::LLVM.Function, toremove) end continue end + if isa(cur, LLVM.CallInst) + cf = LLVM.called_operand(cur) + if cf == fn + baduse = false + for (i, v) in enumerate(operands(cur)) + if i-1 in toremove + continue + end + if v == cval + baduse = true + end + end + if !baduse + continue + end + end + end throw(AssertionError("Deleting argument with an unknown dependency, $(string(cur)) uses $(string(cval))")) end end diff --git a/test/optimize.jl b/test/optimize.jl index a4fcc1768f..7bc89ecc74 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -1,4 +1,5 @@ using Enzyme, LinearAlgebra, Test +using Random, Statistics function gcloaded_fixup(dest, src) N = size(src) @@ -44,3 +45,73 @@ end gcloaded_fixup(dest, H) @test dest ≈ [4.0 2.0; 2.0 5.0] end + +struct MyNormal + sigma::Float64 + off::Float64 +end + +struct MvLocationScale{ + S, D, L +} + location ::L + scale ::S + dist ::D +end + +@noinline function law(dist, flat::AbstractVector) + ccall(:jl_, Cvoid, (Any,), flat) + n_dims = div(length(flat), 2) + data = first(flat, n_dims) + scale = Diagonal(data) + return MvLocationScale(nothing, scale, dist) +end + +function destructure(q::MvLocationScale) + return diag(q.scale) +end + + +myxval(d::MyNormal, z::Real) = muladd(d.sigma, z, d.off) + +function myrand!(rng::AbstractRNG, d::MyNormal, A::AbstractArray{<:Real}) + # randn!(rng, A) + map!(Base.Fix1(myxval, d), A, A) + return A +end + +function man(q::MvLocationScale) + dist = MyNormal(1.0, 0.0) + + out = ones(2,3) # Array{Float64}(undef, (2,3)) + @inbounds myrand!(Random.default_rng(), dist, out) + + return q.scale[1] * out +end + +function estimate_repgradelbo_ad_forward(params, dist) + q = law(dist, params) + samples = man(q) + mean(samples) +end + +@testset "Removed undef arguments" begin + T = Float64 + d = 2 + dist = MyNormal(1.0, 0.0) + q = MvLocationScale(zeros(T, d), Diagonal(ones(T, d)), dist) + params = destructure(q) + + ∇x = zero(params) + fill!(∇x, zero(eltype(∇x))) + + estimate_repgradelbo_ad_forward(params, dist) + + Enzyme.autodiff( + set_runtime_activity(Enzyme.ReverseWithPrimal), + estimate_repgradelbo_ad_forward, + Enzyme.Active, + Enzyme.Duplicated(params, ∇x), + Enzyme.Const(dist) + ) +end From 3aebe9b3994933bf0a764a7689a93c121a0bf877 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 3 Nov 2024 15:53:46 -0600 Subject: [PATCH 409/495] Remove unused symbols (#2048) --- src/compiler/optimize.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 2967d5d6c0..2277f27250 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -2818,6 +2818,18 @@ function post_optimze!(mod, tm, machine = true) for f in collect(functions(mod)) API.EnzymeFixupBatchedJuliaCallingConvention(f) end + for g in collect(globals(mod)) + if startswith(LLVM.name(g), "ccall") + hasuse = false + for u in LLVM.uses(g) + hasuse = true + break + end + if !hasuse + eraseInst(mod, g) + end + end + end out_error = Ref{Cstring}() if LLVM.API.LLVMVerifyModule(mod, LLVM.API.LLVMReturnStatusAction, out_error) != 0 throw( From 46fe1cf425e1ff60437c6b3d2bdf71953ba172b6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 3 Nov 2024 16:42:58 -0600 Subject: [PATCH 410/495] Fix size calculation for simplevector (#2049) --- src/absint.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/absint.jl b/src/absint.jl index d693102eda..e006596b71 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -175,7 +175,7 @@ function actual_size(@nospecialize(typ2)) return sum(map(sizeof,fieldtypes(typ2))) end end - if typ2 <: AbstractString || typ2 <: Symbol + if typ2 <: AbstractString || typ2 <: Symbol || typ2 <: Core.SimpleVector return sizeof(Int) elseif Base.isconcretetype(typ2) return sizeof(typ2) From 7d5fc7630deb2be60a08891904366847b11356bb Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 3 Nov 2024 20:03:04 -0600 Subject: [PATCH 411/495] Fix jit lookup functionality (#2050) * Fix jit lookup functionality * fix * Fix * fix --- src/compiler.jl | 710 ++++++++++++++++++++----------------- src/compiler/orcv2.jl | 41 ++- src/compiler/validation.jl | 68 +--- 3 files changed, 431 insertions(+), 388 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 69ceeabc05..00632463c8 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -292,6 +292,8 @@ const nofreefns = Set{String}(( "ijl_f_getfield", "jl_pop_handler", "ijl_pop_handler", + "jl_pop_handler_noexcept", + "ijl_pop_handler_noexcept", "jl_string_to_array", "ijl_string_to_array", "jl_alloc_string", @@ -3129,9 +3131,23 @@ function annotate!(mod, mode) inactive = LLVM.StringAttribute("enzyme_inactive", "") active = LLVM.StringAttribute("enzyme_active", "") no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") - fns = functions(mod) - for f in fns + funcs = Dict{String, Vector{LLVM.Function}}() + for f in functions(mod) + fname = LLVM.name(f) + for fattr in collect(function_attributes(f)) + if isa(fattr, LLVM.StringAttribute) + if kind(fattr) == "enzyme_math" + fname = LLVM.value(fattr) + break + end + end + end + fname = String(fname) + if !haskey(funcs, fname) + funcs[fname] = LLVM.Function[] + end + push!(funcs[String(fname)], f) API.EnzymeAttributeKnownFunctions(f.ref) end @@ -3144,129 +3160,136 @@ function annotate!(mod, mode) end for fname in inactivefns - if haskey(fns, fname) - fn = fns[fname] - push!(function_attributes(fn), inactive) - push!(function_attributes(fn), no_escaping_alloc) - for u in LLVM.uses(fn) - c = LLVM.user(u) - if !isa(c, LLVM.CallInst) - continue - end - cf = LLVM.called_operand(c) - if !isa(cf, LLVM.Function) - continue - end - if LLVM.name(cf) != "julia.call" && LLVM.name(cf) != "julia.call2" - continue - end - if operands(c)[1] != fn - continue + if haskey(funcs, fname) + for fn in funcs[fname] + push!(function_attributes(fn), inactive) + push!(function_attributes(fn), no_escaping_alloc) + for u in LLVM.uses(fn) + c = LLVM.user(u) + if !isa(c, LLVM.CallInst) + continue + end + cf = LLVM.called_operand(c) + if !isa(cf, LLVM.Function) + continue + end + if LLVM.name(cf) != "julia.call" && LLVM.name(cf) != "julia.call2" + continue + end + if operands(c)[1] != fn + continue + end + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + inactive, + ) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + no_escaping_alloc, + ) end - LLVM.API.LLVMAddCallSiteAttribute( - c, - reinterpret( - LLVM.API.LLVMAttributeIndex, - LLVM.API.LLVMAttributeFunctionIndex, - ), - inactive, - ) - LLVM.API.LLVMAddCallSiteAttribute( - c, - reinterpret( - LLVM.API.LLVMAttributeIndex, - LLVM.API.LLVMAttributeFunctionIndex, - ), - no_escaping_alloc, - ) end end end for fname in nofreefns - if haskey(fns, fname) - fn = fns[fname] - push!(function_attributes(fn), LLVM.EnumAttribute("nofree", 0)) - for u in LLVM.uses(fn) - c = LLVM.user(u) - if !isa(c, LLVM.CallInst) - continue - end - cf = LLVM.called_operand(c) - if !isa(cf, LLVM.Function) - continue - end - if LLVM.name(cf) != "julia.call" && LLVM.name(cf) != "julia.call2" - continue - end - if operands(c)[1] != fn - continue + if haskey(funcs, fname) + for fn in funcs[fname] + push!(function_attributes(fn), LLVM.EnumAttribute("nofree", 0)) + for u in LLVM.uses(fn) + c = LLVM.user(u) + if !isa(c, LLVM.CallInst) + continue + end + cf = LLVM.called_operand(c) + if !isa(cf, LLVM.Function) + continue + end + if LLVM.name(cf) != "julia.call" && LLVM.name(cf) != "julia.call2" + continue + end + if operands(c)[1] != fn + continue + end + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + LLVM.EnumAttribute("nofree", 0), + ) end - LLVM.API.LLVMAddCallSiteAttribute( - c, - reinterpret( - LLVM.API.LLVMAttributeIndex, - LLVM.API.LLVMAttributeFunctionIndex, - ), - LLVM.EnumAttribute("nofree", 0), - ) end end end for fname in activefns - if haskey(fns, fname) - fn = fns[fname] - push!(function_attributes(fn), active) + if haskey(funcs, fname) + for fn in funcs[fname] + push!(function_attributes(fn), active) + end end end for fname in ("julia.typeof", "jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id") - if haskey(fns, fname) - fn = fns[fname] - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) - else - push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + if haskey(funcs, fname) + for fn in funcs[fname] + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + end + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) end - push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) end end for fname in ("julia.typeof",) - if haskey(fns, fname) - fn = fns[fname] - push!(function_attributes(fn), LLVM.StringAttribute("enzyme_nocache")) + if haskey(funcs, fname) + for fn in funcs[fname] + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_nocache")) + end end end for fname in ("jl_excstack_state", "ijl_excstack_state", "ijl_field_index", "jl_field_index") - if haskey(fns, fname) - fn = fns[fname] - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) - push!(function_attributes(fn), LLVM.StringAttribute("inaccessiblememonly")) - else - push!( - function_attributes(fn), - EnumAttribute( - "memory", - MemoryEffect( - (MRI_NoModRef << getLocationPos(ArgMem)) | - (MRI_Ref << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ), - ) + if haskey(funcs, fname) + for fn in funcs[fname] + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) + push!(function_attributes(fn), LLVM.StringAttribute("inaccessiblememonly")) + else + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_Ref << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) + end end end end for fname in ("jl_types_equal", "ijl_types_equal") - if haskey(fns, fname) - fn = fns[fname] - push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) + if haskey(funcs, fname) + for fn in funcs[fname] + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) + end end end @@ -3278,79 +3301,82 @@ function annotate!(mod, mode) "jl_f__svec_ref", "ijl_f__svec_ref", ) - if haskey(fns, fname) - fn = fns[fname] - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) - else - push!(function_attributes(fn), - EnumAttribute( - "memory", - MemoryEffect( - (MRI_Ref << getLocationPos(ArgMem)) | - (MRI_NoModRef << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, + if haskey(funcs, fname) + for fn in funcs[fname] + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) + else + push!(function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) ) - ) - end - for u in LLVM.uses(fn) - c = LLVM.user(u) - if !isa(c, LLVM.CallInst) - continue - end - cf = LLVM.called_operand(c) - if !isa(cf, LLVM.Function) - continue - end - if LLVM.name(cf) != "julia.call" && LLVM.name(cf) != "julia.call2" - continue - end - if operands(c)[1] != fn - continue end - attr = if LLVM.version().major <= 15 - LLVM.EnumAttribute("readonly") - else - EnumAttribute( - "memory", - MemoryEffect( - (MRI_Ref << getLocationPos(ArgMem)) | - (MRI_NoModRef << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, + for u in LLVM.uses(fn) + c = LLVM.user(u) + if !isa(c, LLVM.CallInst) + continue + end + cf = LLVM.called_operand(c) + if !isa(cf, LLVM.Function) + continue + end + if LLVM.name(cf) != "julia.call" && LLVM.name(cf) != "julia.call2" + continue + end + if operands(c)[1] != fn + continue + end + attr = if LLVM.version().major <= 15 + LLVM.EnumAttribute("readonly") + else + EnumAttribute( + "memory", + MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) + end + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + attr, ) end - LLVM.API.LLVMAddCallSiteAttribute( - c, - reinterpret( - LLVM.API.LLVMAttributeIndex, - LLVM.API.LLVMAttributeFunctionIndex, - ), - attr, - ) end end end for fname in ("julia.get_pgcstack", "julia.ptls_states", "jl_get_ptls_states") - if haskey(fns, fname) - fn = fns[fname] - # TODO per discussion w keno perhaps this should change to readonly / inaccessiblememonly - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) - else - push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + if haskey(funcs, fname) + for fn in funcs[fname] + # TODO per discussion w keno perhaps this should change to readonly / inaccessiblememonly + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + end + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) end - push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) end end for fname in ("julia.gc_loaded",) - if haskey(fns, fname) - fn = fns[fname] - push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) - push!(function_attributes(fn), LLVM.StringAttribute("enzyme_nocache")) + if haskey(funcs, fname) + for fn in funcs[fname] + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_nocache")) + end end end @@ -3375,6 +3401,8 @@ function annotate!(mod, mode) "jl_array_del_at", "ijl_pop_handler", "jl_pop_handler", + "ijl_pop_handler_noexcept", + "jl_pop_handler_noexcept", "ijl_push_handler", "jl_push_handler", "ijl_module_name", @@ -3393,43 +3421,46 @@ function annotate!(mod, mode) "ijl_try_substrtod", "jl_try_substrtod", ) - if haskey(fns, fname) - fn = fns[fname] - push!(function_attributes(fn), no_escaping_alloc) + if haskey(funcs, fname) + for fn in funcs[fname] + push!(function_attributes(fn), no_escaping_alloc) + end end end for fname in ("julia.pointer_from_objref",) - if haskey(fns, fname) - fn = fns[fname] - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) - else - push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + if haskey(funcs, fname) + for fn in funcs[fname] + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + end end end end - for boxfn in ( + for fname in ( "julia.gc_alloc_obj", "jl_gc_alloc_typed", "ijl_gc_alloc_typed", ) - if haskey(fns, boxfn) - fn = fns[boxfn] - LLVM.API.LLVMRemoveEnumAttributeAtIndex( - fn, - reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), - kind(EnumAttribute("allockind")), - ) - push!(function_attributes(fn), no_escaping_alloc) - push!(function_attributes(fn), LLVM.EnumAttribute("allockind", (AllocFnKind(AFKE_Alloc) | AllocFnKind(AFKE_Uninitialized)).data)) + if haskey(funcs, fname) + for fn in funcs[fname] + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + fn, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + kind(EnumAttribute("allockind")), + ) + push!(function_attributes(fn), no_escaping_alloc) + push!(function_attributes(fn), LLVM.EnumAttribute("allockind", (AllocFnKind(AFKE_Alloc) | AllocFnKind(AFKE_Uninitialized)).data)) + end end end - for boxfn in ( + for fname in ( "julia.gc_alloc_obj", "jl_gc_alloc_typed", "ijl_gc_alloc_typed", @@ -3464,63 +3495,101 @@ function annotate!(mod, mode) "ijl_new_array", "jl_new_array", ) - if haskey(fns, boxfn) - fn = fns[boxfn] - push!(return_attributes(fn), LLVM.EnumAttribute("noalias", 0)) - push!(function_attributes(fn), no_escaping_alloc) - push!(function_attributes(fn), LLVM.EnumAttribute("mustprogress")) - push!(function_attributes(fn), LLVM.EnumAttribute("willreturn")) - push!(function_attributes(fn), LLVM.EnumAttribute("nounwind")) - push!(function_attributes(fn), LLVM.EnumAttribute("nofree")) - accattr = if LLVM.version().major <= 15 - LLVM.EnumAttribute("inaccessiblememonly") - else - if boxfn in ( - "jl_genericmemory_copy_slice", - "ijl_genericmemory_copy_slice",) - EnumAttribute( - "memory", - MemoryEffect( - (MRI_Ref << getLocationPos(ArgMem)) | - (MRI_ModRef << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ) - else - EnumAttribute( - "memory", - MemoryEffect( - (MRI_NoModRef << getLocationPos(ArgMem)) | - (MRI_ModRef << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ) + if haskey(funcs, fname) + for fn in funcs[fname] + push!(return_attributes(fn), LLVM.EnumAttribute("noalias", 0)) + push!(function_attributes(fn), no_escaping_alloc) + push!(function_attributes(fn), LLVM.EnumAttribute("mustprogress")) + push!(function_attributes(fn), LLVM.EnumAttribute("willreturn")) + push!(function_attributes(fn), LLVM.EnumAttribute("nounwind")) + push!(function_attributes(fn), LLVM.EnumAttribute("nofree")) + accattr = if LLVM.version().major <= 15 + LLVM.EnumAttribute("inaccessiblememonly") + else + if fname in ( + "jl_genericmemory_copy_slice", + "ijl_genericmemory_copy_slice",) + EnumAttribute( + "memory", + MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_ModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) + else + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_ModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) + end end - end - if !( - boxfn in ( - "jl_array_copy", - "ijl_array_copy", - "jl_idtable_rehash", - "ijl_idtable_rehash", + if !( + fname in ( + "jl_array_copy", + "ijl_array_copy", + "jl_idtable_rehash", + "ijl_idtable_rehash", + ) ) - ) - push!(function_attributes(fn), accattr) - end - for u in LLVM.uses(fn) - c = LLVM.user(u) - if !isa(c, LLVM.CallInst) - continue + push!(function_attributes(fn), accattr) end - cf = LLVM.called_operand(c) - if cf == fn + for u in LLVM.uses(fn) + c = LLVM.user(u) + if !isa(c, LLVM.CallInst) + continue + end + cf = LLVM.called_operand(c) + if cf == fn + LLVM.API.LLVMAddCallSiteAttribute( + c, + LLVM.API.LLVMAttributeReturnIndex, + LLVM.EnumAttribute("noalias", 0), + ) + if !( + fname in ( + "jl_array_copy", + "ijl_array_copy", + "jl_idtable_rehash", + "ijl_idtable_rehash", + ) + ) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + accattr, + ) + end + end + if !isa(cf, LLVM.Function) + continue + end + if !(cf == fn || + ((LLVM.name(cf) == "julia.call" || LLVM.name(cf) != "julia.call2") && operands(c)[1] == fn)) + continue + end LLVM.API.LLVMAddCallSiteAttribute( c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0), ) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + no_escaping_alloc, + ) if !( - boxfn in ( + fname in ( "jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", @@ -3537,135 +3606,102 @@ function annotate!(mod, mode) ) end end - if !isa(cf, LLVM.Function) - continue - end - if !(cf == fn || - ((LLVM.name(cf) == "julia.call" || LLVM.name(cf) != "julia.call2") && operands(c)[1] == fn)) - continue - end - LLVM.API.LLVMAddCallSiteAttribute( - c, - LLVM.API.LLVMAttributeReturnIndex, - LLVM.EnumAttribute("noalias", 0), - ) - LLVM.API.LLVMAddCallSiteAttribute( - c, - reinterpret( - LLVM.API.LLVMAttributeIndex, - LLVM.API.LLVMAttributeFunctionIndex, - ), - no_escaping_alloc, - ) - if !( - boxfn in ( - "jl_array_copy", - "ijl_array_copy", - "jl_idtable_rehash", - "ijl_idtable_rehash", - ) - ) - LLVM.API.LLVMAddCallSiteAttribute( - c, - reinterpret( - LLVM.API.LLVMAttributeIndex, - LLVM.API.LLVMAttributeFunctionIndex, - ), - accattr, - ) - end end end end - for gc in ("llvm.julia.gc_preserve_begin", "llvm.julia.gc_preserve_end") - if haskey(fns, gc) - fn = fns[gc] - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly")) - else - push!( - function_attributes(fn), - EnumAttribute( - "memory", - MemoryEffect( - (MRI_NoModRef << getLocationPos(ArgMem)) | - (MRI_ModRef << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ), - ) + for fname in ("llvm.julia.gc_preserve_begin", "llvm.julia.gc_preserve_end") + if haskey(funcs, fname) + for fn in funcs[fname] + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly")) + else + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_ModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) + end end end end # Key of jl_eqtable_get/put is inactive, definitionally - for rfn in ("jl_eqtable_get", "ijl_eqtable_get") - if haskey(fns, rfn) - fn = fns[rfn] - push!(parameter_attributes(fn, 2), LLVM.StringAttribute("enzyme_inactive")) - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) - push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) - else - push!( - function_attributes(fn), - EnumAttribute( - "memory", - MemoryEffect( - (MRI_Ref << getLocationPos(ArgMem)) | - (MRI_NoModRef << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ), - ) + for fname in ("jl_eqtable_get", "ijl_eqtable_get") + if haskey(funcs, fname) + for fn in funcs[fname] + push!(parameter_attributes(fn, 2), LLVM.StringAttribute("enzyme_inactive")) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) + push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) + else + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) + end end end end # Key of jl_eqtable_get/put is inactive, definitionally - for rfn in ("jl_eqtable_put", "ijl_eqtable_put") - if haskey(fns, rfn) - fn = fns[rfn] - push!(parameter_attributes(fn, 2), LLVM.StringAttribute("enzyme_inactive")) - push!(parameter_attributes(fn, 4), LLVM.StringAttribute("enzyme_inactive")) - push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("writeonly")) - push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("nocapture")) - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) - else - push!( - function_attributes(fn), - EnumAttribute( - "memory", - MemoryEffect( - (MRI_ModRef << getLocationPos(ArgMem)) | - (MRI_NoModRef << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ), - ) + for fname in ("jl_eqtable_put", "ijl_eqtable_put") + if haskey(funcs, fname) + for fn in funcs[fname] + push!(parameter_attributes(fn, 2), LLVM.StringAttribute("enzyme_inactive")) + push!(parameter_attributes(fn, 4), LLVM.StringAttribute("enzyme_inactive")) + push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("writeonly")) + push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("nocapture")) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) + else + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_ModRef << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) + end end end end - for rfn in ("jl_in_threaded_region_", "jl_in_threaded_region") - if haskey(fns, rfn) - fn = fns[rfn] - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) - push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly")) - else - push!( - function_attributes(fn), - EnumAttribute( - "memory", - MemoryEffect( - (MRI_NoModRef << getLocationPos(ArgMem)) | - (MRI_Ref << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ), - ) + for fname in ("jl_in_threaded_region_", "jl_in_threaded_region") + if haskey(funcs, fname) + for fn in funcs[fname] + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) + push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly")) + else + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_Ref << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) + end end end end @@ -6015,6 +6051,7 @@ function no_type_setting(@nospecialize(specTypes); world = nothing) return (false, false) end +const DumpPreCheck = Ref(false) const DumpPreOpt = Ref(false) function GPUCompiler.codegen( @@ -6093,6 +6130,9 @@ function GPUCompiler.codegen( end primalf = meta.entry + if DumpPreCheck[] + API.EnzymeDumpModuleRef(mod.ref) + end check_ir(job, mod) disableFallback = String[] @@ -8290,7 +8330,7 @@ function _link(job, (mod, adjoint_name, primal_name, TapeType)) # Now invoke the JIT jitted_mod = JIT.add!(mod) - adjoint_addr = JIT.lookup(jitted_mod, adjoint_name) + adjoint_addr = JIT.lookup(adjoint_name) adjoint_ptr = pointer(adjoint_addr) if adjoint_ptr === C_NULL @@ -8304,7 +8344,7 @@ function _link(job, (mod, adjoint_name, primal_name, TapeType)) if primal_name === nothing primal_ptr = C_NULL else - primal_addr = JIT.lookup(jitted_mod, primal_name) + primal_addr = JIT.lookup(primal_name) primal_ptr = pointer(primal_addr) if primal_ptr === C_NULL throw( diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index 4b8f2d202a..7588eddb78 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -46,6 +46,34 @@ function absolute_symbol_materialization(name, ptr) return LLVM.absolute_symbols(Ref(gv)) end +const hnd_string_map = Dict{String, Ref{Ptr{Cvoid}}}() + +function fix_ptr_lookup(name) + if startswith(name, "ejlstr\$") || startswith(name, "ejlptr\$") + _, fname, arg1 = split(name, "\$") + if startswith(name, "ejlstr\$") + ptr = if haskey(hnd_string_map, arg1) + hnd_string_map[arg1] + else + val = Ref{Ptr{Cvoid}}(C_NULL) + hnd_string_map[arg1] = val + val + end + + return ccall( + :ijl_load_and_lookup, + Ptr{Cvoid}, + (Cstring, Cstring, Ptr{Cvoid}), + arg1, + fname, + ptr + ) + else + end + end + return nothing +end + function define_absolute_symbol(jd, name) ptr = LLVM.find_symbol(name) if ptr !== C_NULL @@ -213,6 +241,17 @@ function get_trampoline(job) end function add!(mod) + for f in collect(functions(mod)) + ptr = fix_ptr_lookup(LLVM.name(f)) + if ptr === nothing + continue + end + ptr = reinterpret(UInt, ptr) + ptr = LLVM.ConstantInt(ptr) + ptr = LLVM.const_inttoptr(ptr, LLVM.PointerType(LLVM.function_type(f))) + replace_uses!(f, ptr) + Compiler.eraseInst(mod, f) + end lljit = jit[].jit jd = LLVM.JITDylib(lljit) tsm = move_to_threadsafe(mod) @@ -220,7 +259,7 @@ function add!(mod) return nothing end -function lookup(_, name) +function lookup(name) LLVM.lookup(jit[].jit, name) end diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 5216244e52..3aebe5e528 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -432,7 +432,7 @@ function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns) match_ = match(r"^jlplt_(.*)_\d+_got$", fname) if match_ !== nothing - fname = match_[1] + fname = String(match_[1]) FT = nothing todo = LLVM.Instruction[inst] while length(todo) != 0 @@ -453,7 +453,6 @@ function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns) end end @assert FT !== nothing - newf, _ = get_function!(mod, String(fname), FT) initfn, _ = get_base_and_offset(LLVM.initializer(fn_got); offsetAllowed=false, inttoptr=false) loadfn = first(instructions(first(blocks(initfn))))::LLVM.LoadInst @@ -474,6 +473,7 @@ function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns) opv = opv::LLVM.GlobalVariable if startswith(fname, "jl_") || startswith(fname, "ijl_") || startswith(fname, "_j_") + newf, _ = get_function!(mod, fname, FT) else found = nothing for lbb in blocks(initfn), linst in collect(instructions(lbb)) @@ -549,62 +549,26 @@ function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns) throw(AssertionError(msg)) end - hnd = operands(found)[3] - - if !isa(hnd, LLVM.GlobalVariable) - msg = sprint() do io::IO - println( - io, - "Enzyme internal error unsupported got(hnd)", - ) - println(io, "inst=", inst) - println(io, "fname=", fname) - println(io, "FT=", FT) - println(io, "fn_got=", fn_got) - println(io, "init=", string(initfn)) - println(io, "opv=", string(opv)) - println(io, "found=", string(found)) - println(io, "hnd=", string(hnd)) - end - throw(AssertionError(msg)) - end - hnd = LLVM.name(hnd) - # println(string(mod)) - - # TODO we don't restore/lookup now because this fails - # @vchuravy / @gbaraldi this needs help looking at how to get the actual handle and setup - - if true - res = nothing - elseif arg1 isa AbstractString - res = ccall( - :ijl_load_and_lookup, - Ptr{Cvoid}, - (Cstring, Cstring, Ptr{Cvoid}), - arg1, - fname, - reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr), - ) + fused_name = if arg1 isa AbstractString + "ejlstr\$$fname\$$arg1" else - res = ccall( - :ijl_load_and_lookup, - Ptr{Cvoid}, - (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), - arg1, - fname, - reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr), - ) + if arg1 == reinterpret(Ptr{Nothing}, UInt(0x3)) + fname + else + arg1 = reinterpret(UInt, arg1) + "ejlptr\$$fname\$$arg1" + end end - if res !== nothing - push!(function_attributes(newf), StringAttribute("enzymejl_needs_restoration", string(convert(UInt, res)))) - end + newf, _ = get_function!(mod, fused_name, FT) + + push!(function_attributes(newf), StringAttribute("enzyme_math", fname)) # TODO we can make this relocatable if desired by having restore lookups re-create this got initializer/etc # metadata(newf)["enzymejl_flib"] = flib # metadata(newf)["enzymejl_flib"] = flib end - + if value_type(newf) != value_type(inst) newf = const_pointercast(newf, value_type(inst)) end @@ -761,7 +725,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), arg1, fname, - reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr), + reinterpret(Ptr{Cvoid}, JIT.lookup(hnd).ptr), ) else res = ccall( @@ -770,7 +734,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), arg1, fname, - reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr), + reinterpret(Ptr{Cvoid}, JIT.lookup(hnd).ptr), ) end replaceWith = LLVM.ConstantInt( From 83b908ede08354c5f5eb5b7b47b24a45d9cc1e70 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 3 Nov 2024 23:13:29 -0600 Subject: [PATCH 412/495] Fix warning (#2053) --- src/compiler/validation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 3aebe5e528..f5383d7aee 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -480,7 +480,7 @@ function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns) if !isa(linst, LLVM.CallInst) continue end - cv = LLVM.called_value(linst) + cv = LLVM.called_operand(linst) if !isa(cv, LLVM.Function) continue end From 899936dca468bc36c0e18cbc5220f4a74ece1fc5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 4 Nov 2024 08:43:01 -0600 Subject: [PATCH 413/495] Update Project.toml (#2055) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1fa4a49186..d5fddd0cfa 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.4, 0.8.5" -Enzyme_jll = "0.0.160" +Enzyme_jll = "0.0.161" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From 2e18dfc3ca4e4de674999e322269284e13604b6a Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 4 Nov 2024 18:27:00 -0600 Subject: [PATCH 414/495] Continuing voyage of 1.111 (#2058) --- src/compiler.jl | 2 ++ test/runtests.jl | 26 +++++++++++++++++++++----- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 00632463c8..786fba0134 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3257,6 +3257,7 @@ function annotate!(mod, mode) if haskey(funcs, fname) for fn in funcs[fname] push!(function_attributes(fn), LLVM.StringAttribute("enzyme_nocache")) + push!(parameter_attributes(fn, 1), LLVM.EnumAttribute("nocapture")) end end end @@ -3498,6 +3499,7 @@ function annotate!(mod, mode) if haskey(funcs, fname) for fn in funcs[fname] push!(return_attributes(fn), LLVM.EnumAttribute("noalias", 0)) + push!(return_attributes(fn), LLVM.EnumAttribute("nonnull", 0)) push!(function_attributes(fn), no_escaping_alloc) push!(function_attributes(fn), LLVM.EnumAttribute("mustprogress")) push!(function_attributes(fn), LLVM.EnumAttribute("willreturn")) diff --git a/test/runtests.jl b/test/runtests.jl index 31dcbd3e66..526be869ee 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2303,18 +2303,34 @@ end @testset "Broadcast noalias" begin x = ones(30) - autodiff(Reverse, bc0_test_function, Active, Const(x)) + + @static if VERSION < v"1.11-" + autodiff(Reverse, bc0_test_function, Active, Const(x)) + else + # TODO + @test_broken autodiff(Reverse, bc0_test_function, Active, Const(x)) + end x = rand(Float32, 2, 3) - Enzyme.autodiff(Reverse, bc1_loss_function, Duplicated(x, zero(x))) + @static if VERSION < v"1.11-" + Enzyme.autodiff(Reverse, bc1_loss_function, Duplicated(x, zero(x))) + else + # TODO + @test_broken Enzyme.autodiff(Reverse, bc1_loss_function, Duplicated(x, zero(x))) + end x = rand(Float32, 6, 6, 6, 2) sc = rand(Float32, 6) bi = rand(Float32, 6) - Enzyme.autodiff(Reverse, bc2_loss_function, Active, Duplicated(x, Enzyme.make_zero(x)), - Duplicated(sc, Enzyme.make_zero(sc)), Duplicated(bi, Enzyme.make_zero(bi))) + @static if VERSION < v"1.11-" + Enzyme.autodiff(Reverse, bc2_loss_function, Active, Duplicated(x, Enzyme.make_zero(x)), + Duplicated(sc, Enzyme.make_zero(sc)), Duplicated(bi, Enzyme.make_zero(bi))) + else + # TODO + @test_broken Enzyme.autodiff(Reverse, bc2_loss_function, Active, Duplicated(x, Enzyme.make_zero(x)), + Duplicated(sc, Enzyme.make_zero(sc)), Duplicated(bi, Enzyme.make_zero(bi))) + end end - function solve_cubic_eq(poly::AbstractVector{Complex{T}}) where T a1 = 1 / @inbounds poly[1] E1 = 2*a1 From df20f1883ece2ca9b346568ea3f82f2f6247a283 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 4 Nov 2024 18:27:43 -0600 Subject: [PATCH 415/495] Debug specfunc (#2052) * Debug specfunc * Add test * Update optimize.jl --- src/compiler.jl | 2 +- test/optimize.jl | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 786fba0134..e35fd8b6a1 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3453,7 +3453,7 @@ function annotate!(mod, mode) LLVM.API.LLVMRemoveEnumAttributeAtIndex( fn, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), - kind(EnumAttribute("allockind")), + kind(EnumAttribute("allockind", AllocFnKind(AFKE_Alloc).data)), ) push!(function_attributes(fn), no_escaping_alloc) push!(function_attributes(fn), LLVM.EnumAttribute("allockind", (AllocFnKind(AFKE_Alloc) | AllocFnKind(AFKE_Uninitialized)).data)) diff --git a/test/optimize.jl b/test/optimize.jl index 7bc89ecc74..4792ac0570 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -115,3 +115,18 @@ end Enzyme.Const(dist) ) end + +@noinline function mc_g(i, _not_used) + k = (0.25) + return (i, k) +end + +function mc_f(_not_used) + i = (0.0, 3.9555) + t = mc_g(i, _not_used) + return t[1][2] +end + +@testset "Memcopy of constant" begin + @test Enzyme.autodiff(Enzyme.Forward, mc_f, Duplicated(2.7, 1.0))[1] ≈ 0.0 +end From 3548f3e4d977341511e49383334ed5dfc5053888 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 5 Nov 2024 09:21:19 -0600 Subject: [PATCH 416/495] Update Project.toml (#2062) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d5fddd0cfa..a0f5a7190e 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.4, 0.8.5" -Enzyme_jll = "0.0.161" +Enzyme_jll = "0.0.162" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From 527f8966d497c96662b53c046f3a6b7c8325f92a Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 5 Nov 2024 13:24:52 -0600 Subject: [PATCH 417/495] 111s (#2059) * Continuing 1.11 support * fix * fix * Continuing 1.11 support * fix * fix * of course 1.11 broke random seeds * fix --- src/compiler/optimize.jl | 60 +++++++++++++++++++++++--------------- src/compiler/validation.jl | 4 ++- test/runtests.jl | 38 +++++++++++------------- 3 files changed, 57 insertions(+), 45 deletions(-) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 2277f27250..94bfd0bfef 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -867,7 +867,7 @@ function nodecayed_phis!(mod::LLVM.Module) end end - if addr == 11 && isa(v, LLVM.ConstantExpr) + if isa(v, LLVM.ConstantExpr) if opcode(v) == LLVM.API.LLVMAddrSpaceCast v2 = operands(v)[1] if addrspace(value_type(v2)) == 10 @@ -890,6 +890,42 @@ function nodecayed_phis!(mod::LLVM.Module) return v2, offset, hasload end end + if opcode(v) == LLVM.API.LLVMBitCast + preop = operands(v)[1] + while isa(preop, LLVM.ConstantExpr) && opcode(preop) == LLVM.API.LLVMBitCast + preop = operands(preop)[1] + end + v2, offset, skipload = + getparent(preop, offset, hasload) + v2 = const_bitcast( + v2, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(v2)), + ), + ) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end + + if opcode(v) == LLVM.API.LLVMGetElementPtr + v2, offset, skipload = + getparent(operands(v)[1], offset, hasload) + offset = const_add( + offset, + API.EnzymeComputeByteOffsetOfGEP(b, v, offty), + ) + v2 = const_bitcast( + v2, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(v2)), + ), + ) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end + end if isa(v, LLVM.AddrSpaceCastInst) @@ -973,28 +1009,6 @@ function nodecayed_phis!(mod::LLVM.Module) return v2, offset, skipload end - if isa(v, LLVM.ConstantExpr) && - opcode(v) == LLVM.API.LLVMGetElementPtr && - !hasload - v2, offset, skipload = - getparent(operands(v)[1], offset, hasload) - offset = nuwadd!( - b, - offset, - API.EnzymeComputeByteOffsetOfGEP(b, v, offty), - ) - v2 = bitcast!( - b, - v2, - LLVM.PointerType( - eltype(value_type(v)), - addrspace(value_type(v2)), - ), - ) - @assert eltype(value_type(v2)) == eltype(value_type(v)) - return v2, offset, skipload - end - undeforpoison = isa(v, LLVM.UndefValue) @static if LLVM.version() >= v"12" undeforpoison |= isa(v, LLVM.PoisonValue) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index f5383d7aee..0ce2e61eb1 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -509,7 +509,9 @@ function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns) if legal1 else arg1, _ = get_base_and_offset(operands(found)[1]; offsetAllowed=false, inttoptr=true) - if !isa(arg1, LLVM.ConstantInt) + if isa(arg1, LLVM.PointerNull) + arg1 = LLVM.ConstantInt(0) + elseif !isa(arg1, LLVM.ConstantInt) msg = sprint() do io::IO println( io, diff --git a/test/runtests.jl b/test/runtests.jl index 526be869ee..d00ab08ca3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2312,25 +2312,15 @@ end end x = rand(Float32, 2, 3) - @static if VERSION < v"1.11-" - Enzyme.autodiff(Reverse, bc1_loss_function, Duplicated(x, zero(x))) - else - # TODO - @test_broken Enzyme.autodiff(Reverse, bc1_loss_function, Duplicated(x, zero(x))) - end + Enzyme.autodiff(Reverse, bc1_loss_function, Duplicated(x, zero(x))) x = rand(Float32, 6, 6, 6, 2) sc = rand(Float32, 6) bi = rand(Float32, 6) - @static if VERSION < v"1.11-" - Enzyme.autodiff(Reverse, bc2_loss_function, Active, Duplicated(x, Enzyme.make_zero(x)), - Duplicated(sc, Enzyme.make_zero(sc)), Duplicated(bi, Enzyme.make_zero(bi))) - else - # TODO - @test_broken Enzyme.autodiff(Reverse, bc2_loss_function, Active, Duplicated(x, Enzyme.make_zero(x)), - Duplicated(sc, Enzyme.make_zero(sc)), Duplicated(bi, Enzyme.make_zero(bi))) - end + Enzyme.autodiff(Reverse, bc2_loss_function, Active, Duplicated(x, Enzyme.make_zero(x)), + Duplicated(sc, Enzyme.make_zero(sc)), Duplicated(bi, Enzyme.make_zero(bi))) end + function solve_cubic_eq(poly::AbstractVector{Complex{T}}) where T a1 = 1 / @inbounds poly[1] E1 = 2*a1 @@ -2719,9 +2709,18 @@ end x = [2.3] dx = [0.0] + rf = @static if VERSION < v"1.11-" + nothing + else + dx.ref.mem + end @test 1.0 ≈ first(Enzyme.autodiff(Reverse, pusher, Duplicated(x, dx), Active(2.0)))[2] - @test x ≈ [2.3, 2.0] - @test dx ≈ [1.0] + @static if VERSION < v"1.11-" + @test dx ≈ [1.0] + else + @test dx ≈ [0.0, 0.0] + @test rf ≈ [1.0,] + end function double_push(x) a = [0.5] @@ -3429,14 +3428,11 @@ end fwd, rev = Enzyme.autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(cual)}, Duplicated) end - -const SEED = 42 const N_SAMPLES = 500 const N_COMPONENTS = 4 -const rnd = Random.MersenneTwister(SEED) -const data = randn(rnd, N_SAMPLES) -const params0 = [rand(rnd, N_COMPONENTS); randn(rnd, N_COMPONENTS); 2rand(rnd, N_COMPONENTS)] +const data = [-0.5560268761463861, -0.444383357109696, 0.027155338009193845, -0.29948409035891055, 1.7778610980573246, -1.14490153172882, -0.46860588216767457, 0.15614346264074028, -2.641991008076796, 1.0033099014594844, 1.0823812056084292, 0.18702790710363, 0.5181487878771377, 1.4913791170403063, 0.3675627461748204, -0.8862052960481365, 0.6845647041648603, -1.590579974922555, 0.410653382404333, -0.856349830552304, -1.0509877103612828, 0.502079457623539, -0.2162480073298127, -0.7064242722014349, -3.663034802576991, 0.1683412659309176, 0.28425906701710857, 0.5698286489101805, -1.4220589095735332, -0.37240087577993225, 0.36901028455183293, -0.007612980079313577, 0.562668812321259, 0.10686911035365092, 0.5694584949295476, 0.6810849274435286, -1.3391251213773154, -0.23828371819888622, 1.0193587887377156, 0.7017713284136731, -0.14521050140729616, 0.6428964706243068, 1.8193484973888463, -0.3672598918343543, 0.7565689715912548, 0.08701681344541234, -0.8511936279150268, 0.9242378649817816, 1.6120908802143608, -0.9258028623888832, 0.49199249434769243, -0.22608145131612992, -1.351640432408328, 0.3023655023653634, 1.2252567101879008, -0.4579776898846808, -0.36873503034471294, -0.5879094743345893, 0.8901285443512134, 0.23942258932052793, -0.195126767304092, 0.6541265910968944, -0.8180579409672205, -1.6505670679051851, -0.41299157333934094, 0.027291621540711814, 0.29759513748684024, 0.07136737211887596, -0.8945184707955289, -1.9947538311302822, -0.18728482002197755, -0.5854431752183636, -1.5094025801991475, -0.10841979845993832, -0.37149972405672654, 1.209427254258146, -0.20401001223483511, 0.012484378414156184, 0.14058032536255227, -0.9923922290801328, -1.1484589871168687, 1.3715759475375402, -0.05906784259018913, -0.3530786655768097, -1.4488810057984256, 2.3879153026952267, 0.12580788649805388, -1.9725913559635857, 0.7009118499232365, 0.31700578675698954, 0.2762925198922067, -1.625619456664758, -0.9373153541158463, -0.9304928802188113, -1.9905539723314605, 0.13753980192836776, 3.1495759241275225, -0.7214121984874265, -0.577092696986848, 0.4593837753047561, 0.24720770605505335, -0.02566249380406308, -0.6320809680268722, -1.0204193548690512, -1.311507800204285, 0.687066447881175, -0.09460071173365793, -0.28474772376166907, -0.3387627167144301, -0.09536748694092163, 0.7689715736798227, 1.442443602597533, -0.27503213774867397, -0.37963749230903393, -0.7226963203200736, 0.13966112558116067, -1.4093392453795512, 1.0554566357077169, -2.237822231313888, 1.15915397311866, -0.6901150042613587, -1.3821310931236412, 0.5938504216651738, 0.6609960603802911, 2.7589164663644565, 0.46763556069956336, -0.08060548262614446, 0.0795712995405885, -0.36251216727772734, -0.6883308052782828, -0.6957669363357581, -0.4298941588197229, -0.6170193131914518, -0.7875381233595508, 0.9793681640487545, -0.16689001105724968, -0.4410353697187671, -0.0106585238662763, 0.4075406194010163, -0.3824860969744455, -0.4306357340413131, -0.05292489819141157, 0.7631065933222378, -1.7078224461664113, -0.665051961459703, 1.5950208497188643, 0.5424677877448506, 0.2702908010561237, 1.1637402735830906, -0.9752391996154888, 0.591290372307129, -0.5811624500841813, -1.0412662750801527, -0.19245292741043743, -0.4348102015339658, 0.08422740657017527, 1.0438125328282608, -0.4927174182298662, 1.2458570216388754, 1.1205311509593492, -0.12330863869436813, -1.0664768973086367, 0.30470144397407023, -1.8010263359728695, -0.13268665889419914, 0.630295337858245, -2.3417617931253183, -0.15973244410747056, 0.6795317720001368, -0.7447645337583819, 1.2306970723625588, 1.090597639554929, 1.8958640522777934, -0.26786676662796843, 2.0394271629613625, 0.055740331653061796, -0.7193540879657702, -0.1628910819232136, -0.2882790142695281, -2.0534491466053764, -2.233319106981269, -1.1534087585837147, -1.5591843759684125, -1.3523434858414614, -0.35519147266027956, -0.9383662929403082, 0.5502010944298217, -1.6530342319857132, -1.0177969113873517, 1.3546070391530678, -0.7303143540080486, 1.4594015115819061, 1.1755531107578732, 1.469591632664121, 2.155036256421912, -0.1978160724116997, -1.238444066448837, 0.4762431627842421, 0.5664035042754365, 2.191213907869695, -0.16697076769423774, -1.8401747205811765, -0.5935878583224328, -0.4447185333005777, 1.8811333529161927, -0.857003256515023, 0.5308971512660459, 0.9475262937475932, 0.5065543926241618, -1.426876319667508, 0.27277024759538804, 1.6832914767785863, 0.8794490994419152, -0.37229135630818994, 1.2103793835305425, -0.8145152351424103, -0.6637107250031519, -0.3642002730267983, 0.128180863566292, -0.8555397517942166, 0.7463496214304585, -0.21349745615233298, 0.6069963236292358, -0.15043631758004797, 0.2865438734856877, 0.9689530290178722, 0.4645944393389331, -0.10075844220363214, 0.9719135191686711, 1.359272322470581, 0.9198388707805807, -0.003947750510092571, 0.6651494514356097, -0.4642945862879513, -0.09632384020701204, -1.4640316974713914, 0.03411735606151582, 0.192612577289544, 1.2723529571406376, -0.6797254154698416, 0.7121446020587434, 1.6474227377969937, -1.3612960160873442, 1.639942921300844, -1.2934385805566249, -0.6093348312692207, -0.4929035640975793, 0.07652791635611562, 0.15922627820027674, 0.4446393063910561, -1.212412143247565, -1.3517775856153358, -1.0318877340986508, 1.074228515227531, 1.0673524298941364, -0.17007000912445897, 0.19378013515263207, -2.4816300047227666, -0.4592570057697022, -1.1897921624102403, 0.26056620827295174, -0.6309468513820905, 1.5399524139961005, -0.10352720131589113, 1.0498414218815497, -0.08560706218145639, -1.968271606952006, 0.9137220126511895, 1.5165903941194543, -0.9634560368533389, 1.1884250536535346, -0.23295101440683805, 0.9553533369282987, -0.3098984978002516, -0.042208017646174, -1.9930838186185373, 0.6230463669791857, -2.7605050117797982, 1.2120139690044167, 1.5742425795634203, -0.8464907448939961, 0.7321425928205605, 1.044045122552859, 1.6213749963355164, 2.750115535545452, 1.7347194078359887, -1.2300524375750894, -0.36190025258293085, 0.16420959796469084, -0.2846711718337991, 1.815652557702069, -0.7696456318219987, -0.3758835433623372, -0.31538987759342985, 0.23203583241300865, -0.9042757791617796, 0.14623184817251003, 0.22324769054960142, -0.07430046379776922, 0.8598376944077396, 0.8094753739829023, 0.7780695563934816, -0.9880838058116279, -0.17529075003709038, -1.4320848711398677, 0.49819547701694217, 1.455253953626022, -0.8646238569013837, 0.6274090322589988, 0.7214994440898491, 1.0249395310099292, 1.6051684957766426, -0.41752946512010924, -1.187706044484646, 1.9667607247158339, -0.8273416405870931, -1.8095812957335915, 0.21946085689792014, 1.6959323077028474, 0.07606410600663914, -0.0005899969679317951, -0.6300575938012335, 0.7168660929110295, -1.6957830800502658, -2.378949781992021, 0.1614508487249226, -0.34807928510100733, -0.19506959830062723, -0.5497606187369344, 1.0808323747949233, -2.2125060475463463, 0.8718983568753548, -0.007206357093314457, -1.575891948273875, -2.2088301903139564, -0.6163495955240155, -0.5801739513350528, -1.5612897485472592, -1.3002596896606895, -1.0059040824152614, 0.6796485760534101, -0.043207370167515954, -0.039839626218822005, -0.4385362125324239, -0.09173118968880091, 1.22561207137653, -0.232131676978592, -1.2139396904067505, -0.23690460123730323, -0.3827075659919168, -1.9688978438045297, -1.5797479817238906, -0.5654974841550139, -1.7170129974387656, -2.446261897220929, -0.26713540122649804, 0.6778692338783507, -0.1689008828898645, 1.604831121738095, 1.7480788262457672, 0.9166815687612451, 1.1341703209400371, -2.1775754411288144, -0.330419660506073, -0.2672312785080624, 1.200964147356692, 1.3170491877286854, 0.6023924017880021, -0.9827718177516547, -1.1457095184571038, 0.25819205428715203, -2.282547976724439, -3.049187087985662, -1.281790532097414, 2.8015194483397003, 0.5639614209301308, -0.45753014518031915, -0.7991285413217107, 1.0753460926351635, -0.5569593865129592, 0.049550548181828, -0.8913053693383933, -0.7053146611866707, 1.3970437844025363, -1.9127587026035067, -0.6408264977866664, -0.4027208663305603, 0.2116527896752755, 2.2400517749401025, -0.881636745059659, -0.6176341167004805, -0.3916247912848145, 0.9513607440394651, 0.14123219972588208, 1.2053043404475006, 0.023930401450278135, -0.8132651104773965, 1.010114660634259, 0.14932562573657962, 1.7774482649689414, 0.8427840155284196, -0.9801124442248111, 1.2865644225495012, 0.4389849473895256, -0.505701456587577, -1.6549980227323258, 0.6515278406227536, -0.5295755930612868, 0.9056306662830947, 0.08972278324911305, 0.23264270532152218, 2.627396969584777, 0.27204169314700904, 0.9247523784433901, 0.39543449166392586, -3.0074905237355902, 1.1183821383894414, 0.17479140498380819, -1.2141175099769368, 0.19312543732457393, 0.3046417196295455, -2.1349686721255985, 0.5660453142567702, 0.025065143849368067, -0.3915696698906197, 0.8816658350282802, 0.8266483514919597, 0.9493314906580934, -0.0032065446136149488, -1.7961894556919296, 0.4130469384119612, -1.28722133892104, -1.119436381330247, 0.773218214182339, -0.3837586462966479, 0.07777043635090204, -0.7542646791925891, 0.08136609240300065, 0.2746677233237281, 1.1122181237929718, 0.5326574958161293, 0.7823989790674563, -0.31307892473155574, -0.04580883976550116, 0.1691926964033846, 0.37103104440006834, -0.9432191269564248, -0.7609096208689287, 0.2804422856751161, -0.048373157442897496, -0.981155307666483, 1.3029831269962606, 0.059610387285286434, 0.12450090856951314, 0.11358777574045617, -1.3306646401495767, -0.34798310991558473, -0.2866743445913757, 0.674272748688434, -0.4239489256735372, -0.9229714850145041, 0.3113603292165489, -0.4846778890580723, -0.013595195649033436, 1.2403767852654752, 1.0331480262937687, -0.11947562831616687, -0.6057052894354995, -0.7754699190529916, 1.1616052799742849, 1.2438648692130239, 0.027783265850463142, -1.2121179280013439, 0.6370251861203581, 0.6834320658257506, 0.6241655870590277, -1.353228100410462, 0.8938693570362417, 0.8374026807814964, -0.3732266794347597, 1.1790529100520817, 0.7911863169212741, 0.2189687616385222, -0.6113204774701082, 0.19900691423850023, 0.31468457309049136, 1.2458549461519632, 0.5053751217075103, -0.4497826390316608, 0.6003636947378631, 0.7879125998061005, 0.4768361874753698, -0.7096215982620703, 0.09448322860785968, -1.6374823189754906, 1.1567167561713774, 0.7983310036650442, -1.3254511689721826, -0.2200472053270165, 0.629399242764823] +const params0 = [0.25733304995705586, 0.4635056170085754, 0.5285451129509773, 0.7120981447127772, 0.835601264145011, -1.4646785862195637, 0.24736086263101278, -0.21967358320549735, 1.0624643704713206, 1.628664511492019, 1.8530572439128092, 0.6276756477143253] # ========== Objective function ========== normal_pdf(x::Real, mean::Real, var::Real) = From a0552afae1ad39b1b21140383c949368f646225a Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 5 Nov 2024 13:45:39 -0600 Subject: [PATCH 418/495] Add absint fix and tests (#2060) * Add absint fix and tests * fix * fix * fix * fix --- src/absint.jl | 152 +++++++++++++++++++++-------------------- src/compiler.jl | 81 ++++++++++++---------- src/rules/typerules.jl | 20 ++++-- src/typetree.jl | 33 +++------ src/utils.jl | 61 ++++++++++++++++- test/absint.jl | 17 +++++ test/runtests.jl | 1 + test/typetree.jl | 1 + 8 files changed, 224 insertions(+), 142 deletions(-) create mode 100644 test/absint.jl diff --git a/src/absint.jl b/src/absint.jl index e006596b71..b706d4134a 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -168,7 +168,7 @@ end function actual_size(@nospecialize(typ2)) @static if VERSION < v"1.11-" if typ2 <: Array - return sizeof(Int) + return sizeof(Ptr{Cvoid}) + 2 + 2 + 4 + 2 * sizeof(Csize_t) + sizeof(Csize_t) end else if typ2 <: GenericMemory @@ -185,6 +185,11 @@ function actual_size(@nospecialize(typ2)) end @inline function first_non_ghost(@nospecialize(typ2)) + @static if VERSION < v"1.11-" + if typ2 <: Array + return (1, typed_fieldtype(typ2, 1)) + end + end fc = fieldcount(typ2) for i in 1:fc if i == fc @@ -496,101 +501,100 @@ function abs_typeof( larg, offset = get_base_and_offset(operands(arg)[1]) legal, typ, byref = abs_typeof(larg, false, seenphis) - if legal && (byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF) && Base.isconcretetype(typ) - @static if VERSION < v"1.11-" - if typ <: Array && Base.isconcretetype(typ) - T = eltype(typ) - if offset == 0 - if !allocatedinline(T) && Base.isconcretetype(T) - T = Ptr{T} - end - return (true, Ptr{T}, GPUCompiler.BITS_VALUE) - else - return (true, Int, GPUCompiler.BITS_VALUE) - end - end + dl = LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(arg)))) + + shouldLoad = true + + if legal && typ <: Ptr && Base.isconcretetype(typ) && byref == GPUCompiler.BITS_VALUE + ET = eltype(typ) + sz = actual_size(ET) + offset %= sz + byref = GPUCompiler.MUT_REF + typ = ET + if !Base.allocatedinline(typ) + shouldLoad = false end - if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF - dl = LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(arg)))) - + end + + if legal && (byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF) && Base.isconcretetype(typ) + if shouldLoad byref = GPUCompiler.BITS_VALUE - legal = true + end - while offset != 0 && legal - @assert Base.isconcretetype(typ) - seen = false - lasti = 1 - for i = 1:fieldcount(typ) - fo = fieldoffset(typ, i) - if fieldoffset(typ, i) == offset - offset = 0 - typ = typed_fieldtype(typ, i) - if !Base.allocatedinline(typ) - if byref != GPUCompiler.BITS_VALUE - legal = false - end - byref = GPUCompiler.MUT_REF - end - seen = true - break - elseif fieldoffset(typ, i) > offset - offset = offset - fieldoffset(typ, lasti) - typ = typed_fieldtype(typ, lasti) - @assert Base.isconcretetype(typ) - if !Base.allocatedinline(typ) + legal = true + + while offset != 0 && legal + @assert Base.isconcretetype(typ) + seen = false + lasti = 1 + for i = 1:typed_fieldcount(typ) + fo = typed_fieldoffset(typ, i) + if fo == offset + offset = 0 + typ = typed_fieldtype(typ, i) + if !Base.allocatedinline(typ) + if byref != GPUCompiler.BITS_VALUE legal = false end - seen = true - break - end - - if fo != 0 && fo != fieldoffset(typ, i-1) - lasti = i + byref = GPUCompiler.MUT_REF end - end - if !seen && fieldcount(typ) > 0 - offset = offset - fieldoffset(typ, lasti) - typ = typed_fieldtype(typ, lasti) + seen = true + break + elseif fo > offset + offset = offset - typed_fieldoffset(typ, lasti) + typ = typed_fieldtype(typ, lasti) @assert Base.isconcretetype(typ) if !Base.allocatedinline(typ) legal = false end seen = true + break + end + + if fo != 0 && fo != typed_fieldoffset(typ, i-1) + lasti = i end - if !seen + end + if !seen && typed_fieldcount(typ) > 0 + offset = offset - typed_fieldoffset(typ, lasti) + typ = typed_fieldtype(typ, lasti) + @assert Base.isconcretetype(typ) + if !Base.allocatedinline(typ) legal = false end + seen = true end + if !seen + legal = false + end + end - typ2 = typ - while legal && should_recurse(typ2, value_type(arg), byref, dl) - idx, _ = first_non_ghost(typ2) - if idx != -1 - typ2 = typed_fieldtype(typ2, idx) - if Base.allocatedinline(typ2) - if byref == GPUCompiler.BITS_VALUE - continue - end + typ2 = typ + while legal && should_recurse(typ2, value_type(arg), byref, dl) + idx, _ = first_non_ghost(typ2) + if idx != -1 + typ2 = typed_fieldtype(typ2, idx) + if Base.allocatedinline(typ2) + if byref == GPUCompiler.BITS_VALUE + continue + end + legal = false + break + else + if byref != GPUCompiler.BITS_VALUE legal = false break - else - if byref != GPUCompiler.BITS_VALUE - legal = false - break - end - byref = GPUCompiler.MUT_REF - continue end + byref = GPUCompiler.MUT_REF + continue end - legal = false - break - end - if legal - return (true, typ2, byref) end + legal = false + break + end + if legal + return (true, typ2, byref) end - elseif legal && typ <: Ptr && Base.isconcretetype(typ) - return (true, eltype(typ), GPUCompiler.BITS_VALUE) end end diff --git a/src/compiler.jl b/src/compiler.jl index e35fd8b6a1..64e6703cd6 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4008,15 +4008,19 @@ function enzyme!( logic = Logic() TA = TypeAnalysis(logic, rules) - retT = - (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? - Ptr{actualRetType} : actualRetType - retTT = - ( - !isa(actualRetType, Union) && + retTT = if !isa(actualRetType, Union) && actualRetType <: Tuple && in(Any, actualRetType.parameters) - ) ? TypeTree() : typetree(retT, ctx, dl, seen) + TypeTree() + else + typeTree = typetree(actualRetType, ctx, dl, seen) + if !isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType) + typeTree = copy(typeTree) + merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) + only!(typeTree, -1) + end + typeTree + end typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values) @@ -5569,8 +5573,11 @@ function lower_convention( if RetActivity <: Const metadata(sretPtr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end - metadata(sretPtr)["enzyme_type"] = - to_md(typetree(Ptr{actualRetType}, ctx, dl, seen), ctx) + + typeTree = copy(typetree(actualRetType, ctx, dl, seen)) + merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) + only!(typeTree, -1) + metadata(sretPtr)["enzyme_type"] = to_md(typeTree, ctx) push!(wrapper_args, sretPtr) end if returnRoots && !in(1, parmsRemoved) @@ -5606,8 +5613,11 @@ function lower_convention( metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end ctx = LLVM.context(entry_f) - metadata(ptr)["enzyme_type"] = - to_md(typetree(Ptr{arg.typ}, ctx, dl, seen), ctx) + + typeTree = copy(typetree(arg.typ, ctx, dl, seen)) + merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) + only!(typeTree, -1) + metadata(ptr)["enzyme_type"] = to_md(typeTree, ctx) if LLVM.addrspace(ty) != 0 ptr = addrspacecast!(builder, ptr, ty) end @@ -5639,11 +5649,14 @@ function lower_convention( wrapparm = load!(builder, convert(LLVMType, arg.typ), wrapparm) ctx = LLVM.context(wrapparm) push!(wrapper_args, wrapparm) + typeTree = copy(typetree(arg.typ, ctx, dl, seen)) + merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) + only!(typeTree, -1) push!( parameter_attributes(wrapper_f, arg.codegen.i - sret - returnRoots), StringAttribute( "enzyme_type", - string(typetree(Base.RefValue{arg.typ}, ctx, dl, seen)), + string(typeTree), ), ) push!( @@ -6336,7 +6349,7 @@ function GPUCompiler.codegen( byref = arg.cc - rest = typetree(arg.typ, ctx, dl) + rest = copy(typetree(arg.typ, ctx, dl)) if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF # adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader @@ -6382,7 +6395,14 @@ function GPUCompiler.codegen( if llRT !== nothing && LLVM.return_type(LLVM.function_type(f)) != LLVM.VoidType() @assert !retRemoved - rest = typetree(llRT, ctx, dl) + rest = if llRT == Ptr{RT} + typeTree = copy(typetree(RT, ctx, dl)) + merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) + only!(typeTree, -1) + typeTree + else + typetree(RT, ctx, dl) + end push!(return_attributes(f), StringAttribute("enzyme_type", string(rest))) end end @@ -6972,33 +6992,18 @@ end legal, source_typ, byref = abs_typeof(inst) codegen_typ = value_type(inst) if legal - typ = if codegen_typ isa LLVM.PointerType || codegen_typ isa LLVM.IntegerType - llvm_source_typ = convert(LLVMType, source_typ; allow_boxed = true) - # pointers are used for multiple kinds of arguments - # - literal pointer values - if byref == GPUCompiler.BITS_VALUE - source_typ - elseif byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF - Ptr{source_typ} - else - msg = sprint() do io - println(io, "Enzyme illegal state") - println(io, string(f)) - println(io, "legal=", legal) - println(io, "source_typ=", source_typ) - println(io, "byref=", byref) - println(io, "llvm_source_typ=", llvm_source_typ) - println(io, "codegen_typ=", codegen_typ) - println(io, "inst=", string(inst)) - println(io, enzyme_custom_extract_mi(f)) - end - throw(AssertionError(msg)) - end + if codegen_typ isa LLVM.PointerType || codegen_typ isa LLVM.IntegerType else + @assert byref == GPUCompiler.BITS_VALUE source_typ end - ec = typetree(typ, ctx, dl, seen) + ec = typetree(source_typ, ctx, dl, seen) + if byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF + ec = copy(ec) + merge!(ec, TypeTree(API.DT_Pointer, ctx)) + only!(ec, -1) + end if isa(inst, LLVM.CallInst) LLVM.API.LLVMAddCallSiteAttribute( inst, @@ -7010,6 +7015,8 @@ end ) else metadata(inst)["enzyme_type"] = to_md(ec, ctx) + metadata(inst)["enzymejl_source_type_$(source_typ)"] = MDNode(LLVM.Metadata[]) + metadata(inst)["enzymejl_byref_$(byref)"] = MDNode(LLVM.Metadata[]) @static if VERSION < v"1.11-" else diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index de11d3c1cd..67c4776c4a 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -32,10 +32,12 @@ function inout_rule( if (direction & API.DOWN) != 0 ctx = LLVM.context(inst) dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) + rest = typetree(typ, ctx, dl) if GPUCompiler.deserves_retbox(typ) - typ = Ptr{typ} + rest = copy(rest) + merge!(rest, TypeTree(API.DT_Pointer, ctx)) + only!(rest, -1) end - rest = typetree(typ, ctx, dl) changed, legal = API.EnzymeCheckedMergeTypeTree(ret, rest) @assert legal end @@ -72,10 +74,12 @@ function inoutcopyslice_rule( if (direction & API.DOWN) != 0 ctx = LLVM.context(inst) dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) + rest = typetree(typ, ctx, dl) if GPUCompiler.deserves_retbox(typ) - typ = Ptr{typ} + rest = copy(rest) + merge!(rest, TypeTree(API.DT_Pointer, ctx)) + only!(rest, -1) end - rest = typetree(typ, ctx, dl) changed, legal = API.EnzymeCheckedMergeTypeTree(ret, rest) @assert legal end @@ -112,10 +116,12 @@ function inoutgcloaded_rule( if (direction & API.DOWN) != 0 ctx = LLVM.context(inst) dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) + rest = typetree(typ, ctx, dl) if GPUCompiler.deserves_retbox(typ) - typ = Ptr{typ} + rest = copy(rest) + merge!(rest, TypeTree(API.DT_Pointer, ctx)) + only!(rest, -1) end - rest = typetree(typ, ctx, dl) changed, legal = API.EnzymeCheckedMergeTypeTree(ret, rest) @assert legal end @@ -131,4 +137,4 @@ function inoutgcloaded_rule( @assert legal end return UInt8(false) -end \ No newline at end of file +end diff --git a/src/typetree.jl b/src/typetree.jl index adf3738625..aa7d4b08dd 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -255,6 +255,10 @@ function typetree_inner( seen::TypeTreeTable, ) where {T} tt = copy(typetree(T == UInt8 ? Nothing : T, ctx, dl, seen)) + if !allocatedinline(T) && Base.isconcretetype(T) + merge!(tt, TypeTree(API.DT_Pointer, ctx)) + only!(tt, 0) + end merge!(tt, TypeTree(API.DT_Pointer, ctx)) only!(tt, -1) return tt @@ -264,13 +268,8 @@ end function typetree_inner(::Type{<:Array{T}}, ctx, dl, seen::TypeTreeTable) where {T} offset = 0 - tt = copy(typetree(T, ctx, dl, seen)) - if !allocatedinline(T) && Base.isconcretetype(T) - merge!(tt, TypeTree(API.DT_Pointer, ctx)) - only!(tt, 0) - end - merge!(tt, TypeTree(API.DT_Pointer, ctx)) - only!(tt, offset) + tt = copy(typetree(Ptr{T}, ctx, dl, seen)) + shift!(tt, dl, 0, sizeof(Int), offset) offset += sizeof(Ptr{Cvoid}) @@ -291,14 +290,8 @@ else dl, seen::TypeTreeTable, ) where {kind,T} - offset = 0 - tt = copy(typetree(T, ctx, dl, seen)) - if !allocatedinline(T) && Base.isconcretetype(T) - merge!(tt, TypeTree(API.DT_Pointer, ctx)) - only!(tt, 0) - end - merge!(tt, TypeTree(API.DT_Pointer, ctx)) - only!(tt, sizeof(Csize_t)) + tt = copy(typetree(Ptr{T}, ctx, dl, seen)) + shift!(tt, dl, 0, sizeof(Int), sizeof(Csize_t)) for i = 0:(sizeof(Csize_t)-1) merge!(tt, TypeTree(API.DT_Integer, i, ctx)) @@ -312,14 +305,8 @@ else dl, seen::TypeTreeTable, ) where {kind,T} - offset = 0 - tt = copy(typetree(T, ctx, dl, seen)) - if !allocatedinline(T) && Base.isconcretetype(T) - Enzyme.merge!(tt, TypeTree(API.DT_Pointer, ctx)) - only!(tt, 0) - end - Enzyme.merge!(tt, TypeTree(API.DT_Pointer, ctx)) - only!(tt, 0) + tt = copy(typetree(Ptr{T}, ctx, dl, seen)) + shift!(tt, dl, 0, sizeof(Int), 0) for f = 2:fieldcount(AT) offset = fieldoffset(AT, f) diff --git a/src/utils.jl b/src/utils.jl index 9492dccdd6..1cb3504725 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -336,8 +336,57 @@ export my_methodinstance @static if VERSION < v"1.11-" +# JL_EXTENSION typedef struct { +# JL_DATA_TYPE +# void *data; +# #ifdef STORE_ARRAY_LEN (just true new newer versions) +# size_t length; +# #endif +# jl_array_flags_t flags; +# uint16_t elsize; // element size including alignment (dim 1 memory stride) +# uint32_t offset; // for 1-d only. does not need to get big. +# size_t nrows; +# union { +# // 1d +# size_t maxsize; +# // Nd +# size_t ncols; +# }; +# // other dim sizes go here for ndims > 2 +# +# // followed by alignment padding and inline data, or owner pointer +# } jl_array_t; @inline function typed_fieldtype(@nospecialize(T::Type), i::Int) - fieldtype(T, i) + if T <: Array + eT = eltype(T) + PT = Ptr{eT} + return (PT, Csize_t, UInt16, UInt16, UInt32, Csize_t, Csize_t)[i] + else + fieldtype(T, i) + end +end + +@inline function typed_fieldcount(@nospecialize(T::Type)) + if T <: Array + return 7 + else + fieldcount(T) + end +end + +@inline function typed_fieldoffset(@nospecialize(T::Type), i::Int) + if T <: Array + tys = (Ptr, Csize_t, UInt16, UInt16, UInt32, Csize_t, Csize_t) + sum = 0 + idx = 1 + while idx < i + sum += sizeof(tys[idx]) + idx+=1 + end + return sum + else + fieldoffset(T, i) + end end else @@ -355,9 +404,19 @@ else end end +@inline function typed_fieldcount(@nospecialize(T::Type)) + fieldcount(T) +end + +@inline function typed_fieldoffset(@nospecialize(T::Type), i::Int) + fieldoffset(T, i) +end + end export typed_fieldtype +export typed_fieldcount +export typed_fieldoffset # returns the inner type of an sret/enzyme_sret/enzyme_sret_v function sret_ty(fn::LLVM.Function, idx::Int) diff --git a/test/absint.jl b/test/absint.jl new file mode 100644 index 0000000000..ce17ca7f19 --- /dev/null +++ b/test/absint.jl @@ -0,0 +1,17 @@ +using Enzyme, Test + +struct BufferedMap!{X} + x_buffer::Vector{X} +end + +function (bc::BufferedMap!)() + return @inbounds bc.x_buffer[1][1] +end + + +@testset "Internal tests" begin + f = BufferedMap!([[2.7]]) + df = BufferedMap!([[3.1]]) + + @test autodiff(Forward, Duplicated(f, df))[1] ≈ 3.1 +end diff --git a/test/runtests.jl b/test/runtests.jl index d00ab08ca3..3b338eff18 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -82,6 +82,7 @@ include("kwrrules.jl") include("internal_rules.jl") include("ruleinvalidation.jl") include("typeunstable.jl") +include("absint.jl") @static if !Sys.iswindows() include("blas.jl") diff --git a/test/typetree.jl b/test/typetree.jl index 074103ea7c..7ea243a4d1 100644 --- a/test/typetree.jl +++ b/test/typetree.jl @@ -76,6 +76,7 @@ end "{[0]:Pointer, [0,0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,0]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,0]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,0]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" @static if VERSION < v"1.11-" + @test tt(Vector{Vector{Float32}}) == "{[0]:Pointer, [0,0]:Pointer, [0,0,0]:Pointer, [0,0,0,-1]:Float@float, [0,0,8]:Integer, [0,0,9]:Integer, [0,0,10]:Integer, [0,0,11]:Integer, [0,0,12]:Integer, [0,0,13]:Integer, [0,0,14]:Integer, [0,0,15]:Integer, [0,0,16]:Integer, [0,0,17]:Integer, [0,0,18]:Integer, [0,0,19]:Integer, [0,0,20]:Integer, [0,0,21]:Integer, [0,0,22]:Integer, [0,0,23]:Integer, [0,0,24]:Integer, [0,0,25]:Integer, [0,0,26]:Integer, [0,0,27]:Integer, [0,0,28]:Integer, [0,0,29]:Integer, [0,0,30]:Integer, [0,0,31]:Integer, [0,0,32]:Integer, [0,0,33]:Integer, [0,0,34]:Integer, [0,0,35]:Integer, [0,0,36]:Integer, [0,0,37]:Integer, [0,0,38]:Integer, [0,0,39]:Integer, [8]:Integer, [9]:Integer, [10]:Integer, [11]:Integer, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Integer, [25]:Integer, [26]:Integer, [27]:Integer, [28]:Integer, [29]:Integer, [30]:Integer, [31]:Integer, [32]:Integer, [33]:Integer, [34]:Integer, [35]:Integer, [36]:Integer, [37]:Integer, [38]:Integer, [39]:Integer}" else @test tt(MemoryRef{Float32}) == "{[-1]:Pointer, [0,-1]:Float@float, [8,0]:Integer, [8,1]:Integer, [8,2]:Integer, [8,3]:Integer, [8,4]:Integer, [8,5]:Integer, [8,6]:Integer, [8,7]:Integer, [8,8]:Pointer, [8,8,-1]:Float@float}" end From 19e3ccb9c58eb12e5bc68478b1b22645d4bb8366 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 5 Nov 2024 14:30:07 -0600 Subject: [PATCH 419/495] Fix offset size for absint load (#2064) * Fix offset size for absint load * With test --- src/absint.jl | 6 ++++-- test/absint.jl | 8 +++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index b706d4134a..2ae9957103 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -507,12 +507,14 @@ function abs_typeof( if legal && typ <: Ptr && Base.isconcretetype(typ) && byref == GPUCompiler.BITS_VALUE ET = eltype(typ) - sz = actual_size(ET) - offset %= sz byref = GPUCompiler.MUT_REF typ = ET if !Base.allocatedinline(typ) shouldLoad = false + offset %= sizeof(Int) + else + sz = actual_size(ET) + offset %= sz end end diff --git a/test/absint.jl b/test/absint.jl index ce17ca7f19..ca2ad8f502 100644 --- a/test/absint.jl +++ b/test/absint.jl @@ -9,9 +9,15 @@ function (bc::BufferedMap!)() end -@testset "Internal tests" begin +@testset "Absint struct vector of vector" begin f = BufferedMap!([[2.7]]) df = BufferedMap!([[3.1]]) @test autodiff(Forward, Duplicated(f, df))[1] ≈ 3.1 end + +@testset "Absint sum vector of vector" begin + a = [[2.7]] + da = [[3.1]] + @test autodiff(Forward, sum, Duplicated(a, da))[1] ≈ [3.1] +end From b668a13144bb300fe5b3bb0ee7842102b44a3a2a Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 5 Nov 2024 15:57:17 -0600 Subject: [PATCH 420/495] Fix idiv (#2065) --- src/absint.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/absint.jl b/src/absint.jl index 2ae9957103..7d96d26f3d 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -513,7 +513,7 @@ function abs_typeof( shouldLoad = false offset %= sizeof(Int) else - sz = actual_size(ET) + sz = max(1, actual_size(ET)) offset %= sz end end From 5e0548c7ca6611ae6fc01e1adee69ed99a205423 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 5 Nov 2024 15:57:28 -0600 Subject: [PATCH 421/495] Adapt 1.11 to ptr of generic (#2066) --- src/utils.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 1cb3504725..0e23ca486b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -394,11 +394,7 @@ else @inline function typed_fieldtype(@nospecialize(T::Type), i::Int) if T <: GenericMemoryRef && i == 1 || T <: GenericMemory && i == 2 eT = eltype(T) - if !allocatedinline(eT) && Base.isconcretetype(eT) - Ptr{Ptr{eT}} - else - Ptr{eT} - end + Ptr{eT} else fieldtype(T, i) end From b9652d2b481a671e9cceec94f1fd9e89304d237f Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 6 Nov 2024 01:00:12 -0600 Subject: [PATCH 422/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a0f5a7190e..7608975ecb 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.4, 0.8.5" -Enzyme_jll = "0.0.162" +Enzyme_jll = "0.0.163" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From 1c09430159c4abd5863a9c1a45d474841a7f4fd4 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 7 Nov 2024 16:43:52 +0100 Subject: [PATCH 423/495] Add integration tests for Bijectors (#2037) * Add integration tests for Bijectors * Remove unnecessary collect calls * Mark one more broken test * Update to which Bijectors tests to run * Bring in another fixed test --- .github/workflows/CI.yml | 5 +- test/integration/Bijectors/Project.toml | 9 + test/integration/Bijectors/runtests.jl | 208 ++++++++++++++++++ .../{ => DynamicExpressions}/Project.toml | 0 .../runtests.jl} | 0 5 files changed, 220 insertions(+), 2 deletions(-) create mode 100644 test/integration/Bijectors/Project.toml create mode 100644 test/integration/Bijectors/runtests.jl rename test/integration/{ => DynamicExpressions}/Project.toml (100%) rename test/integration/{DynamicExpressions.jl => DynamicExpressions/runtests.jl} (100%) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 0f47a044b6..c1bd30fe7e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -264,6 +264,7 @@ jobs: - ubuntu-latest test: - DynamicExpressions + - Bijectors steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -273,8 +274,8 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - name: "Run tests" run: | - julia --color=yes --project=test/integration -e 'using Pkg; Pkg.develop([PackageSpec(; path) for path in (".", "lib/EnzymeCore")]); Pkg.instantiate()' - julia --color=yes --project=test/integration --threads=auto --check-bounds=yes test/integration/${{ matrix.test }}.jl + julia --color=yes --project=test/integration/${{ matrix.test }} -e 'using Pkg; Pkg.develop([PackageSpec(; path) for path in (".", "lib/EnzymeCore")]); Pkg.instantiate()' + julia --color=yes --project=test/integration/${{ matrix.test }} --threads=auto --check-bounds=yes test/integration/${{ matrix.test }}/runtests.jl shell: bash docs: name: Documentation diff --git a/test/integration/Bijectors/Project.toml b/test/integration/Bijectors/Project.toml new file mode 100644 index 0000000000..2b8c2f46c2 --- /dev/null +++ b/test/integration/Bijectors/Project.toml @@ -0,0 +1,9 @@ +[deps] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" + +[compat] +Bijectors = "=0.13.16" +FiniteDifferences = "0.12.32" +StableRNGs = "1.0.2" diff --git a/test/integration/Bijectors/runtests.jl b/test/integration/Bijectors/runtests.jl new file mode 100644 index 0000000000..23e6136561 --- /dev/null +++ b/test/integration/Bijectors/runtests.jl @@ -0,0 +1,208 @@ +module BijectorsIntegrationTests + +using Bijectors: Bijectors +using Enzyme: Enzyme +using FiniteDifferences: FiniteDifferences +using LinearAlgebra: LinearAlgebra +using Random: randn +using StableRNGs: StableRNG +using Test: @test, @test_broken, @testset + +rng = StableRNG(23) + +""" +Enum type for choosing Enzyme autodiff modes. +""" +@enum ModeSelector Neither Forward Reverse Both + +""" +Type for specifying a test case for `Enzyme.gradient`. + +The test will check the accuracy of the gradient of `func` at `value` against `finitediff`, +with both forward and reverse mode autodiff. `name` is for diagnostic printing. +`runtime_activity`, `broken`, `skip` are for specifying whether to use +`Enzyme.set_runtime_activity` or not, whether the test is broken, and whether the test is so +broken we can't even run `@test_broken` on it (because it crashes Julia). All of them take +values `Neither`, `Forward`, `Reverse` or `Both`, to specify which mode to apply the setting +to. `splat` is for specifying whether to call the function as `func(value)` or as +`func(value...)`. + +Default values are `name=nothing`, `runtime_activity=Neither`, `broken=Neither`, +`skip=Neither`, and `splat=false`. +""" +struct TestCase + func::Function + value + name::Union{String, Nothing} + runtime_activity::ModeSelector + broken::ModeSelector + skip::ModeSelector + splat::Bool +end + +# Default values for most arguments. +function TestCase( + f, value; + name=nothing, runtime_activity=Neither, broken=Neither, skip=Neither, splat=false +) + return TestCase(f, value, name, runtime_activity, broken, skip, splat) +end + +""" +Test Enzyme.gradient, both Forward and Reverse mode, against FiniteDifferences.grad. +""" +function test_grad(case::TestCase; rtol=1e-6, atol=1e-6) + @nospecialize + f = case.func + # We'll call the function as f(x...), so wrap in a singleton tuple if need be. + x = case.splat ? case.value : (case.value,) + finitediff = FiniteDifferences.grad(FiniteDifferences.central_fdm(4, 1), f, x...)[1] + + f_mode = if (case.runtime_activity === Both || case.runtime_activity === Forward) + Enzyme.set_runtime_activity(Enzyme.Forward) + else + Enzyme.Forward + end + r_mode = if (case.runtime_activity === Both || case.runtime_activity === Reverse) + Enzyme.set_runtime_activity(Enzyme.Reverse) + else + Enzyme.Reverse + end + + if !(case.skip === Forward) && !(case.skip === Both) + if case.broken === Both || case.broken === Forward + @test_broken( + Enzyme.gradient(f_mode, Enzyme.Const(f), x...)[1] ≈ finitediff, + rtol = rtol, + atol = atol, + ) + else + @test( + Enzyme.gradient(f_mode, Enzyme.Const(f), x...)[1] ≈ finitediff, + rtol = rtol, + atol = atol, + ) + end + end + + if !(case.skip === Reverse) && !(case.skip === Both) + if case.broken === Both || case.broken === Reverse + @test_broken( + Enzyme.gradient(r_mode, Enzyme.Const(f), x...)[1] ≈ finitediff, + rtol = rtol, + atol = atol, + ) + else + @test( + Enzyme.gradient(r_mode, Enzyme.Const(f), x...)[1] ≈ finitediff, + rtol = rtol, + atol = atol, + ) + end + end + return nothing +end + +""" +A helper function that returns a TestCase that evaluates sum(bijector(inverse(bijector)(x))) +""" +function sum_b_binv_test_case( + bijector, dim; runtime_activity=Neither, name=nothing, broken=Neither, skip=Neither +) + if name === nothing + name = string(bijector) + end + b_inv = Bijectors.inverse(bijector) + return TestCase( + x -> sum(bijector(b_inv(x))), + randn(rng, dim); + runtime_activity=runtime_activity, name=name, broken=broken, skip=skip + ) +end + +@testset "Bijectors integration tests" begin + test_cases = TestCase[ + sum_b_binv_test_case(Bijectors.VecCorrBijector(), 3), + sum_b_binv_test_case(Bijectors.VecCorrBijector(), 0), + sum_b_binv_test_case(Bijectors.CorrBijector(), (3, 3)), + sum_b_binv_test_case(Bijectors.CorrBijector(), (0, 0)), + sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:L), 3), + sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:L), 0), + sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 3), + sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 0), + sum_b_binv_test_case(Bijectors.Coupling(Bijectors.Shift, Bijectors.PartitionMask(3, [1], [2])), 3), + sum_b_binv_test_case(Bijectors.InvertibleBatchNorm(3), (3, 3)), + sum_b_binv_test_case(Bijectors.LeakyReLU(0.2), 3), + sum_b_binv_test_case(Bijectors.Logit(0.1, 0.3), 3), + sum_b_binv_test_case(Bijectors.PDBijector(), (3, 3)), + sum_b_binv_test_case(Bijectors.PDVecBijector(), 3), + sum_b_binv_test_case( + Bijectors.Permute([ + 0 1 0; + 1 0 0; + 0 0 1 + ]), + (3, 3), + ), + # TODO(mhauru) Both modes broken because of + # https://github.com/EnzymeAD/Enzyme.jl/issues/2035 + sum_b_binv_test_case(Bijectors.PlanarLayer(3), (3, 3); broken=Both), + sum_b_binv_test_case(Bijectors.RadialLayer(3), 3), + sum_b_binv_test_case(Bijectors.Reshape((2, 3), (3, 2)), (2, 3)), + sum_b_binv_test_case(Bijectors.Scale(0.2), 3), + sum_b_binv_test_case(Bijectors.Shift(-0.4), 3), + sum_b_binv_test_case(Bijectors.SignFlip(), 3), + sum_b_binv_test_case(Bijectors.SimplexBijector(), 3), + sum_b_binv_test_case(Bijectors.TruncatedBijector(-0.2, 0.5), 3), + + # Below, some test cases that don't fit the sum_b_binv_test_case mold. + + TestCase( + function (x) + b = Bijectors.RationalQuadraticSpline([-0.2, 0.1, 0.5], [-0.3, 0.3, 0.9], [1.0, 0.2, 1.0]) + binv = Bijectors.inverse(b) + return sum(binv(b(x))) + end, + randn(rng); + name="RationalQuadraticSpline on scalar", + ), + + TestCase( + function (x) + b = Bijectors.OrderedBijector() + binv = Bijectors.inverse(b) + return sum(binv(b(x))) + end, + randn(rng, 7); + name="OrderedBijector", + ), + + TestCase( + function (x) + layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5]) + flow = Bijectors.transformed(Bijectors.MvNormal(zeros(2), LinearAlgebra.I), layer) + x = x[6:7] + return Bijectors.logpdf(flow.dist, x) - Bijectors.logabsdetjac(flow.transform, x) + end, + randn(rng, 7); + name="PlanarLayer7", + ), + + TestCase( + function (x) + layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5]) + flow = Bijectors.transformed(Bijectors.MvNormal(zeros(2), LinearAlgebra.I), layer) + x = reshape(x[6:end], 2, :) + return sum(Bijectors.logpdf(flow.dist, x) - Bijectors.logabsdetjac(flow.transform, x)) + end, + randn(rng, 11); + name="PlanarLayer11", + ), + ] + + @testset "$(case.name)" for case in test_cases + test_grad(case) + end +end + +end diff --git a/test/integration/Project.toml b/test/integration/DynamicExpressions/Project.toml similarity index 100% rename from test/integration/Project.toml rename to test/integration/DynamicExpressions/Project.toml diff --git a/test/integration/DynamicExpressions.jl b/test/integration/DynamicExpressions/runtests.jl similarity index 100% rename from test/integration/DynamicExpressions.jl rename to test/integration/DynamicExpressions/runtests.jl From ae171bfaa37201a3e3b39d8bae57e1cdfe7f653f Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 8 Nov 2024 12:00:19 -0600 Subject: [PATCH 424/495] Fix deletion of return (#2073) * Fix deletion of return * Update optimize.jl --- src/compiler/optimize.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 94bfd0bfef..cf57d28302 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -1942,7 +1942,7 @@ function propagate_returned!(mod::LLVM.Module) un = LLVM.user(u) push!(next, LLVM.name(LLVM.parent(LLVM.parent(un)))) end - delete_writes_into_removed_args(fn, toremove) + delete_writes_into_removed_args(fn, toremove, keepret) nm = LLVM.name(fn) #try nfn = LLVM.Function( @@ -1988,7 +1988,7 @@ function propagate_returned!(mod::LLVM.Module) end end -function delete_writes_into_removed_args(fn::LLVM.Function, toremove) +function delete_writes_into_removed_args(fn::LLVM.Function, toremove, keepret::Bool) args = collect(parameters(fn)) for tr in toremove tr = tr + 1 @@ -2032,6 +2032,9 @@ function delete_writes_into_removed_args(fn::LLVM.Function, toremove) end end end + if !keepret && LLVM.API.LLVMIsAReturnInst(cur) != C_NULL + LLVM.API.LLVMSetOperand(cur, 0, LLVM.UndefValue(value_type(cval))) + end throw(AssertionError("Deleting argument with an unknown dependency, $(string(cur)) uses $(string(cval))")) end end From 31df08b376634e2926e1d147c67284b20c493c22 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 8 Nov 2024 16:06:58 -0600 Subject: [PATCH 425/495] Add rnumber/rarray (#2075) * Add rnumber/rarray * fix * Update compiler.jl --- Project.toml | 2 +- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/EnzymeCore.jl | 56 ++++++++++++++++++++++++++++++++ src/compiler.jl | 56 +++++++++++++++++++++++++++++--- 4 files changed, 109 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 7608975ecb..cfbcd7fe76 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ EnzymeStaticArraysExt = "StaticArrays" BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.8.4, 0.8.5" +EnzymeCore = "0.8.6" Enzyme_jll = "0.0.163" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 18c3bbad00..270dd35056 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.5" +version = "0.8.6" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index c751aaac38..44166c3477 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -592,4 +592,60 @@ Return a new mode with its [`ABI`](@ref) set to the chosen type. """ function set_abi end + +""" + Primitive Type usable within Reactant. See Reactant.jl for more information. +""" +@static if isdefined(Core, :BFloat16) + const ReactantPrimitive = Union{ + Bool, + Int8, + UInt8, + Int16, + UInt16, + Int32, + UInt32, + Int64, + UInt64, + Float16, + Core.BFloat16, + Float32, + Float64, + Complex{Float32}, + Complex{Float64}, + } +else + const ReactantPrimitive = Union{ + Bool, + Int8, + UInt8, + Int16, + UInt16, + Int32, + UInt32, + Int64, + UInt64, + Float16, + Float32, + Float64, + Complex{Float32}, + Complex{Float64}, + } +end + +""" + Abstract Reactant Array type. See Reactant.jl for more information +""" +abstract type RArray{T<:ReactantPrimitive,N} <: AbstractArray{T,N} end +@inline Base.eltype(::RArray{T}) where T = T +@inline Base.eltype(::Type{<:RArray{T}}) where T = T + +""" + Abstract Reactant Number type. See Reactant.jl for more information +""" +abstract type RNumber{T<:ReactantPrimitive} <: Number end +@inline Base.eltype(::RNumber{T}) where T = T +@inline Base.eltype(::Type{<:RNumber{T}}) where T = T + + end # module EnzymeCore diff --git a/src/compiler.jl b/src/compiler.jl index 64e6703cd6..b96a432dc1 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -160,6 +160,8 @@ const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,Data end const nofreefns = Set{String}(( + "ClientGetDevice", + "BufferOnCPU", "pcre2_match_8", "julia.gcroot_flush", "pcre2_jit_stack_assign_8", @@ -311,6 +313,8 @@ const nofreefns = Set{String}(( )) const inactivefns = Set{String}(( + "ClientGetDevice", + "BufferOnCPU", "pcre2_match_data_create_from_pattern_8", "ijl_typeassert", "jl_typeassert", @@ -517,6 +521,8 @@ end end end +@inline numbereltype(::Type{<:EnzymeCore.RNumber{T}}) where {T} = T +@inline ptreltype(::Type{<:EnzymeCore.RArray{T}}) where {T} = T @inline ptreltype(::Type{Ptr{T}}) where {T} = T @inline ptreltype(::Type{Core.LLVMPtr{T,N}}) where {T,N} = T @inline ptreltype(::Type{Core.LLVMPtr{T} where N}) where {T} = T @@ -644,10 +650,21 @@ end return ActiveState end + if T <: EnzymeCore.RNumber + return active_reg_inner( + numbereltype(T), + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) + end + if T <: Ptr || T <: Core.LLVMPtr || T <: Base.RefValue || - T <: Array || + T <: Array || T <: EnzymeCore.RArray is_arrayorvararg_ty(T) if justActive return AnyState @@ -762,7 +779,7 @@ end end # if abstract it must be by reference - if Base.isabstracttype(T) + if Base.isabstracttype(T) || T == Tuple if AbstractIsMixed return MixedState else @@ -779,11 +796,11 @@ end end @assert !Base.isabstracttype(T) - if !(Base.isconcretetype(T) || (T <: Tuple && T != Tuple) || T isa UnionAll) + if !(Base.isconcretetype(T) || T <: Tuple || T isa UnionAll) throw(AssertionError("Type $T is not concrete type or concrete tuple")) end - nT = if T <: Tuple && T != Tuple && !(T isa UnionAll) + nT = if T <: Tuple && !(T isa UnionAll) Tuple{( ntuple(length(T.parameters)) do i Base.@_inline_meta @@ -1108,7 +1125,7 @@ struct Return2 end function force_recompute!(mod::LLVM.Module) - for f in functions(mod), bb in blocks(f), inst in instructions(bb) + for f in functions(mod), bb in blocks(f), inst in collect(instructions(bb)) if isa(inst, LLVM.LoadInst) has_loaded = false for u in LLVM.uses(inst) @@ -1137,8 +1154,24 @@ function force_recompute!(mod::LLVM.Module) metadata(inst)["enzyme_nocache"] = MDNode(LLVM.Metadata[]) end end + if isa(inst, LLVM.CallInst) + cf = LLVM.called_operand(inst) + if isa(cf, LLVM.Function) + if LLVM.name(cf) == "llvm.julia.gc_preserve_begin" + has_use = false + for u2 in LLVM.uses(inst) + has_use = true + break + end + if !has_use + eraseInst(bb, inst) + end + end + end + end end end + function permit_inlining!(f::LLVM.Function) for bb in blocks(f), inst in instructions(bb) # remove illegal invariant.load and jtbaa_const invariants @@ -3294,6 +3327,18 @@ function annotate!(mod, mode) end end + for fname in ( + "UnsafeBufferPointer", + ) + if haskey(funcs, fname) + for fn in funcs[fname] + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_math", "__dynamic_cast")) + end + end + end + end + for fname in ( "jl_f_getfield", "ijl_f_getfield", @@ -3301,6 +3346,7 @@ function annotate!(mod, mode) "ijl_get_nth_field_checked", "jl_f__svec_ref", "ijl_f__svec_ref", + "UnsafeBufferPointer" ) if haskey(funcs, fname) for fn in funcs[fname] From 03d70353863e245ceae1dc04a201575073509be2 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 10 Nov 2024 00:34:35 -0600 Subject: [PATCH 426/495] Add abi set for forward/rev types --- lib/EnzymeCore/src/EnzymeCore.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 44166c3477..614bedbe55 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -332,6 +332,7 @@ const ReverseHolomorphicWithPrimal = ReverseMode{true,false,DefaultABI, true, fa @inline set_err_if_func_written(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,true}() @inline clear_err_if_func_written(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,false}() +@inline set_abi(::Type{ReverseMode{ReturnPrimal,RuntimeActivity,OldABI,Holomorphic,ErrIfFuncWritten}}, ::Type{NewABI}) where {ReturnPrimal,RuntimeActivity,OldABI,Holomorphic,ErrIfFuncWritten,NewABI<:ABI} = ReverseMode{ReturnPrimal,RuntimeActivity,NewABI,Holomorphic,ErrIfFuncWritten} @inline set_abi(::ReverseMode{ReturnPrimal,RuntimeActivity,OldABI,Holomorphic,ErrIfFuncWritten}, ::Type{NewABI}) where {ReturnPrimal,RuntimeActivity,OldABI,Holomorphic,ErrIfFuncWritten,NewABI<:ABI} = ReverseMode{ReturnPrimal,RuntimeActivity,NewABI,Holomorphic,ErrIfFuncWritten}() @inline set_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,true,ABI,Holomorphic,ErrIfFuncWritten}() @@ -483,6 +484,7 @@ const ForwardWithPrimal = ForwardMode{true, DefaultABI, false, false}() @inline set_err_if_func_written(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,true,RuntimeActivity}() @inline clear_err_if_func_written(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,false,RuntimeActivity}() +@inline set_abi(::Type{ForwardMode{ReturnPrimal,OldABI,ErrIfFuncWritten,RuntimeActivity}}, ::Type{NewABI}) where {ReturnPrimal,OldABI,ErrIfFuncWritten,RuntimeActivity,NewABI<:ABI} = ForwardMode{ReturnPrimal,NewABI,ErrIfFuncWritten,RuntimeActivity} @inline set_abi(::ForwardMode{ReturnPrimal,OldABI,ErrIfFuncWritten,RuntimeActivity}, ::Type{NewABI}) where {ReturnPrimal,OldABI,ErrIfFuncWritten,RuntimeActivity,NewABI<:ABI} = ForwardMode{ReturnPrimal,NewABI,ErrIfFuncWritten,RuntimeActivity}() @inline set_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,true}() From cfb733b785425707e387d94b0fab16b215e7604d Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 10 Nov 2024 21:57:49 -0600 Subject: [PATCH 427/495] Mark reshape as nocapture (#2078) --- src/compiler.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index b96a432dc1..e8a5c6ebcb 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3704,6 +3704,16 @@ function annotate!(mod, mode) end end end + + for fname in ("jl_reshape_array", "ijl_reshape_array") + if haskey(funcs, fname) + for fn in funcs[fname] + push!(parameter_attributes(fn, 3), LLVM.EnumAttribute("readonly")) + push!(parameter_attributes(fn, 3), LLVM.EnumAttribute("nocapture")) + end + end + end + # Key of jl_eqtable_get/put is inactive, definitionally for fname in ("jl_eqtable_put", "ijl_eqtable_put") if haskey(funcs, fname) From 1c594dc39d7377779c945d917ea23b60aa8076ca Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 12 Nov 2024 22:42:11 -0600 Subject: [PATCH 428/495] Improve broadcast index analysis (#2079) * Improve broadcast index analysis * fixup * more * reduce ci err * fix * fix * fix * conditionally disable * with test --- src/compiler/interpreter.jl | 453 +++++++++++++++++++++++++++++++++++- test/optimize.jl | 12 + 2 files changed, 464 insertions(+), 1 deletion(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 51e20f8dc4..f87a63a750 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -43,6 +43,7 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter forward_rules::Bool reverse_rules::Bool deferred_lower::Bool + broadcast_rewrite::Bool handler::T end @@ -53,6 +54,7 @@ function EnzymeInterpreter( forward_rules::Bool, reverse_rules::Bool, deferred_lower::Bool = true, + broadcast_rewrite::Bool = true, handler = nothing ) @assert world <= Base.get_world_counter() @@ -79,6 +81,7 @@ function EnzymeInterpreter( forward_rules, reverse_rules, deferred_lower, + broadcast_rewrite, handler ) end @@ -89,8 +92,9 @@ EnzymeInterpreter( world::UInt, mode::API.CDerivativeMode, deferred_lower::Bool = true, + broadcast_rewrite::Bool = true, handler = nothing -) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, handler) +) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, broadcast_rewrite, handler) Core.Compiler.InferenceParams(interp::EnzymeInterpreter) = interp.inf_params Core.Compiler.OptimizationParams(interp::EnzymeInterpreter) = interp.opt_params @@ -350,6 +354,404 @@ else end +# julia> @btime Base.copyto!(dst, src); +# 668.438 ns (0 allocations: 0 bytes) + +# inp = rand(2,3,4,5); +# src = Base.Broadcast.preprocess(inp, convert(Base.Broadcast.Broadcasted{Nothing}, Base.Broadcast.instantiate(Base.broadcasted(Main.sin, inp)))); +# +# idx = Base.eachindex(src); +# +# src2 = sin.(inp); +# +# dst = zero(inp); +# lindex_v1(idx, dst, src); +# @assert dst == sin.(inp) +# +# dst = zero(inp); +# lindex_v1(idx, dst, src2); +# @assert dst == sin.(inp) +# +# @btime lindex_v1(idx, dst, src) +# # 619.140 ns (0 allocations: 0 bytes) +# +# @btime lindex_v1(idx, dst, src2) +# # 153.258 ns (0 allocations: 0 bytes) + +@generated function lindex_v1(idx::BC2, dest, src) where BC2 + if BC2 <: Base.CartesianIndices + nloops = BC2.parameters[1] + exprs = Expr[] + tot = :true + idxs = Symbol[] + lims = Symbol[] + for i in 1:nloops + sym = Symbol("lim_$i") + push!(lims, sym) + sidx = Symbol("idx_$i") + push!(idxs, sidx) + push!(exprs, quote + $sym = idx.indices[$i].stop + end) + if tot == :true + tot = quote $sym != 0 end + else + tot = quote $tot && ($sym != 0) end + end + end + + loops = quote + @inbounds dest[$(idxs...)] = @inbounds Base.Broadcast._broadcast_getindex(src, Base.CartesianIndex($(idxs...))) + end + + # for (sidx, lim) in zip(reverse(idxs), reverse(lims)) + for (sidx, lim) in zip(idxs, lims) + loops = quote + let $sidx = 0 + @inbounds while true + $sidx += 1 + $loops + if $sidx == $lim + break + end + $(Expr(:loopinfo, Symbol("julia.simdloop"), nothing)) # Mark loop as SIMD loop + end + end + end + end + + return quote + Base.@_inline_meta + $(exprs...) + if $tot + $loops + end + end + else + return quote + Base.@_inline_meta + @inbounds @simd for I in idx + dest[I] = src[I] + end + end + end +end + +# inp = rand(2,3,4,5); +# # inp = [2.0 3.0; 4.0 5.0; 7.0 9.0] +# src = Base.Broadcast.preprocess(inp, convert(Base.Broadcast.Broadcasted{Nothing}, Base.Broadcast.instantiate(Base.broadcasted(Main.sin, inp)))); +# +# idx = Base.eachindex(src); +# +# src2 = sin.(inp); +# +# dst = zero(inp); +# lindex_v2(idx, dst, src); +# @assert dst == sin.(inp) +# +# dst = zero(inp); +# lindex_v2(idx, dst, src2); +# @assert dst == sin.(inp) +# +# @btime lindex_v2(idx, dst, src) +# # 1.634 μs (0 allocations: 0 bytes) +# +# @btime lindex_v2(idx, dst, src2) +# # 1.617 μs (0 allocations: 0 bytes) +@generated function lindex_v2(idx::BC2, dest, src, ::Val{Checked}=Val(true)) where {BC2, Checked} + if BC2 <: Base.CartesianIndices + nloops = BC2.parameters[1] + exprs = Union{Expr,Symbol}[] + tot = :true + idxs = Symbol[] + lims = Symbol[] + + total = :1 + for i in 1:nloops + sym = Symbol("lim_$i") + push!(lims, sym) + sidx = Symbol("idx_$i") + push!(idxs, sidx) + push!(exprs, quote + $sym = idx.indices[$i].stop + end) + if tot == :true + tot = quote $sym != 0 end + total = sym + else + tot = quote $tot && ($sym != 0) end + total = quote $total * $sym end + end + end + + push!(exprs, quote total = $total end) + + lexprs = Expr[] + + if Checked + for (lidx, lim) in zip(idxs, lims) + push!(lexprs, quote + $lidx = Base.urem_int(tmp, $lim) + 1 + tmp = Base.udiv_int(tmp, $lim) + end) + end + else + idxs = [quote I+1 end] + end + + return quote + Base.@_inline_meta + $(exprs...) + if $tot + let I = 0 + @inbounds while true + let tmp = I + $(lexprs...) + @inbounds dest[I+1] = @inbounds Base.Broadcast._broadcast_getindex(src, Base.CartesianIndex($(idxs...))) + end + I += 1 + if I == total + break + end + $(Expr(:loopinfo, Symbol("julia.simdloop"), nothing)) # Mark loop as SIMD loop + end + end + end + end + else + return quote + Base.@_inline_meta + @inbounds @simd for I in idx + dest[I] = src[I] + end + end + end +end + + +# inp = rand(2,3,4,5); +# src = Base.Broadcast.preprocess(inp, convert(Base.Broadcast.Broadcasted{Nothing}, Base.Broadcast.instantiate(Base.broadcasted(Main.sin, inp)))); +# +# idx = Base.eachindex(src); +# +# src2 = sin.(inp); +# +# dst = zero(inp); +# lindex_v3(idx, dst, src); +# @assert dst == sin.(inp) +# +# dst = zero(inp); +# lindex_v3(idx, dst, src2); +# @assert dst == sin.(inp) +# +# @btime lindex_v3(idx, dst, src) +# # 568.065 ns (0 allocations: 0 bytes) + +# @btime lindex_v3(idx, dst, src2) +# # 23.906 ns (0 allocations: 0 bytes) +@generated function lindex_v3(idx::BC2, dest, src) where BC2 + if BC2 <: Base.CartesianIndices + nloops = BC2.parameters[1] + exprs = Union{Expr,Symbol}[] + tot = :true + idxs = Symbol[] + lims = Symbol[] + + condition = :true + todo = Tuple{Type, Tuple}[(src, ())] + + function index(x, ::Tuple{}) + return x + end + + function index(x, path) + if path[1] isa Symbol + return quote + $(index(x, Base.tail(path))).$(path[1]) + end + else + return quote getindex($(index(x, Base.tail(path))), $(path[1])) end + end + end + + legal = true + while length(todo) != 0 + cur, path = pop!(todo) + if cur <: AbstractArray + if condition == :true + condition = quote idx.indices == axes($(index(:src, path))) end + else + condition = quote $condition && idx.indices == axes($(index(:src, path))) end + end + continue + end + if cur <: Base.Broadcast.Extruded + if condition == :true + condition = quote all(($(index(:src, path))).keeps) end + else + condition = quote $condition && all(($(index(:src, path))).keeps) end + end + push!(todo, (cur.parameters[1], (:x, path...))) + continue + end + if cur == src && cur <: Base.Broadcast.Broadcasted + for (i, v) in enumerate(cur.parameters[4].parameters) + push!(todo, (v, (i, :args, path...))) + end + continue + end + if cur <: AbstractFloat + continue + end + legal = false + end + + if legal + return quote + Base.@_inline_meta + if $condition + lindex_v2(idx, dest, src, Val(false)) + else + lindex_v1(idx, dest, src) + end + end + else + return quote + Base.@_inline_meta + lindex_v1(idx, dest, src) + end + end + else + return quote + Base.@_inline_meta + @inbounds @simd for I in idx + dest[I] = src[I] + end + end + end +end + +# Override Base.copyto!(dest::AbstractArray, bc::Broadcasted{Nothing}) with +# a form which provides better analysis of loop indices +@inline function override_bc_copyto!(dest::AbstractArray, bc::Base.Broadcast.Broadcasted{Nothing}) + axdest = Base.axes(dest) + axbc = Base.axes(bc) + axdest == axbc || Base.Broadcast.throwdm(axdest, axbc) + + if bc.args isa Tuple{AbstractArray} + A = bc.args[1] + if axdest == Base.axes(A) + if bc.f === Base.identity + Base.copyto!(dest, A) + return dest + end + end + end + + # The existing code is rather slow for broadcast in practice: https://github.com/EnzymeAD/Enzyme.jl/issues/1434 + src = Base.Broadcast.preprocess(dest, bc) + idx = Base.eachindex(src) + @inline Enzyme.Compiler.Interpreter.lindex_v3(idx, dest, src) + return dest +end + +@generated function same_sized(x::Tuple) + result = :true + prev = nothing + for i in 1:length(x.parameters) + if x.parameters[i] <: Number + continue + end + if prev == nothing + prev = quote + sz = size(x[$i]) + end + continue + end + if result == :true + result = quote + sz == size(x[$i]) + end + else + result = quote + result && sz == size(x[$i]) + end + end + end + return quote + Base.@_inline_meta + $prev + return $result + end +end + + +Base.@propagate_inbounds overload_broadcast_getindex(A::Union{Ref,AbstractArray{<:Any,0},Number}, I) = A[] # Scalar-likes can just ignore all indices +Base.@propagate_inbounds overload_broadcast_getindex(::Ref{Type{T}}, I) where {T} = T +# Tuples are statically known to be singleton or vector-like +Base.@propagate_inbounds overload_broadcast_getindex(A::Tuple{Any}, I) = A[1] +Base.@propagate_inbounds overload_broadcast_getindex(A::Tuple, I) = error("unhandled") # A[I[1]] +Base.@propagate_inbounds overload_broadcast_getindex(A, I) = A[I] + +@inline function override_bc_materialize(bc) + if bc.args isa Tuple{AbstractArray} && bc.f === Base.identity + return copy(bc.args[1]) + end + ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args) + dest = similar(bc, ElType) + if all(isa_array_or_number, bc.args) && same_sized(bc.args) + @inbounds @simd for I in 1:length(bc) + val = Base.Broadcast._broadcast_getindex_evalf(bc.f, map(Base.Fix2(overload_broadcast_getindex, I), bc.args)...) + dest[I] = val + end + else + Base.copyto!(dest, bc) + end + return dest +end + +struct MultiOp{Position, NumUsed, F1, F2} + f1::F1 + f2::F2 +end + +@generated function (m::MultiOp{Position, NumUsed})(args::Vararg{Any, N}) where {N, Position, NumUsed} + f2args = Union{Symbol, Expr}[] + for i in Position:(Position+NumUsed) + push!(f2args, :(args[$i])) + end + f1args = Union{Symbol, Expr}[] + for i in 1:Position + push!(f1args, :(args[$i])) + end + push!(f1args, quote + f2($(f2args...)) + end) + for i in (Position+NumUsed):N + push!(f1args, :(args[$i])) + end + return quote + Base.@_inline_meta + f1($(f1args...)) + end +end + +@inline function array_or_number(@nospecialize(Ty)) + return Ty <: AbstractArray || Ty <: Number +end + +@inline function isa_array_or_number(@nospecialize(x)) + return x isa AbstractArray || x isa Number +end + +@inline function num_or_eltype(@nospecialize(Ty)) + if Ty <: AbstractArray + eltype(Ty) + else + return Ty + end +end + function abstract_call_known( interp::EnzymeInterpreter, @nospecialize(f), @@ -384,6 +786,55 @@ function abstract_call_known( ) end end + + if interp.broadcast_rewrite + if f === Base.materialize && length(argtypes) == 2 + bcty = widenconst(argtypes[2]) + if Base.isconcretetype(bcty) && bcty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing} && all(array_or_number, bcty.parameters[4].parameters) && any(Base.Fix2(Base.:<:, AbstractArray), bcty.parameters[4].parameters) + fnty = bcty.parameters[3] + eltys = map(num_or_eltype, bcty.parameters[4].parameters) + retty = Core.Compiler._return_type(interp, Tuple{fnty, eltys...}) + if Base.isconcretetype(retty) + arginfo2 = ArgInfo( + fargs isa Nothing ? nothing : + [:(Enzyme.Compiler.Interpreter.override_bc_materialize), fargs[2:end]...], + [Core.Const(Enzyme.Compiler.Interpreter.override_bc_materialize), argtypes[2:end]...], + ) + return abstract_call_known( + interp, + Enzyme.Compiler.Interpreter.override_bc_materialize, + arginfo2, + si, + sv, + max_methods, + ) + end + end + end + + if f === Base.copyto! && length(argtypes) == 3 + # Ideally we just override uses of the AbstractArray base class, but + # I don't know how to override the method in base, without accidentally overridding + # it for say CuArray or other users. For safety, we only override for Array + if widenconst(argtypes[2]) <: Array && + widenconst(argtypes[3]) <: Base.Broadcast.Broadcasted{Nothing} + + arginfo2 = ArgInfo( + fargs isa Nothing ? nothing : + [:(Enzyme.Compiler.Interpreter.override_bc_copyto!), fargs[2:end]...], + [Core.Const(Enzyme.Compiler.Interpreter.override_bc_copyto!), argtypes[2:end]...], + ) + return abstract_call_known( + interp, + Enzyme.Compiler.Interpreter.override_bc_copyto!, + arginfo2, + si, + sv, + max_methods, + ) + end + end + end @static if VERSION < v"1.11.0-" else diff --git a/test/optimize.jl b/test/optimize.jl index 4792ac0570..d13a6ed752 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -1,6 +1,18 @@ using Enzyme, LinearAlgebra, Test using Random, Statistics +# check that our broadcast interpreter fix is correct for scalars +function bcast_sum(A) + s = 0.0 + for i in 1:3 + s += abs2.(A[i]) + end + return s +end +@testset "Broadcast interpreter" begin + @test autodiff(Forward, bcast_sum, Duplicated([1.0, 2.0, 3.0], [1.0, 2.0, 3.0]))[1] ≈ 28.0 +end + function gcloaded_fixup(dest, src) N = size(src) dat = src.data From 9a47d322a36bcd222af7d661f964868d13a60a04 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 13 Nov 2024 13:57:49 -0600 Subject: [PATCH 429/495] Update Project.toml (#2088) --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index cfbcd7fe76..131d57cd5a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.14" +version = "0.13.15" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.6" -Enzyme_jll = "0.0.163" +Enzyme_jll = "0.0.164" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From e15e7d9bd571aa5ed66656338cf0094927f769ec Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 13 Nov 2024 13:58:10 -0600 Subject: [PATCH 430/495] Cholmod struct ccall (#2086) --- src/absint.jl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 7d96d26f3d..1d5fed1403 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -207,7 +207,9 @@ end function should_recurse(@nospecialize(typ2), arg_t, byref, dl) sz = sizeof(dl, arg_t) if byref != GPUCompiler.BITS_VALUE - @assert sz == sizeof(Int) + if sz != sizeof(Int) + throw(AssertionError("non bits type $byref of $typ2 has size $sz != sizeof(Int) from arg type $arg_t")) + end return false else if actual_size(typ2) != sz @@ -509,7 +511,14 @@ function abs_typeof( ET = eltype(typ) byref = GPUCompiler.MUT_REF typ = ET - if !Base.allocatedinline(typ) + # We currently make the assumption that Ptr{T} either represents a ptr which could be generated by + # julia code (for example pointer(x) ), or is a storage container for an array / memory + # in the latter case, it may need an extra level of indirection because of boxing. It is semantically + # consistent here to consider Ptr{T} to represent the ptr to the boxed value in that case [and we essentially + # add the extra poitner offset when loading here]. However for pointers constructed by ccall outside julia + # to a julia object, which are not inline by type but appear so, like SparseArrays, this is a problem + # and merits further investigation. x/ref https://github.com/EnzymeAD/Enzyme.jl/issues/2085 + if !Base.allocatedinline(typ) && typ != SparseArrays.cholmod_dense_struct shouldLoad = false offset %= sizeof(Int) else From 75e6725c07e46352df774904c7289f45e37d6d33 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 13 Nov 2024 13:58:21 -0600 Subject: [PATCH 431/495] Also handle non intrinsic versions (#2089) * Also handle non intrinsic versions * Update Project.toml --------- Co-authored-by: William Moses --- src/compiler.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index e8a5c6ebcb..0b044d27bb 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -7480,6 +7480,10 @@ end function_attributes(wrapper_f), StringAttribute("implements", llname), ) + push!( + function_attributes(wrapper_f), + StringAttribute("implements2", n * pf) + ) end end end @@ -7586,6 +7590,10 @@ end function_attributes(wrapper_f), StringAttribute("implements", llname), ) + push!( + function_attributes(wrapper_f), + StringAttribute("implements2", n * pf) + ) end end end From 595141b55b55e1dab7ffdfbf26c2c6c5f70fa954 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Fri, 15 Nov 2024 22:03:02 -0800 Subject: [PATCH 432/495] Remove inactive_type(::Type{Function}) method (#2093) Does this break anything? --- src/internal_rules.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 539223ffa5..9c647aefca 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -123,7 +123,6 @@ Enzyme.EnzymeRules.inactive_noinl(::typeof(Core._compute_sparams), args...) = no @inline EnzymeRules.inactive_type(v::Type{Union{}}) = true @inline EnzymeRules.inactive_type(v::Type{Char}) = true @inline EnzymeRules.inactive_type(v::Type{T}) where {T<:Integer} = true -@inline EnzymeRules.inactive_type(v::Type{Function}) = true @inline EnzymeRules.inactive_type(v::Type{T}) where {T<:DataType} = true @inline EnzymeRules.inactive_type(v::Type{T}) where {T<:Module} = true @inline EnzymeRules.inactive_type(v::Type{T}) where {T<:AbstractString} = true From ba15c1cfadd18d0fa1421c42a233e55a12a59973 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Fri, 15 Nov 2024 22:03:20 -0800 Subject: [PATCH 433/495] Mark IO inactive_type (#2092) * Test make_zero with IO (issue #2091) * Mark IO inactive_type Fixes #2091 --- src/internal_rules.jl | 1 + test/abi.jl | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 9c647aefca..3f968b4cb9 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -129,6 +129,7 @@ Enzyme.EnzymeRules.inactive_noinl(::typeof(Core._compute_sparams), args...) = no @inline EnzymeRules.inactive_type(v::Type{Core.MethodMatch}) = true @inline EnzymeRules.inactive_type(v::Type{Core.Compiler.WorldRange}) = true @inline EnzymeRules.inactive_type(v::Type{Core.MethodInstance}) = true +@inline EnzymeRules.inactive_type(v::Type{T}) where {T<:IO} = true # Note all of these forward mode definitions do not support runtime activity as # the do not keep the primal if shadow(x.y) == primal(x.y) diff --git a/test/abi.jl b/test/abi.jl index 1c62741ef1..20747f2aaa 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -494,12 +494,31 @@ mutable struct ConstVal const y::Float64 end +struct WithIO{F} + v::Vector{Float64} + callback::F + function WithIO(v, io) + callback() = println(io, "hello") + return new{typeof(callback)}(v, callback) + end +end + @testset "Make Zero" begin v = ConstVal(2.0, 3.0) dv = make_zero(v) @test dv isa ConstVal @test dv.x ≈ 0.0 @test dv.y ≈ 0.0 + + f = WithIO([1.0, 2.0], stdout) + df = @test_nowarn try + # catch errors to get failed test instead of "exception outside of a @test" + make_zero(f) + catch e + showerror(stderr, e) + end + @test df.v == [0.0, 0.0] + @test df.callback === f.callback end @testset "Type inference" begin From 9a0c9461c39d38e036b01e255940bf5ea9f725ac Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 16 Nov 2024 01:48:21 -0500 Subject: [PATCH 434/495] 1.10 symbol res fix (#2087) * 1.10 symbol res fix * Update validation.jl * Update validation.jl --- src/compiler/validation.jl | 144 +++++++++++++++++++------------------ 1 file changed, 75 insertions(+), 69 deletions(-) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 0ce2e61eb1..8ee16730c3 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -946,88 +946,94 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) LLVM.API.LLVMInstructionEraseFromParent(inst) else - if fn == "jl_lazy_load_and_lookup" - res = ccall( - :jl_lazy_load_and_lookup, - Ptr{Cvoid}, - (Any, Cstring), - flib, - fname, - ) - else - res = ccall( - :ijl_lazy_load_and_lookup, - Ptr{Cvoid}, - (Any, Cstring), - flib, - fname, - ) + res = try + if fn == "jl_lazy_load_and_lookup" + ccall( + :jl_lazy_load_and_lookup, + Ptr{Cvoid}, + (Any, Cstring), + flib, + fname, + ) + else + ccall( + :ijl_lazy_load_and_lookup, + Ptr{Cvoid}, + (Any, Cstring), + flib, + fname, + ) + end + catch + nothing end - - replaceWith = - LLVM.ConstantInt(LLVM.IntType(8 * sizeof(Int)), reinterpret(UInt, res)) - for u in LLVM.uses(inst) - st = LLVM.user(u) - if isa(st, LLVM.StoreInst) && - LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 0)) == inst - ptr = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 1)) - for u in LLVM.uses(ptr) - ld = LLVM.user(u) - if isa(ld, LLVM.LoadInst) - b = IRBuilder() - position!(b, ld) - for u in LLVM.uses(ld) - u = LLVM.user(u) - if isa(u, LLVM.CallInst) - push!(calls, u) + + if res != nothing + replaceWith = + LLVM.ConstantInt(LLVM.IntType(8 * sizeof(Int)), reinterpret(UInt, res)) + for u in LLVM.uses(inst) + st = LLVM.user(u) + if isa(st, LLVM.StoreInst) && + LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 0)) == inst + ptr = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 1)) + for u in LLVM.uses(ptr) + ld = LLVM.user(u) + if isa(ld, LLVM.LoadInst) + b = IRBuilder() + position!(b, ld) + for u in LLVM.uses(ld) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end end + replace_uses!( + ld, + LLVM.inttoptr!(b, replaceWith, value_type(inst)), + ) end - replace_uses!( - ld, - LLVM.inttoptr!(b, replaceWith, value_type(inst)), - ) end end end - end - - b = IRBuilder() - position!(b, inst) - replacement = LLVM.inttoptr!(b, replaceWith, value_type(inst)) - for u in LLVM.uses(inst) - u = LLVM.user(u) - if isa(u, LLVM.CallInst) - push!(calls, u) - end - if isa(u, LLVM.PHIInst) - if all( - x -> first(x) == inst || first(x) == replacement, - LLVM.incoming(u), - ) - - for u in LLVM.uses(u) - u = LLVM.user(u) - if isa(u, LLVM.CallInst) - push!(calls, u) - end - if isa(u, LLVM.BitCastInst) - for u1 in LLVM.uses(u) - u1 = LLVM.user(u1) - if isa(u1, LLVM.CallInst) - push!(calls, u1) + + b = IRBuilder() + position!(b, inst) + replacement = LLVM.inttoptr!(b, replaceWith, value_type(inst)) + for u in LLVM.uses(inst) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end + if isa(u, LLVM.PHIInst) + if all( + x -> first(x) == inst || first(x) == replacement, + LLVM.incoming(u), + ) + + for u in LLVM.uses(u) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end + if isa(u, LLVM.BitCastInst) + for u1 in LLVM.uses(u) + u1 = LLVM.user(u1) + if isa(u1, LLVM.CallInst) + push!(calls, u1) + end end + replace_uses!( + u, + LLVM.inttoptr!(b, replaceWith, value_type(u)), + ) end - replace_uses!( - u, - LLVM.inttoptr!(b, replaceWith, value_type(u)), - ) end end end end + replace_uses!(inst, replacement) + LLVM.API.LLVMInstructionEraseFromParent(inst) end - replace_uses!(inst, replacement) - LLVM.API.LLVMInstructionEraseFromParent(inst) end elseif fn == "julia.call" || fn == "julia.call2" dest = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(inst, 0)) From f034de4b4e21ba623143f862b3eeb16092a74612 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 16 Nov 2024 03:19:09 -0500 Subject: [PATCH 435/495] Type fix (#2097) --- src/compiler/optimize.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index cf57d28302..0b5c4ce683 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -1139,6 +1139,9 @@ function nodecayed_phis!(mod::LLVM.Module) GTy ) nphi = call!(nb, GTy, gcloaded, LLVM.Value[base_obj, nphi]) + if value_type(nphi) != ty + nphi = bitcast!(nb, nphi, ty) + end end else nphi = addrspacecast!(nb, nphi, ty) From 41ee9cd451a34652b91c959f8fe690c147e5de6f Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 16 Nov 2024 03:20:47 -0500 Subject: [PATCH 436/495] Disable multiarg (#2096) --- test/runtests.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 3b338eff18..d9b7dc97ac 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1702,16 +1702,16 @@ end R = zeros(6,6) dR = zeros(6, 6) - @static if VERSION ≥ v"1.10-" - @test_broken autodiff(Reverse, whocallsmorethan30args, Active, Duplicated(R, dR)) + @static if VERSION ≥ v"1.11-" else - autodiff(Reverse, whocallsmorethan30args, Active, Duplicated(R, dR)) - @test 1.0 ≈ dR[1, 1] - @test 1.0 ≈ dR[2, 2] - @test 1.0 ≈ dR[3, 3] - @test 1.0 ≈ dR[4, 4] - @test 1.0 ≈ dR[5, 5] - @test 0.0 ≈ dR[6, 6] + @test_broken autodiff(Reverse, whocallsmorethan30args, Active, Duplicated(R, dR)) + # autodiff(Reverse, whocallsmorethan30args, Active, Duplicated(R, dR)) + # @test 1.0 ≈ dR[1, 1] + # @test 1.0 ≈ dR[2, 2] + # @test 1.0 ≈ dR[3, 3] + # @test 1.0 ≈ dR[4, 4] + # @test 1.0 ≈ dR[5, 5] + # @test 0.0 ≈ dR[6, 6] end end From 292f43d85ddb4e44c11c14ac8d4bfc9852aea685 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 16 Nov 2024 11:49:29 -0500 Subject: [PATCH 437/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 131d57cd5a..a70af4338e 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.6" -Enzyme_jll = "0.0.164" +Enzyme_jll = "0.0.165" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From b63f36f64f4f4b96e28e2e963817c024cdb22d86 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 16 Nov 2024 11:49:45 -0500 Subject: [PATCH 438/495] batch active ret (#2098) * batch active ret * Update rrules.jl * Update rrules.jl --- src/rules/customrules.jl | 4 ++++ test/rrules.jl | 26 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 96661849b2..03611e7d26 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -1179,6 +1179,10 @@ end insert!(args, tape_idx, tape) end if RT <: Active + if width != 1 + emit_error(B, orig, "Not yet supported: Enzyme custom rule of batch size=$width, and active return $RT") + return tapeV + end llty = convert(LLVMType, RT) diff --git a/test/rrules.jl b/test/rrules.jl index cd41b49716..3c6bf558a4 100644 --- a/test/rrules.jl +++ b/test/rrules.jl @@ -61,6 +61,32 @@ end @test dx ≈ [102.0] end +function augmented_primal(config::RevConfigWidth{2}, func::Const{typeof(f)}, ::Type{<:Active}, x::Active) + if needs_primal(config) + return AugmentedReturn(func.val(x.val), nothing, nothing) + else + return AugmentedReturn(nothing, nothing, nothing) + end +end + +function reverse(config::RevConfigWidth{2}, ::Const{typeof(f)}, dret::Active, tape, x::Active) + return ((10+2*x.val*dret.val,100+2*x.val*dret.val,)) +end + +function fip_2(out, in) + out[] = f(in[]) + nothing +end + +@testset "Batch ActiveReverse Rules" begin + out = BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(3.0))) + in = BatchDuplicated(Ref(2.0), (Ref(0.0), Ref(0.0))) + # TODO: Not yet supported: Enzyme custom rule of batch size=2, and active return EnzymeCore.Active{Float64} + @test_throws Enzyme.Compiler.EnzymeRuntimeException Enzyme.autodiff(Enzyme.Reverse, fip_2, out, in) + @test_broken in.dvals[1][] ≈ 104.0 + @test_broken in.dvals[1][] ≈ 42.0 +end + function alloc_sq(x) return Ref(x*x) end From 5e6a82dd08e74666822b9d7b2b46c36b075668ca Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 17 Nov 2024 16:28:07 -0500 Subject: [PATCH 439/495] Fix embarassing bug on keepret (#2101) --- src/compiler/optimize.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 0b5c4ce683..8c42ee2b55 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -2037,7 +2037,8 @@ function delete_writes_into_removed_args(fn::LLVM.Function, toremove, keepret::B end if !keepret && LLVM.API.LLVMIsAReturnInst(cur) != C_NULL LLVM.API.LLVMSetOperand(cur, 0, LLVM.UndefValue(value_type(cval))) - end + continue + end throw(AssertionError("Deleting argument with an unknown dependency, $(string(cur)) uses $(string(cval))")) end end From be7984735939b819417c23a28d1671b0c398a7e2 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Mon, 18 Nov 2024 09:43:08 -0500 Subject: [PATCH 440/495] Handle case for SparseMatrixCSC when output is Const (#2100) * Fix sparse matmul rule when output is const * Remove FiniteDifferences * Simplify because we can never be in a BatchDuplicated state * Handle if B is BatchDuplicated but C is Const --- src/internal_rules.jl | 29 +++++++++++++++++++++++++++++ test/internal_rules.jl | 14 ++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 3f968b4cb9..71c700e73a 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -844,6 +844,35 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, dCs[i] .*= β.val end end + else + # C is constant so there is no gradient information to compute + + dα = if !isa(α, Const) + if N == 1 + zero(α.val) + else + ntuple(Val(N)) do i + Base.@_inline_meta + zero(α.val) + end + end + else + nothing + end + + + dβ = if !isa(β, Const) + if N == 1 + zero(β.val) + else + ntuple(Val(N)) do i + Base.@_inline_meta + zero(β.val) + end + end + else + nothing + end end return (nothing, nothing, nothing, dα, dβ) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index fb5926a1d3..a91ddaa620 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -734,6 +734,14 @@ end are_activities_compatible(Tret, Tret, Tv) || continue test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) end + + # Test with a const output and active α and β + (_,_,_,dα, dβ), = autodiff(Reverse, LinearAlgebra.mul!, Const, Const(C), Const(M), Const(v), Active(α), Active(β)) + @test dα ≈ 0 + @test dβ ≈ 0 + + + end @testset "SparseArrays spmatmat reverse rule" begin @@ -754,6 +762,12 @@ end are_activities_compatible(Tret, Tv) || continue test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) end + + # Test with a const output and active α and β + (_,_,_,dα, dβ), = autodiff(Reverse, LinearAlgebra.mul!, Const, Const(C), Const(M), Const(v), Active(α), Active(β)) + @test dα ≈ 0 + @test dβ ≈ 0 + end end # InternalRules From 5185b0f36c4ef74cbf2b7b0d3663c82aa8cc3333 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 18 Nov 2024 13:53:38 -0500 Subject: [PATCH 441/495] Speed up split (#2107) --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 0b044d27bb..67d3db2836 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1316,7 +1316,7 @@ function removed_ret_parms(F::LLVM.Function) end if parmrem !== nothing str = value(parmrem) - for v in split(str, ",") + for v in eachsplit(str, ",") push!(parmsRemoved, parse(UInt64, v)) end end From ba4c22a1fa99676f2c9abc0e81baeec453e32e53 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 18 Nov 2024 13:57:01 -0500 Subject: [PATCH 442/495] Update interpreter.jl (#2106) --- src/compiler/interpreter.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index f87a63a750..22761c2d1a 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -674,7 +674,7 @@ end end else result = quote - result && sz == size(x[$i]) + $result && sz == size(x[$i]) end end end From 9c6899c289d68525c098ba56a9d353bd820a3a12 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 18 Nov 2024 15:31:03 -0500 Subject: [PATCH 443/495] Handle mixed return of unstable (#2104) * Handle mixed return of unstable * with test * Update mixed.jl * Update mixed.jl --- src/Enzyme.jl | 19 +++++----- src/compiler.jl | 65 +++++++++++++++++++++++--------- src/rules/jitrules.jl | 86 +++++++++++++++++++++++-------------------- test/mixed.jl | 41 ++++++++++++++++++++- 4 files changed, 143 insertions(+), 68 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 2e8643744b..17ec83980b 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -682,16 +682,15 @@ code, as well as high-order differentiation. if A isa UnionAll rt = Compiler.primal_return_type(rmode, Val(world), FTy, tt) - rt = Core.Compiler.return_type(f.val, tt) - A2 = A{rt} - if rt == Union{} - throw(ErrorException("Return type inferred to be Union{}. Giving up.")) - end - else - @assert A isa DataType - rt = A - if rt == Union{} - throw(ErrorException("Return type inferred to be Union{}. Giving up.")) + A2 = A{rt} + if rt == Union{} + throw(ErrorException("Return type inferred to be Union{}. Giving up.")) + end + else + @assert A isa DataType + rt = A + if rt == Union{} + throw(ErrorException("Return type inferred to be Union{}. Giving up.")) end end diff --git a/src/compiler.jl b/src/compiler.jl index 67d3db2836..2b7d441c27 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4491,12 +4491,17 @@ function create_abi_wrapper( push!(sret_types, AnonymousStruct(NTuple{width,literal_rt})) end elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + rty = if Base.isconcretetype(literal_rt) + Base.RefValue{literal_rt} + else + (Base.RefValue{T} where T <: literal_rt) + end if width == 1 - push!(sret_types, Base.RefValue{literal_rt}) + push!(sret_types, rty) else push!( sret_types, - AnonymousStruct(NTuple{width,Base.RefValue{literal_rt}}), + AnonymousStruct(NTuple{width,rty}), ) end end @@ -4633,6 +4638,7 @@ function create_abi_wrapper( convty = convert(LLVMType, T′; allow_boxed = true) if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType)) + @assert Base.isconcretetype(T′) al0 = al = emit_allocobj!(builder, Base.RefValue{T′}, "mixedparameter") al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) store!(builder, params[i], al) @@ -4692,6 +4698,7 @@ function create_abi_wrapper( parmsi = params[i] if T <: BatchMixedDuplicated + @assert Base.isconcretetype(T′) if GPUCompiler.deserves_argbox(NTuple{width,Base.RefValue{T′}}) njlvalue = LLVM.ArrayType(Int(width), T_prjlvalue) parmsi = bitcast!( @@ -4812,26 +4819,37 @@ function create_abi_wrapper( for idx = 1:width pv = (width == 1) ? eval : extract_value!(builder, eval, idx - 1) - al0 = + irt = eltype(rettype) + ires = if Base.isconcretetype(irt) al = emit_allocobj!( builder, Base.RefValue{eltype(rettype)}, "batchmixedret", ) - llty = value_type(pv) - al = bitcast!( - builder, - al, - LLVM.PointerType(llty, addrspace(value_type(al))), - ) - store!(builder, pv, al) - emit_writebarrier!( - builder, - get_julia_inner_types(builder, al0, pv), - ) + al0 = al + llty = value_type(pv) + al = bitcast!( + builder, + al, + LLVM.PointerType(llty, addrspace(value_type(al))), + ) + store!(builder, pv, al) + emit_writebarrier!( + builder, + get_julia_inner_types(builder, al0, pv), + ) + al0 + else + # emit_allocobj!( + # builder, + # emit_apply_type!(builder, Base.RefValue, [emit_jltypeof!(builder, pv)]), + # "batchmixedret", + # ) + pv + end ival = - (width == 1) ? al0 : - insert_value!(builder, ival, al0, idx - 1) + (width == 1) ? ires : + insert_value!(builder, ival, ires, idx - 1) end eval = ival end @@ -8223,11 +8241,21 @@ end if rettype <: Duplicated || rettype <: DuplicatedNoNeed push!(sret_types, jlRT) elseif rettype <: MixedDuplicated - push!(sret_types, Base.RefValue{jlRT}) + rty = if Base.isconcretetype(jlRT) + Base.RefValue{jlRT} + else + (Base.RefValue{T} where T <: jlRT) + end + push!(sret_types, rty) elseif rettype <: BatchDuplicated || rettype <: BatchDuplicatedNoNeed push!(sret_types, AnonymousStruct(NTuple{width,jlRT})) elseif rettype <: BatchMixedDuplicated - push!(sret_types, AnonymousStruct(NTuple{width,Base.RefValue{jlRT}})) + rty = if Base.isconcretetype(jlRT) + Base.RefValue{jlRT} + else + (Base.RefValue{T} where T <: jlRT) + end + push!(sret_types, AnonymousStruct(NTuple{width,rty})) elseif CC <: AugmentedForwardThunk push!(sret_types, Nothing) elseif rettype <: Const @@ -8363,6 +8391,7 @@ end @assert length(types) == length(ccexprs) + if !(GPUCompiler.isghosttype(PT) || Core.Compiler.isconstType(PT)) return quote Base.@_inline_meta diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index bf98aaf885..546231c3f4 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -438,32 +438,33 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) args = ($(wrapped...),) $(MakeTypes...) - FT = Core.Typeof(f) - dupClosure0 = if ActivityTup[1] - !guaranteed_const(FT) - else - false - end - - tt = Tuple{$(ElTypes...)} - rt = Core.Compiler.return_type(f, tt) - annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - - annotationA = if $Width != 1 && annotation0 <: Duplicated - BatchDuplicated{rt,$Width} - elseif $Width != 1 && annotation0 <: MixedDuplicated - BatchMixedDuplicated{rt,$Width} - else - annotation0 - end - internal_tape, origRet, initShadow, annotation = if f isa typeof(Core.getglobal) gv = Core.getglobal(args[1].val, args[2].val) @assert sizeof(gv) == 0 (nothing, gv, nothing, Const) else + FT = Core.Typeof(f) + tt = Tuple{$(ElTypes...)} world = codegen_world_age(FT, tt) + dupClosure0 = if ActivityTup[1] + !guaranteed_const(FT) + else + false + end + + rt = Compiler.primal_return_type(Reverse, Val(world), FT, tt) + + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + + annotationA = if $Width != 1 && annotation0 <: Duplicated + BatchDuplicated{rt,$Width} + elseif $Width != 1 && annotation0 <: MixedDuplicated + BatchMixedDuplicated{rt,$Width} + else + annotation0 + end + opt_mi = Val(world) forward, adjoint = thunk( opt_mi, @@ -492,7 +493,11 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) ) return ReturnType(($(nres...), tape)) elseif annotation <: Active - shadow_return = $shadowretinit + shadow_return = if Base.isconcretetype(rt) + $shadowretinit + else + initShadow + end tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( internal_tape, shadow_return, @@ -634,31 +639,33 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act end quote - $(active_refs...) - args = ($(wrapped...),) - $(MakeTypes...) - - FT = Core.Typeof(f) - dupClosure0 = if ActivityTup[1] - !guaranteed_const(FT) + if f isa typeof(Core.getglobal) else - false - end + $(active_refs...) + args = ($(wrapped...),) + $(MakeTypes...) - tt = Tuple{$(ElTypes...)} - rt = Core.Compiler.return_type(f, tt) - annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + FT = Core.Typeof(f) + dupClosure0 = if ActivityTup[1] + !guaranteed_const(FT) + else + false + end - annotation = if $Width != 1 && annotation0 <: Duplicated - BatchDuplicated{rt,$Width} - else - annotation0 - end + tt = Tuple{$(ElTypes...)} - if f isa typeof(Core.getglobal) - else world = codegen_world_age(FT, tt) + rt = Compiler.primal_return_type(Reverse, Val(world), FT, tt) + + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + + annotation = if $Width != 1 && annotation0 <: Duplicated + BatchDuplicated{rt,$Width} + else + annotation0 + end + opt_mi = Val(world) _, adjoint = thunk( opt_mi, @@ -1488,6 +1495,7 @@ end if vec isa Base.RefValue vecld = vec[] T = Core.Typeof(vecld) + @assert !(vecld isa Base.RefValue) vec[] = recursive_index_add(T, vecld, Val(idx_in_vec), expr) else val = @inbounds vec[idx_in_vec] diff --git a/test/mixed.jl b/test/mixed.jl index dc4c510b23..77a09192d0 100644 --- a/test/mixed.jl +++ b/test/mixed.jl @@ -81,4 +81,43 @@ end @testset "Mixed PrimalError" begin @test_throws AssertionError autodiff(Reverse, bad_abi, MixedDuplicated(Foobar(2, 3, 4, 5, 6.0), Ref(Foobar(2, 3, 4, 5, 6.0)))) -end \ No newline at end of file +end + + + +function flattened_unique_values(tupled) + flattened = flatten_tuple(tupled) + + return nothing +end + +@inline flatten_tuple(a::Tuple) = tuple(inner_flatten_tuple(a[1])..., inner_flatten_tuple(a[2:end])...) +@inline flatten_tuple(a::Tuple{<:Any}) = tuple(inner_flatten_tuple(a[1])...) + +@inline inner_flatten_tuple(a) = tuple(a) +@inline inner_flatten_tuple(a::Tuple) = flatten_tuple(a) +@inline inner_flatten_tuple(a::Tuple{}) = () + + +struct Center end + +struct Field{LX} + grid :: Float64 + data :: Float64 +end + +@testset "Mixed Unstable Return" begin + grid = 1.0 + data = 2.0 + f1 = Field{Center}(grid, data) + f2 = Field{Center}(grid, data) + f3 = Field{Center}(grid, data) + f4 = Field{Center}(grid, data) + f5 = Field{Nothing}(grid, data) + thing = (f1, f2, f3, f4, f5) + dthing = Enzyme.make_zero(thing) + + dedC = autodiff(Enzyme.Reverse, + flattened_unique_values, + Duplicated(thing, dthing)) +end From 995d7f3d94796ea39400471550b2134baa10a911 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 20 Nov 2024 00:06:34 -0500 Subject: [PATCH 444/495] Fix global var store (#2068) --- src/compiler/validation.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 8ee16730c3..72fcb35efe 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -457,6 +457,19 @@ function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns) initfn, _ = get_base_and_offset(LLVM.initializer(fn_got); offsetAllowed=false, inttoptr=false) loadfn = first(instructions(first(blocks(initfn))))::LLVM.LoadInst opv = operands(loadfn)[1] + if !isa(opv, LLVM.GlobalVariable) + for iv in instructions(last(blocks(initfn))) + if !(isa, LLVM.StoreInst) + continue + end + gv = operands(iv)[2] + if !(isa, LLVM.GlobalVariable) + continue + end + opv = gv + break + end + end if !isa(opv, LLVM.GlobalVariable) msg = sprint() do io::IO println( From 665cebd317f212de8e9b3734f48cdcac6b427ee9 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 20 Nov 2024 00:07:00 -0500 Subject: [PATCH 445/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a70af4338e..24883933f5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.15" +version = "0.13.16" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 4f0f333f899593e1828a9ba232f597e8268dcb43 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Wed, 20 Nov 2024 23:42:38 -0500 Subject: [PATCH 446/495] Add crappy rule for dA in Sparse matmul (#2109) * Add crappy rule for dA * Make rule better and add tests --- src/internal_rules.jl | 40 +++++++++++++++++++++++++++------------- test/internal_rules.jl | 26 ++++++++++++++------------ 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 71c700e73a..04aca1a66a 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -733,7 +733,7 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(LinearAlgebra.mul!)}, ::Type{RT}, C::Annotation{<:StridedVecOrMat}, - A::Const{<:SparseArrays.SparseMatrixCSCUnion}, + A::Annotation{<:SparseArrays.SparseMatrixCSCUnion}, B::Annotation{<:StridedVecOrMat}, α::Annotation{<:Number}, β::Annotation{<:Number} @@ -761,7 +761,10 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, && !(typeof(C) <: Const) ) ? copy(A.val) : nothing - # cache_B = ( EnzymeRules.overwritten(config)[6]) ? copy(B.val) : nothing + cache_B = ( EnzymeRules.overwritten(config)[6] + && !(typeof(A) <: Const) + && !(typeof(C) <: Const) + ) ? copy(B.val) : nothing if !isa(α, Const) cache_α = A.val*B.val @@ -769,7 +772,7 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, cache_α = nothing end - cache = (cache_C, cache_A, cache_α) + cache = (cache_C, cache_A, cache_B, cache_α) return EnzymeRules.AugmentedReturn(primal, shadow, cache) end @@ -778,16 +781,16 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(LinearAlgebra.mul!)}, ::Type{RT}, cache, C::Annotation{<:StridedVecOrMat}, - A::Const{<:SparseArrays.SparseMatrixCSCUnion}, + A::Annotation{<:SparseArrays.SparseMatrixCSCUnion}, B::Annotation{<:StridedVecOrMat}, α::Annotation{<:Number}, β::Annotation{<:Number} ) where {RT} - cache_C, cache_A, cache_α = cache + cache_C, cache_A, cache_B, cache_α = cache Cval = !isnothing(cache_C) ? cache_C : C.val Aval = !isnothing(cache_A) ? cache_A : A.val - # Bval = !isnothing(cache_B) ? cache_B : B.val + Bval = !isnothing(cache_B) ? cache_B : B.val N = EnzymeRules.width(config) if !isa(C, Const) @@ -821,13 +824,24 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, end for i in 1:N - # This rule is incorrect since you need to project dA to have the same - # sparsity pattern as A. - # if !isa(A, Const) - # dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] - # #dA .+= α*dC*B' - # mul!(dA, dC, Bval', α.val, true) - # end + if !isa(A, Const) + # dA .+= αdC*B' + # You need to be careful so that dA sparsity pattern does not change. Otherwise + # you will get incorrect gradients. So for now we do the slow and bad way of accumulating + dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[i] + dC = EnzymeRules.width(config) == 1 ? C.dval : C.dval[i] + # Now accumulate to preserve the correct sparsity pattern + I, J, _ = SparseArrays.findnz(dA) + for k in eachindex(I, J) + Ik, Jk = I[k], J[k] + tmp = zero(eltype(dA)) + for ti in axes(dC,2) + tmp += dC[Ik, ti]*Bval[Jk, ti] + end + dA[Ik, Jk] += α.val*tmp + end + # mul!(dA, dCs, Bval', α.val, true) + end if !isa(B, Const) #dB .+= α*A'*dC diff --git a/test/internal_rules.jl b/test/internal_rules.jl index a91ddaa620..ad10e88e79 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -721,17 +721,18 @@ end α = 2.0 β = 1.0 - for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), + for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), Tα in (Const, Active), Tβ in (Const, Active) - are_activities_compatible(Tret, Tret, Tv, Tα, Tβ) || continue - test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ)) + are_activities_compatible(Tret, Tret, TM, Tv, Tα, Tβ) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, TM), (v, Tv), (α, Tα), (β, Tβ)) end - for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) - are_activities_compatible(Tret, Tret, Tv) || continue + for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), + Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) + are_activities_compatible(Tret, Tret, TM, Tv) || continue test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) end @@ -740,8 +741,6 @@ end @test dα ≈ 0 @test dβ ≈ 0 - - end @testset "SparseArrays spmatmat reverse rule" begin @@ -751,15 +750,18 @@ end α = 2.0 β = 1.0 - for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), + for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), Tα in (Const, Active), Tβ in (Const, Active) - are_activities_compatible(Tret, Tv, Tα, Tβ) || continue - test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ)) + are_activities_compatible(Tret, Tret, TM, Tv, Tα, Tβ) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, TM), (v, Tv), (α, Tα), (β, Tβ)) + end - for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) - are_activities_compatible(Tret, Tv) || continue + + for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), + Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) + are_activities_compatible(Tret, Tret, TM, Tv) || continue test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) end From 5c373fb9b7bd9087e0ba8d0fca61935c4298dd02 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 21 Nov 2024 17:30:45 -0500 Subject: [PATCH 447/495] fix fn attr 1.11 (#2114) --- src/compiler/validation.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 72fcb35efe..848b5734e7 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -577,6 +577,9 @@ function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns) newf, _ = get_function!(mod, fused_name, FT) + while isa(newf, LLVM.ConstantExpr) + newf = operands(newf) + end push!(function_attributes(newf), StringAttribute("enzyme_math", fname)) # TODO we can make this relocatable if desired by having restore lookups re-create this got initializer/etc # metadata(newf)["enzymejl_flib"] = flib From d096464a8a6bbb7e87b74bd938541c243492e2d3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 26 Nov 2024 01:46:11 +0100 Subject: [PATCH 448/495] Utilities for splitting and unsplitting mode objects (#1979) * Utilities for splitting and unsplitting mode objects * Remove Manifest * Rename and add tests * Add ABI tests * Fix tests * Add set_abi on mode type * Rename to Split and Combined --- lib/EnzymeCore/src/EnzymeCore.jl | 99 +++++++++++++++++++++++- lib/EnzymeCore/test/mode_modification.jl | 25 ++++++ lib/EnzymeCore/test/runtests.jl | 3 + 3 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 lib/EnzymeCore/test/mode_modification.jl diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 614bedbe55..f949664b6a 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -395,7 +395,7 @@ Subtype of [`Mode`](@ref) for split reverse mode differentiation, to use in [`au - [`set_abi`](@ref) - [`ReverseSplitModified`](@ref), [`ReverseSplitWidth`](@ref) """ -struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI,Holomorphic,ErrIfFuncWritten,ShadowInit} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end +struct ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI,Holomorphic,ErrIfFuncWritten,ShadowInit} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end """ const ReverseSplitNoPrimal @@ -432,6 +432,9 @@ Return a new instance of [`ReverseModeSplit`](@ref) mode where `Width` is set to @inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() @inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() +@inline set_abi(::Type{ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,OldABI,Holomorphic,ErrIfFuncWritten,ShadowInit}}, ::Type{NewABI}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,OldABI,Holomorphic,ErrIfFuncWritten,ShadowInit,NewABI<:ABI} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,NewABI,Holomorphic,ErrIfFuncWritten,ShadowInit} +@inline set_abi(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,OldABI,Holomorphic,ErrIfFuncWritten,ShadowInit}, ::Type{NewABI}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,OldABI,Holomorphic,ErrIfFuncWritten,ShadowInit,NewABI<:ABI} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,NewABI,Holomorphic,ErrIfFuncWritten,ShadowInit}() + @inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() @inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() @@ -594,6 +597,100 @@ Return a new mode with its [`ABI`](@ref) set to the chosen type. """ function set_abi end +""" + Split( + ::ReverseMode, [::Val{ReturnShadow}, ::Val{Width}, ::Val{ModifiedBetween}, ::Val{ShadowInit}] + ) + +Turn a [`ReverseMode`](@ref) object into a [`ReverseModeSplit`](@ref) object while preserving as many of the settings as possible. +The rest of the settings can be configured with optional positional arguments of `Val` type. + +This function acts as the identity on a [`ReverseModeSplit`](@ref). + +See also [`Combined`](@ref). +""" +function Split( + ::ReverseMode{ + ReturnPrimal, + RuntimeActivity, + ABI, + Holomorphic, + ErrIfFuncWritten + }, + ::Val{ReturnShadow}=Val(true), + ::Val{Width}=Val(0), + ::Val{ModifiedBetween}=Val(true), + ::Val{ShadowInit}=Val(false), +) where { + ReturnPrimal, + ReturnShadow, + RuntimeActivity, + Width, + ModifiedBetween, + ABI, + Holomorphic, + ErrIfFuncWritten, + ShadowInit +} + mode_split = ReverseModeSplit{ + ReturnPrimal, + ReturnShadow, + RuntimeActivity, + Width, + ModifiedBetween, + ABI, + Holomorphic, + ErrIfFuncWritten, + ShadowInit + }() + return mode_split +end + +Split(mode::ReverseModeSplit, args...) = mode + +""" + Combined(::ReverseMode) + +Turn a [`ReverseModeSplit`](@ref) object into a [`ReverseMode`](@ref) object while preserving as many of the settings as possible. + +This function acts as the identity on a [`ReverseMode`](@ref). + +See also [`Split`](@ref). +""" +function Combined( + ::ReverseModeSplit{ + ReturnPrimal, + ReturnShadow, + RuntimeActivity, + Width, + ModifiedBetween, + ABI, + Holomorphic, + ErrIfFuncWritten, + ShadowInit + } +) where { + ReturnPrimal, + ReturnShadow, + RuntimeActivity, + Width, + ModifiedBetween, + ABI, + Holomorphic, + ErrIfFuncWritten, + ShadowInit +} + mode_unsplit = ReverseMode{ + ReturnPrimal, + RuntimeActivity, + ABI, + Holomorphic, + ErrIfFuncWritten + }() + return mode_unsplit +end + +Combined(mode::ReverseMode) = mode """ Primitive Type usable within Reactant. See Reactant.jl for more information. diff --git a/lib/EnzymeCore/test/mode_modification.jl b/lib/EnzymeCore/test/mode_modification.jl new file mode 100644 index 0000000000..fde34d23bb --- /dev/null +++ b/lib/EnzymeCore/test/mode_modification.jl @@ -0,0 +1,25 @@ +using EnzymeCore +using EnzymeCore: InlineABI, ReverseModeSplit, Split, Combined, set_runtime_activity, set_err_if_func_written, set_abi +using Test + +@testset "Split / unsplit mode" begin + @test Split(Reverse) == ReverseSplitNoPrimal + @test Split(ReverseWithPrimal) == ReverseSplitWithPrimal + @test Split(ReverseSplitNoPrimal) == ReverseSplitNoPrimal + @test Split(ReverseSplitWithPrimal) == ReverseSplitWithPrimal + + @test Split(set_runtime_activity(Reverse)) == set_runtime_activity(ReverseSplitNoPrimal) + @test Split(set_err_if_func_written(Reverse)) == set_err_if_func_written(ReverseSplitNoPrimal) + @test Split(set_abi(Reverse, InlineABI)) == set_abi(ReverseSplitNoPrimal, InlineABI) + + @test Split(Reverse, Val(:ReturnShadow), Val(:Width), Val(:ModifiedBetween), Val(:ShadowInit)) == ReverseModeSplit{false,:ReturnShadow,false,:Width,:ModifiedBetween,EnzymeCore.DefaultABI,false,false,:ShadowInit}() + + @test Combined(Reverse) == Reverse + @test Combined(ReverseWithPrimal) == ReverseWithPrimal + @test Combined(ReverseSplitNoPrimal) == Reverse + @test Combined(ReverseSplitWithPrimal) == ReverseWithPrimal + + @test Combined(set_runtime_activity(ReverseSplitNoPrimal)) == set_runtime_activity(Reverse) + @test Combined(set_err_if_func_written(ReverseSplitNoPrimal)) == set_err_if_func_written(Reverse) + @test Combined(set_abi(ReverseSplitNoPrimal, InlineABI)) == set_abi(Reverse, InlineABI) +end diff --git a/lib/EnzymeCore/test/runtests.jl b/lib/EnzymeCore/test/runtests.jl index 2fb7fd74f2..b12065384d 100644 --- a/lib/EnzymeCore/test/runtests.jl +++ b/lib/EnzymeCore/test/runtests.jl @@ -33,4 +33,7 @@ using EnzymeCore @testset "Miscellaneous" begin include("misc.jl") end + @testset "Mode modification" begin + include("mode_modification.jl") + end end From 8e481be3c8542a1f6b7387018a157681fbae8509 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 26 Nov 2024 20:28:06 -0500 Subject: [PATCH 449/495] Fix segfault on return type (#2117) * Fix segfault on return type * fix * fix * fix * fix * cleanup * cleanup * fix * Update runtests.jl * Update parallelrules.jl * Update parallelrules.jl * Update jitrules.jl * Update jitrules.jl * Update jitrules.jl * Update jitrules.jl * fix * cleanup * cleanup * more fix * fix * Update compiler.jl * Update compiler.jl * Update compiler.jl * Update compiler.jl * fix * jelly * fix * fix * more nospec * more types * more types * moretype * more cleanup * cleanup * further types * more types * unused ctx * ix * fixup * ix * ix * ix * fix * fix * fix * fix * fix * fix * fix * fix * fix * ix * Update absint.jl * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * Update compiler.jl * Update jitrules.jl * Update compiler.jl * fix * fix * fix * fix * fix --- Project.toml | 4 +- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/rules.jl | 36 +- src/Enzyme.jl | 140 +++--- src/absint.jl | 20 +- src/compiler.jl | 808 ++++++++++++++++++++------------- src/compiler/interpreter.jl | 46 +- src/compiler/optimize.jl | 165 ++++--- src/compiler/reflection.jl | 28 +- src/compiler/utils.jl | 46 +- src/compiler/validation.jl | 28 +- src/jlrt.jl | 115 +++-- src/rules/customrules.jl | 52 +-- src/rules/jitrules.jl | 223 ++++----- src/rules/parallelrules.jl | 22 +- src/rules/typeunstablerules.jl | 17 +- src/utils.jl | 116 ----- test/ext/chainrulescore.jl | 2 +- test/optimize.jl | 63 +++ test/ruleinvalidation.jl | 8 + test/runtests.jl | 17 +- 21 files changed, 1067 insertions(+), 891 deletions(-) diff --git a/Project.toml b/Project.toml index 24883933f5..369e43f88a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.16" +version = "0.13.17" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -35,7 +35,7 @@ EnzymeStaticArraysExt = "StaticArrays" BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.8.6" +EnzymeCore = "0.8.7" Enzyme_jll = "0.0.165" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 270dd35056..28f92d9055 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.6" +version = "0.8.7" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index a7563a2ef7..945951b216 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -171,7 +171,7 @@ end function has_frule_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance}=nothing) + caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) ft, tt = _annotate_tt(TT) TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} return isapplicable(forward, TT; world, method_table, caller) @@ -180,7 +180,7 @@ end function has_rrule_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance}=nothing) + caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) ft, tt = _annotate_tt(TT) TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} return isapplicable(augmented_primal, TT; world, method_table, caller) @@ -192,7 +192,7 @@ end function isapplicable(@nospecialize(f), @nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance}=nothing) + caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) tt = Base.to_tuple_type(TT) sig = Base.signature_type(f, tt) mt = ccall(:jl_method_table_for, Any, (Any,), sig) @@ -208,18 +208,36 @@ function isapplicable(@nospecialize(f), @nospecialize(TT); matches = result end fullmatch = Core.Compiler._any(match::Core.MethodMatch->match.fully_covers, matches) - if caller !== nothing - fullmatch || add_mt_backedge!(caller, mt, sig) + if !fullmatch + if caller isa Core.MethodInstance + add_mt_backedge!(caller, mt, sig) + elseif caller isa Core.Compiler.MethodLookupResult + for j = 1:Core.Compiler.length(caller) + cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch + cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance + add_mt_backedge!(cspec, mt, sig) + end + end end if Core.Compiler.isempty(matches) return false else - if caller !== nothing + if caller isa Core.MethodInstance for i = 1:Core.Compiler.length(matches) match = Core.Compiler.getindex(matches, i)::Core.MethodMatch edge = Core.Compiler.specialize_method(match)::Core.MethodInstance add_backedge!(caller, edge, sig) end + elseif caller isa Core.Compiler.MethodLookupResult + for j = 1:Core.Compiler.length(caller) + cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch + cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance + for i = 1:Core.Compiler.length(matches) + match = Core.Compiler.getindex(matches, i)::Core.MethodMatch + edge = Core.Compiler.specialize_method(match)::Core.MethodInstance + add_backedge!(cspec, edge, sig) + end + end end return true end @@ -245,7 +263,7 @@ function inactive end function is_inactive_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance}=nothing) + caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) return isapplicable(inactive, TT; world, method_table, caller) end @@ -260,7 +278,7 @@ function inactive_noinl end function is_inactive_noinl_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance}=nothing) + caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) return isapplicable(inactive_noinl, TT; world, method_table, caller) end @@ -275,7 +293,7 @@ function noalias end function noalias_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance}=nothing) + caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) return isapplicable(noalias, TT; world, method_table, caller) end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 17ec83980b..5305beff8d 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -11,7 +11,9 @@ import EnzymeCore: ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, + Mode, ReverseMode, + ReverseModeSplit, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal @@ -23,7 +25,9 @@ export Forward, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, + Mode, ReverseMode, + ReverseModeSplit, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal @@ -99,7 +103,6 @@ export markType, batch_size, onehot, chunkedonehot using LinearAlgebra import SparseArrays -import EnzymeCore: ReverseMode, ReverseModeSplit, ForwardMode, Mode import EnzymeCore: EnzymeRules export EnzymeRules @@ -158,6 +161,15 @@ end )...} end +@inline function vaEltypeof(args::Vararg{Any,N}) where {N} + return Tuple{( + ntuple(Val(N)) do i + Base.@_inline_meta + eltype(Core.Typeof(args[i])) + end + )...} +end + @inline function vaEltypes(args::Type{Ty}) where {Ty<:Tuple} return Tuple{( ntuple(Val(length(Ty.parameters))) do i @@ -330,7 +342,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) point values, but cannot do so for integer values in tuples and structs. """ @inline function autodiff( - rmode::ReverseMode{ReturnPrimal,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}, + mode::ReverseMode{ReturnPrimal,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs}, @@ -353,12 +365,12 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) ModifiedBetweenT = falses_from_args(Nargs + 1) ModifiedBetween = Val(ModifiedBetweenT) - tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} + tt = vaEltypeof(args...) FTy = Core.Typeof(f.val) rt = if A isa UnionAll - Compiler.primal_return_type(rmode, Val(codegen_world_age(FTy, tt)), FTy, tt) + Compiler.primal_return_type(mode, FTy, tt) else eltype(A) end @@ -400,7 +412,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) opt_mi = if RABI <: NonGenABI Compiler.fspec(eltype(FA), tt′) else - Val(codegen_world_age(FTy, tt)) + Val(0) end if (A <: Active && rt <: Complex) && rt != Union{} @@ -522,17 +534,12 @@ Like [`autodiff`](@ref) but will try to guess the activity of the return value. f::FA, args::Vararg{Annotation,Nargs}, ) where {FA<:Annotation,CMode<:Mode,Nargs} - tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} - rt = if mode isa ReverseMode - Compiler.primal_return_type( - mode, - Val(codegen_world_age(eltype(FA), tt)), - eltype(FA), - tt, - ) - else - Core.Compiler.return_type(f.val, tt) - end + tt = vaEltypeof(args...) + rt = Compiler.primal_return_type( + mode, + eltype(FA), + tt, + ) A = guess_activity(rt, mode) autodiff(mode, f, A, args...) end @@ -578,7 +585,7 @@ f(x) = x*x ``` """ @inline function autodiff( - ::ForwardMode{ReturnPrimal,RABI,ErrIfFuncWritten,RuntimeActivity}, + mode::ForwardMode{ReturnPrimal,RABI,ErrIfFuncWritten,RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs}, @@ -622,12 +629,12 @@ f(x) = x*x ModifiedBetween = Val(falses_from_args(Nargs + 1)) - tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} + tt = vaEltypeof(args...) opt_mi = if RABI <: NonGenABI Compiler.fspec(eltype(FA), tt′) else - Val(codegen_world_age(Core.Typeof(f.val), tt)) + Val(0) end thunk = Enzyme.Compiler.thunk( @@ -654,7 +661,7 @@ Same as [`autodiff`](@ref) but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ @inline function autodiff_deferred( - rmode::ReverseMode{ReturnPrimal,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}, + mode::ReverseMode{ReturnPrimal,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs}, @@ -673,24 +680,23 @@ code, as well as high-order differentiation. if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end - tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} + tt = vaEltypeof(args...) FTy = Core.Typeof(f.val) - world = codegen_world_age(FTy, tt) A2 = A if A isa UnionAll - rt = Compiler.primal_return_type(rmode, Val(world), FTy, tt) - A2 = A{rt} - if rt == Union{} - throw(ErrorException("Return type inferred to be Union{}. Giving up.")) - end - else - @assert A isa DataType - rt = A - if rt == Union{} - throw(ErrorException("Return type inferred to be Union{}. Giving up.")) + rt = Compiler.primal_return_type(mode, FTy, tt) + A2 = A{rt} + if rt == Union{} + throw(ErrorException("Return type inferred to be Union{}. Giving up.")) + end + else + @assert A isa DataType + rt = A + if rt == Union{} + throw(ErrorException("Return type inferred to be Union{}. Giving up.")) end end @@ -753,10 +759,9 @@ code, as well as high-order differentiation. end adjoint_ptr = Compiler.deferred_codegen( - Val(world), FA, - Val(tt′), - Val(A), + A, + tt′, Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, @@ -787,7 +792,7 @@ Same as `autodiff(::ForwardMode, f, Activity, args...)` but uses deferred compil code, as well as high-order differentiation. """ @inline function autodiff_deferred( - ::ForwardMode{ReturnPrimal,RABI,ErrIfFuncWritten,RuntimeActivity}, + mode::ForwardMode{ReturnPrimal,RABI,ErrIfFuncWritten,RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs}, @@ -830,12 +835,12 @@ code, as well as high-order differentiation. else A end - tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} + tt = vaEltypeof(args...) - world = codegen_world_age(Core.Typeof(f.val), tt) + FT = Core.Typeof(f.val) if RT isa UnionAll - rt = Core.Compiler.return_type(f.val, tt) + rt = Compiler.primal_return_type(mode, FT, tt) rt = RT{rt} else @assert RT isa DataType @@ -853,10 +858,9 @@ code, as well as high-order differentiation. ModifiedBetween = Val(falses_from_args(Nargs + 1)) adjoint_ptr = Compiler.deferred_codegen( - Val(world), - FA, - Val(tt′), - Val(rt), + Core.Typeof(f), + rt, + tt′, Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, @@ -914,7 +918,7 @@ result, ∂v, ∂A ``` """ @inline function autodiff_thunk( - rs::ReverseModeSplit{ + mode::ReverseModeSplit{ ReturnPrimal, ReturnShadow, RuntimeActivity, @@ -963,7 +967,7 @@ result, ∂v, ∂A opt_mi = if RABI <: NonGenABI Compiler.fspec(eltype(FA), tt′) else - Val(codegen_world_age(eltype(FA), tt)) + Val(0) end Enzyme.Compiler.thunk( opt_mi, @@ -1057,7 +1061,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, Duplicated, Duplicated{Float ``` """ @inline function autodiff_thunk( - ::ForwardMode{ReturnPrimal,RABI,ErrIfFuncWritten,RuntimeActivity}, + mode::ForwardMode{ReturnPrimal,RABI,ErrIfFuncWritten,RuntimeActivity}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation},Nargs}, @@ -1093,7 +1097,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, Duplicated, Duplicated{Float opt_mi = if RABI <: NonGenABI Compiler.fspec(eltype(FA), tt′) else - Val(codegen_world_age(eltype(FA), tt)) + Val(0) end results = Enzyme.Compiler.thunk( opt_mi, @@ -1112,7 +1116,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, Duplicated, Duplicated{Float end @inline function tape_type( - ::ReverseModeSplit{ + mode::ReverseModeSplit{ ReturnPrimal, ReturnShadow, RuntimeActivity, @@ -1161,7 +1165,7 @@ end opt_mi = if RABI <: NonGenABI Compiler.fspec(eltype(FA), TT) else - Val(codegen_world_age(eltype(FA), primal_tt)) + Val(0) end nondef = Enzyme.Compiler.thunk( opt_mi, @@ -1193,7 +1197,7 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType @inline function tape_type( parent_job::Union{GPUCompiler.CompilerJob,Nothing}, - ::ReverseModeSplit{ + mode::ReverseModeSplit{ ReturnPrimal, ReturnShadow, RuntimeActivity, @@ -1239,9 +1243,7 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType primal_tt = Tuple{map(eltype, args)...} - world = codegen_world_age(eltype(FA), primal_tt) - - mi = Compiler.fspec(eltype(FA), TT, world) + mi = Compiler.fspec(eltype(FA), TT) target = Compiler.EnzymeTarget() params = Compiler.EnzymeCompilerParams( @@ -1377,13 +1379,14 @@ result, ∂v, ∂A TT = Tuple{args...} primal_tt = Tuple{map(eltype, args)...} - world = codegen_world_age(eltype(FA), primal_tt) + rt0 = Compiler.primal_return_type(mode, eltype(FA), primal_tt) + + rt = Compiler.remove_innerty(A2){rt0} primal_ptr = Compiler.deferred_codegen( - Val(world), FA, - Val(TT), - Val(Compiler.remove_innerty(A2)), + rt, + TT, Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, @@ -1394,10 +1397,9 @@ result, ∂v, ∂A Val(RuntimeActivity), ) #=ShadowInit=# adjoint_ptr = Compiler.deferred_codegen( - Val(world), FA, - Val(TT), - Val(Compiler.remove_innerty(A2)), + rt, + TT, Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, @@ -1430,13 +1432,6 @@ result, ∂v, ∂A A2 end - rt = if RT isa UnionAll - RT{Core.Compiler.return_type(Tuple{eltype(FA),map(eltype, args)...})} - else - @assert RT isa DataType - RT - end - aug_thunk = Compiler.AugmentedForwardThunk{Ptr{Cvoid},FA,rt,TT,width,ReturnPrimal,TapeType}( primal_ptr, @@ -2232,7 +2227,7 @@ this function will retun an AbstractArray of shape `size(output)` of values of t ``` """ @inline function jacobian( - ::ReverseMode{ReturnPrimal,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}, + mode::ReverseMode{ReturnPrimal,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}, f::F, x::X; n_outs::OutType = nothing, @@ -2286,14 +2281,15 @@ this function will retun an AbstractArray of shape `size(output)` of values of t XT = Core.Typeof(x) MD = Compiler.active_reg_inner(XT, (), nothing, Val(true)) == Compiler.ActiveState #=justActive=# tt = Tuple{XT} - rt = if f isa Const - Core.Compiler.return_type(f.val, tt) + FRT = if f isa Const + Core.Typeof(f.val) else - Core.Compiler.return_type(f, tt) + Core.Typeof(f) end + rt = Compiler.primal_return_type(mode, FRT, tt) + ModifiedBetweenT = (false, false) - FRT = Core.Typeof(f) FA = Const{FRT} if chunk == Val(1) || chunk == nothing diff --git a/src/absint.jl b/src/absint.jl index 1d5fed1403..3205dccc25 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -75,9 +75,9 @@ function absint(arg::LLVM.Value, partial::Bool = false) end if nm == "jl_f_apply_type" || nm == "ijl_f_apply_type" index += 1 - found = [] + found = Any[] legal, Ty = absint(operands(arg)[index], partial) - unionalls = [] + unionalls = TypeVar[] for sarg in operands(arg)[index+1:end-1] slegal, foundv = absint(sarg, partial) if slegal @@ -102,7 +102,7 @@ function absint(arg::LLVM.Value, partial::Bool = false) end if nm == "jl_f_tuple" || nm == "ijl_f_tuple" index += 1 - found = [] + found = Any[] legal = true for sarg in operands(arg)[index:end-1] slegal, foundv = absint(sarg, partial) @@ -250,7 +250,7 @@ function get_base_and_offset(larg::LLVM.Value; offsetAllowed=true, inttoptr=fals continue end if isa(larg, LLVM.GetElementPtrInst) && - all(x -> isa(x, LLVM.ConstantInt), operands(larg)[2:end]) + all(Base.Fix2(isa, LLVM.ConstantInt), operands(larg)[2:end]) b = LLVM.IRBuilder() position!(b, larg) offty = LLVM.IntType(8 * sizeof(Int)) @@ -389,8 +389,8 @@ function abs_typeof( if nm == "jl_f_tuple" || nm == "ijl_f_tuple" index += 1 - found = [] - unionalls = [] + found = Union{Type, TypeVar}[] + unionalls = TypeVar[] legal = true for sarg in operands(arg)[index:end-1] slegal, foundv, _ = abs_typeof(sarg, partial, seenphis) @@ -416,8 +416,6 @@ function abs_typeof( if nm == "jl_f__apply_iterate" || nm == "ijl_f__apply_iterate" index += 1 - found = [] - unionalls = [] legal, iterfn = absint(operands(arg)[index]) index += 1 if legal && iterfn == Base.iterate @@ -426,7 +424,7 @@ function abs_typeof( if legal0 && combfn == Core.apply_type && partial return (true, Type, GPUCompiler.BITS_REF) end - resvals = [] + resvals = Type[] while index != length(operands(arg)) legal, pval, _ = abs_typeof(operands(arg)[index], partial, seenphis) if !legal @@ -727,7 +725,7 @@ function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr ce = operands(ce)[1] elseif opcode(ce) == LLVM.API.LLVMGetElementPtr - if all(x -> isa(x, LLVM.ConstantInt) && convert(UInt, x) == 0, operands(ce)[2:end]) + if all(x -> x isa LLVM.ConstantInt && convert(UInt, x) == 0, operands(ce)[2:end]) ce = operands(ce)[1] else break @@ -739,7 +737,7 @@ function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} if isa(ce, LLVM.GlobalVariable) ce = LLVM.initializer(ce) if (isa(ce, LLVM.ConstantArray) || isa(ce, LLVM.ConstantDataArray)) && eltype(value_type(ce)) == LLVM.IntType(8) - return (true, String(map((x)->convert(UInt8, x), collect(ce)[1:(end-1)]))) + return (true, String(map(Base.Fix1(convert, UInt8), collect(ce)[1:(end-1)]))) end end diff --git a/src/compiler.jl b/src/compiler.jl index 2b7d441c27..e836f4bb8c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -106,13 +106,13 @@ const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,Data typeof(Base.FastMath.tanh_fast) => (:tanh, 1, nothing), typeof(Base.fma_emulated) => (:fma, 3, nothing), ) -@inline function find_math_method(@nospecialize(func), sparam_vals) +@inline function find_math_method(@nospecialize(func::Type), sparam_vals::Core.SimpleVector) if func ∈ keys(known_ops) name, arity, toinject = known_ops[func] Tys = (Float32, Float64) if length(sparam_vals) == arity - T = first(sparam_vals) + T = first(sparam_vals)::Type legal = T ∈ Tys if legal @@ -450,7 +450,7 @@ end @inline element(::Val{T}) where {T} = T # From https://github.com/JuliaLang/julia/blob/81813164963f38dcd779d65ecd222fad8d7ed437/src/cgutils.cpp#L570 -@inline function isghostty(ty) +@inline function isghostty(@nospecialize(ty)) if ty === Union{} return true end @@ -950,7 +950,7 @@ using .JIT include("jlrt.jl") -AnyArray(Length::Int) = NamedTuple{ntuple(i -> Symbol(i), Val(Length)),NTuple{Length,Any}} +AnyArray(Length::Int) = NamedTuple{ntuple(Symbol, Val(Length)),NTuple{Length,Any}} struct EnzymeRuntimeException <: Base.Exception msg::Cstring @@ -1202,12 +1202,12 @@ end include("make_zero.jl") -function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, f, tt, world) +function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt) funcspec = my_methodinstance(typeof(f), tt, world) nested_codegen!(mode, mod, funcspec, world) end -function prepare_llvm(mod, job, meta) +function prepare_llvm(mod::LLVM.Module, job, meta) interp = GPUCompiler.get_interpreter(job) for f in functions(mod) attributes = function_attributes(f) @@ -1250,7 +1250,7 @@ function nested_codegen!( mode::API.CDerivativeMode, mod::LLVM.Module, funcspec::Core.MethodInstance, - world, + world::UInt, ) # TODO: Put a cache here index on `mod` and f->tt @@ -1406,7 +1406,7 @@ end parent_scope(val::LLVM.Function, depth = 0) = depth == 0 ? LLVM.parent(val) : val parent_scope(val::LLVM.Module, depth = 0) = val -parent_scope(val::LLVM.Value, depth = 0) = parent_scope(LLVM.parent(val), depth + 1) +parent_scope(@nospecialize(val::LLVM.Value), depth = 0) = parent_scope(LLVM.parent(val), depth + 1) parent_scope(val::LLVM.Argument, depth = 0) = parent_scope(LLVM.Function(LLVM.API.LLVMGetParamParent(val)), depth + 1) @@ -1431,11 +1431,7 @@ function julia_sanitize( stringv = "Enzyme: Found nan while computing derivative of " * string(orig) if orig !== nothing && isa(orig, LLVM.Instruction) bt = GPUCompiler.backtrace(orig) - function printBT(io) - print(io, "\nCaused by:") - Base.show_backtrace(io, bt) - end - stringv *= sprint(io -> Base.show_backtrace(io, bt)) + stringv *= sprint(Base.Fix2(Base.show_backtrace, bt)) end fn, _ = get_function!(mod, "julia.sanitize." * string(ty), FT) @@ -1649,7 +1645,7 @@ function julia_error( created = LLVM.Instruction[] world = enzyme_extract_world(LLVM.parent(position(IRBuilder(B)))) width = get_width(gutils) - function make_batched(cur, B) + function make_batched(@nospecialize(cur::LLVM.Value), B::LLVM.IRBuilder)::LLVM.Value if width == 1 return cur else @@ -1668,7 +1664,7 @@ function julia_error( illegalVal = nothing - function make_replacement(cur::LLVM.Value, prevbb)::LLVM.Value + function make_replacement(@nospecialize(cur::LLVM.Value), prevbb::LLVM.IRBuilder)::LLVM.Value ncur = new_from_original(gutils, cur) if cur in keys(seen) return seen[cur] @@ -1883,7 +1879,7 @@ end end end - if isa(cur, LLVM.LoadInst) || isa(cur, LLVM.BitCastInst) || isa(cur, LLVM.AddrSpaceCastInst) || (isa(cur, LLVM.GetElementPtrInst) && all(x->isa(x, LLVM.ConstantInt), operands(cur)[2:end])) + if isa(cur, LLVM.LoadInst) || isa(cur, LLVM.BitCastInst) || isa(cur, LLVM.AddrSpaceCastInst) || (isa(cur, LLVM.GetElementPtrInst) && all(Base.Fix2(isa, LLVM.ConstantInt), operands(cur)[2:end])) lhs = make_replacement(operands(cur)[1], prevbb) if illegal return ncur @@ -2168,7 +2164,7 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} error("Can't construct tape type for $Type $(string(Type)) $tkind") end -function tape_type(LLVMType::LLVM.LLVMType) +function tape_type(@nospecialize(LLVMType::LLVM.LLVMType)) TT, isAny = to_tape_type(LLVMType.ref) if isAny return AnonymousStruct(Tuple{Any}) @@ -2204,7 +2200,7 @@ current_task_offset() = current_ptls_offset() = unsafe_load(cglobal(:jl_task_ptls_offset, Cint)) ÷ sizeof(Ptr{Cvoid}) -function store_nonjl_types!(B, startval, p) +function store_nonjl_types!(B::LLVM.IRBuilder, @nospecialize(startval::LLVM.Value), @nospecialize(p::LLVM.Value)) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) vals = LLVM.Value[] @@ -2248,7 +2244,7 @@ function store_nonjl_types!(B, startval, p) return end -function get_julia_inner_types(B, p, startvals...; added = LLVM.API.LLVMValueRef[]) +function get_julia_inner_types(B::LLVM.IRBuilder, @nospecialize(p::Union{Nothing, LLVM.Value}), @nospecialize(startvals::Vararg{LLVM.Value}); added = LLVM.API.LLVMValueRef[]) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) vals = LLVM.Value[] @@ -2327,7 +2323,7 @@ end function julia_post_cache_store( SI::LLVM.API.LLVMValueRef, B::LLVM.API.LLVMBuilderRef, - R2, + R2::Ptr{UInt64}, )::Ptr{LLVM.API.LLVMValueRef} B = LLVM.IRBuilder(B) SI = LLVM.Instruction(SI) @@ -2445,7 +2441,7 @@ function julia_allocator( Count::LLVM.API.LLVMValueRef, AlignedSize::LLVM.API.LLVMValueRef, IsDefault::UInt8, - ZI, + ZI::Ptr{LLVM.API.LLVMValueRef}, ) B = LLVM.IRBuilder(B) Count = LLVM.Value(Count) @@ -2454,7 +2450,7 @@ function julia_allocator( return julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) end -function fixup_return(B, retval) +function fixup_return(B::LLVM.API.LLVMBuilderRef, retval::LLVM.API.LLVMValueRef) B = LLVM.IRBuilder(B) func = LLVM.parent(position(B)) @@ -2480,7 +2476,7 @@ function fixup_return(B, retval) return retval.ref end -function zero_allocation(B, LLVMType, obj, isTape::UInt8) +function zero_allocation(B::LLVM.API.LLVMBuilderRef, LLVMType::LLVM.API.LLVMTypeRef, obj::LLVM.API.LLVMValueRef, isTape::UInt8) B = LLVM.IRBuilder(B) LLVMType = LLVM.LLVMType(LLVMType) obj = LLVM.Value(obj) @@ -2493,7 +2489,7 @@ function zero_allocation(B, LLVMType, obj, isTape::UInt8) return nothing end -function zero_single_allocation(builder, jlType, LLVMType, nobj, zeroAll, idx) +function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::DataType), @nospecialize(LLVMType::LLVM.LLVMType), @nospecialize(nobj::LLVM.Value), zeroAll::Bool, @nospecialize(idx::LLVM.Value)) T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) @@ -2570,11 +2566,11 @@ end function zero_allocation( B::LLVM.IRBuilder, - jlType, - LLVMType, - obj, - AlignedSize, - Size, + @nospecialize(jlType::DataType), + @nospecialize(LLVMType::LLVM.LLVMType), + @nospecialize(obj::LLVM.Value), + @nospecialize(AlignedSize::LLVM.Value), + @nospecialize(Size::LLVM.Value), zeroAll::Bool, )::LLVM.API.LLVMValueRef func = LLVM.parent(position(B)) @@ -2657,7 +2653,7 @@ function zero_allocation( ).ref end -function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) +function julia_allocator(B::LLVM.IRBuilder, @nospecialize(LLVMType::LLVM.LLVMType), @nospecialize(Count::LLVM.Value), @nospecialize(AlignedSize::LLVM.Value), IsDefault::UInt8, ZI::Ptr{LLVM.API.LLVMValueRef}) func = LLVM.parent(position(B)) mod = LLVM.parent(func) @@ -2704,7 +2700,7 @@ function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) Count = trunc!(B, Count, T_size_t) boxed_count = emit_box_int32!(B, Count) end - tag = emit_apply_type!(B, NTuple, (boxed_count, unsafe_to_llvm(B, TT))) + tag = emit_apply_type!(B, NTuple, LLVM.Value[boxed_count, unsafe_to_llvm(B, TT)]) end # Check if Julia version has https://github.com/JuliaLang/julia/pull/46914 @@ -2782,7 +2778,7 @@ function julia_deallocator(B::LLVM.API.LLVMBuilderRef, Obj::LLVM.API.LLVMValueRe julia_deallocator(B, Obj) end -function julia_deallocator(B::LLVM.IRBuilder, Obj::LLVM.Value) +function julia_deallocator(B::LLVM.IRBuilder, @nospecialize(Obj::LLVM.Value)) mod = LLVM.parent(LLVM.parent(position(B))) T_void = LLVM.VoidType() @@ -2801,7 +2797,7 @@ function julia_deallocator(B::LLVM.IRBuilder, Obj::LLVM.Value) return LLVM.API.LLVMValueRef(callf.ref) end -function emit_inacterror(B, V, orig) +function emit_inacterror(B::LLVM.API.LLVMBuilderRef, V::LLVM.API.LLVMValueRef, orig::LLVM.API.LLVMValueRef) B = LLVM.IRBuilder(B) curent_bb = position(B) orig = LLVM.Value(orig) @@ -2809,7 +2805,7 @@ function emit_inacterror(B, V, orig) mod = LLVM.parent(fn) bt = GPUCompiler.backtrace(orig) - bts = sprint(io -> Base.show_backtrace(io, bt)) + bts = sprint(Base.Fix2(Base.show_backtrace, bt)) fmt = globalstring_ptr!(B, "%s:\nBacktrace\n" * bts) funcT = LLVM.FunctionType( @@ -2817,7 +2813,7 @@ function emit_inacterror(B, V, orig) LLVMType[LLVM.PointerType(LLVM.Int8Type())], vararg = true, ) - func, _ = get_function!(mod, "jl_errorf", funcT, [EnumAttribute("noreturn")]) + func, _ = get_function!(mod, "jl_errorf", funcT, LLVM.Attribute[EnumAttribute("noreturn")]) call!(B, funcT, func, LLVM.Value[fmt, LLVM.Value(V)]) return nothing @@ -2960,7 +2956,8 @@ function __init__() end # Define EnzymeTarget -Base.@kwdef struct EnzymeTarget <: AbstractCompilerTarget end +# Base.@kwdef +struct EnzymeTarget <: AbstractCompilerTarget end GPUCompiler.llvm_triple(::EnzymeTarget) = LLVM.triple(JIT.get_jit()) GPUCompiler.llvm_datalayout(::EnzymeTarget) = LLVM.datalayout(JIT.get_jit()) @@ -3082,9 +3079,9 @@ import .Interpreter: isKWCallSignature Create the methodinstance pair, and lookup the primal return type. """ @inline function fspec( - @nospecialize(F), - @nospecialize(TT), - world::Union{Integer,Nothing} = nothing, + @nospecialize(F::Type), + @nospecialize(TT::Type), + world::Union{UInt,Nothing} = nothing, ) # primal function. Inferred here to get return type _tt = (TT.parameters...,) @@ -3100,12 +3097,10 @@ Create the methodinstance pair, and lookup the primal return type. return primal end -@generated function primal_return_type( - ::ReverseMode, - ::Val{world}, - ::Type{FT}, - ::Type{TT}, -) where {world,FT,TT} +function primal_interp_world( + @nospecialize(::ReverseMode), + world::UInt +) mode = Enzyme.API.DEM_ReverseModeCombined CT = @static if VERSION >= v"1.11.0-DEV.1552" @@ -3120,40 +3115,156 @@ end Enzyme.Compiler.GLOBAL_REV_CACHE end - interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode) - res = Core.Compiler._return_type(interp, Tuple{FT,TT.parameters...}) - return quote - Base.@_inline_meta - $res - end + Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode) end -@generated function primal_return_type( - ::ForwardMode, - ::Val{world}, - ::Type{FT}, - ::Type{TT}, -) where {world,FT,TT} +function primal_interp_world( + @nospecialize(::ForwardMode), + world::UInt +) mode = Enzyme.API.DEM_ForwardMode CT = @static if VERSION >= v"1.11.0-DEV.1552" EnzymeCacheToken( typeof(DefaultCompilerTarget()), false, - GPUCompiler.GLOBAL_METHOD_TABLE, #=always_inline=# + GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=# EnzymeCompilerParams, - false, + true, ) else Enzyme.Compiler.GLOBAL_FWD_CACHE end - interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode) - res = Core.Compiler._return_type(interp, Tuple{FT,TT.parameters...}) - return quote - Base.@_inline_meta - $res + Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode) +end + +@inline primal_interp_world( + @nospecialize(::ReverseModeSplit), + world::UInt) = primal_interp_world(Reverse, world) + +function primal_return_type_world( + @nospecialize(mode::Mode), + world::UInt, + @nospecialize(TT::Type), +) + Core.Compiler._return_type(primal_interp_world(mode, world), TT) +end + +function primal_return_type_world( + @nospecialize(mode::Mode), + world::UInt, + mi::Core.MethodInstance, +) + interp = primal_interp_world(mode, world) + something( + Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), + Any, + ) +end + +primal_return_type_world( + @nospecialize(mode::Mode), + world::UInt, + @nospecialize(FT::Type), + @nospecialize(TT::Type), + ) = primal_return_type_world(mode, world, Tuple{FT, TT.parameters...}) + +function primal_return_type_generator(world::UInt, source, self, @nospecialize(mode::Type), @nospecialize(ft::Type), @nospecialize(tt::Type)) + @nospecialize + @assert Core.Compiler.isType(ft) && Core.Compiler.isType(tt) + @assert mode <: Mode + mode = mode() + ft = ft.parameters[1] + tt = tt.parameters[1] + + # validation + ft <: Core.Builtin && + error("$(GPUCompiler.unsafe_function_from_type(ft)) is not a generic function") + + # look up the method + method_error = :(throw(MethodError(ft, tt, $world))) + sig = Tuple{ft,tt.parameters...} + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results + #interp = primal_interp_world(mode, world) + #method_table = Core.Compiler.method_table(interp) + method_table = nothing + mthds = Base._methods_by_ftype( + sig, + method_table, + -1, #=lim=# + world, + false, #=ambig=# + min_world, + max_world, + has_ambig, + ) + stub = Core.GeneratedFunctionStub( + identity, + Core.svec(:methodinstance, :mode, :ft, :tt), + Core.svec(), + ) + mthds === nothing && return stub(world, source, method_error) + length(mthds) == 1 || return stub(world, source, method_error) + + # look up the method and code instance + mtypes, msp, m = mthds[1] + mi = ccall( + :jl_specializations_get_linfo, + Ref{Core.MethodInstance}, + (Any, Any, Any), + m, + mtypes, + msp, + ) + ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo + + # prepare a new code info + new_ci = copy(ci) + empty!(new_ci.code) + @static if isdefined(Core, :DebugInfo) + new_ci.debuginfo = Core.DebugInfo(:none) + else + empty!(new_ci.codelocs) + resize!(new_ci.linetable, 1) # see note below + end + empty!(new_ci.ssaflags) + new_ci.ssavaluetypes = 0 + new_ci.min_world = min_world[] + new_ci.max_world = max_world[] + new_ci.edges = Core.MethodInstance[mi] + # XXX: setting this edge does not give us proper method invalidation, see + # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. + # invoking `code_llvm` also does the necessary codegen, as does calling the + # underlying C methods -- which GPUCompiler does, so everything Just Works. + + # prepare the slots + new_ci.slotnames = Symbol[Symbol("#self#"), :mode, :ft, :tt] + new_ci.slotflags = UInt8[0x00 for i = 1:4] + + # return the codegen world age + res = primal_return_type_world(mode, world, mi) + push!(new_ci.code, Core.Compiler.ReturnNode(res)) + push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` + @static if isdefined(Core, :DebugInfo) + else + push!(new_ci.codelocs, 1) # see note below end + new_ci.ssavaluetypes += 1 + + # NOTE: we keep the first entry of the original linetable, and use it for location info + # on the call to check_cache. we can't not have a codeloc (using 0 causes + # corruption of the back trace), and reusing the target function's info + # has as advantage that we see the name of the kernel in the backtraces. + + return new_ci +end + +@eval @inline function primal_return_type(mode::Mode, ft::Type, tt::Type) + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, primal_return_type_generator)) end ## @@ -3776,7 +3887,7 @@ function enzyme_extract_world(fn::LLVM.Function)::UInt throw(AssertionError("Enzyme: could not find world in $(string(fn))")) end -function enzyme_custom_extract_mi(orig::LLVM.Instruction, error = true) +function enzyme_custom_extract_mi(orig::LLVM.Instruction, error::Bool = true) operand = LLVM.called_operand(orig) if isa(operand, LLVM.Function) return enzyme_custom_extract_mi(operand::LLVM.Function, error) @@ -3786,7 +3897,7 @@ function enzyme_custom_extract_mi(orig::LLVM.Instruction, error = true) return nothing, nothing end -function enzyme_custom_extract_mi(orig::LLVM.Function, error = true) +function enzyme_custom_extract_mi(orig::LLVM.Function, error::Bool = true) mi = nothing RT = nothing for fattr in collect(function_attributes(orig)) @@ -3807,7 +3918,7 @@ function enzyme_custom_extract_mi(orig::LLVM.Function, error = true) return mi, RT end -function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error = true) +function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error::Bool = true) ty = nothing byref = nothing for fattr in collect(parameter_attributes(fn, idx)) @@ -3850,20 +3961,20 @@ const DumpPreEnzyme = Ref(false) const DumpPostWrap = Ref(false) function enzyme!( - job, - mod, - primalf, - TT, - mode, - width, - parallel, - actualRetType, - wrap, - modifiedBetween, - returnPrimal, - expectedTapeType, - loweredArgs, - boxedArgs, + job::CompilerJob, + mod::LLVM.Module, + primalf::LLVM.Function, + @nospecialize(TT::Type), + mode::API.CDerivativeMode, + width::Int, + parallel::Bool, + @nospecialize(actualRetType::Type), + wrap::Bool, + @nospecialize(modifiedBetween::NTuple{N, Bool} where N), + returnPrimal::Bool, + @nospecialize(expectedTapeType::Type), + loweredArgs::Set{Int}, + boxedArgs::Set{Int}, ) if DumpPreEnzyme[] API.EnzymeDumpModuleRef(mod.ref) @@ -4297,15 +4408,15 @@ end function create_abi_wrapper( enzymefn::LLVM.Function, - TT, - rettype, - actualRetType, + @nospecialize(TT::Type), + @nospecialize(rettype::Type), + @nospecialize(actualRetType::Type), Mode::API.CDerivativeMode, augmented, - width, - returnPrimal, - shadow_init, - world, + width::Int, + returnPrimal::Bool, + shadow_init::Bool, + world::UInt, interp, ) is_adjoint = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModeCombined @@ -4782,7 +4893,7 @@ function create_abi_wrapper( metadata(val)[LLVM.MD_dbg] = DILocation(0, 0, get_subprogram(llvm_f)) end - @inline function fixup_abi(index, value) + @inline function fixup_abi(index::Int, @nospecialize(value::LLVM.Value)) valty = sret_types[index] # Union becoming part of a tuple needs to be adjusted # See https://github.com/JuliaLang/julia/blob/81afdbc36b365fcbf3ae25b7451c6cb5798c0c3d/src/cgutils.cpp#L3795C1-L3801C121 @@ -4878,7 +4989,7 @@ function create_abi_wrapper( end for shadowv in shadows - c = emit_apply_generic!(builder, [unsafe_to_llvm(builder, add_one_in_place), shadowv]) + c = emit_apply_generic!(builder, LLVM.Value[unsafe_to_llvm(builder, add_one_in_place), shadowv]) if get_subprogram(llvm_f) !== nothing metadata(c)[LLVM.MD_dbg] = DILocation(0, 0, get_subprogram(llvm_f)) @@ -5123,7 +5234,7 @@ struct RemovedParam end # Modified from GPUCompiler classify_arguments function classify_arguments( - source_sig::Type, + @nospecialize(source_sig::Type), codegen_ft::LLVM.FunctionType, has_sret::Bool, has_returnroots::Bool, @@ -5225,7 +5336,7 @@ function classify_arguments( return args end -function isSpecialPtr(Ty) +function isSpecialPtr(@nospecialize(Ty::LLVM.LLVMType)) if !isa(Ty, LLVM.PointerType) return false end @@ -5239,7 +5350,7 @@ mutable struct CountTrackedPointers derived::Bool end -function CountTrackedPointers(T) +function CountTrackedPointers(@nospecialize(T::LLVM.LLVMType)) res = CountTrackedPointers(0, true, false) if isa(T, LLVM.PointerType) @@ -5276,7 +5387,7 @@ function CountTrackedPointers(T) end # must deserve sret -function deserves_rooting(T) +function deserves_rooting(@nospecialize(T::LLVM.LLVMType)) tracked = CountTrackedPointers(T) @assert !tracked.derived if tracked.count != 0 && !tracked.all @@ -5287,7 +5398,7 @@ end # https://github.com/JuliaLang/julia/blob/64378db18b512677fc6d3b012e6d1f02077af191/src/cgutils.cpp#L823 # returns if all unboxed -function for_each_uniontype_small(f, ty, counter = Ref(0)) +function for_each_uniontype_small(@nospecialize(f), @nospecialize(ty::Type), counter::Base.RefValue{Int} = Ref(0)) if counter[] > 127 return false end @@ -5306,9 +5417,9 @@ function for_each_uniontype_small(f, ty, counter = Ref(0)) end # From https://github.com/JuliaLang/julia/blob/038d31463f0ef744c8308bdbe87339b9c3f0b890/src/cgutils.cpp#L3108 -function union_alloca_type(UT) +function union_alloca_type(@nospecialize(UT::Type)) nbytes = 0 - function inner(jlrettype) + function inner(@nospecialize(jlrettype::Type)) if !(Base.issingletontype(jlrettype) && isa(jlrettype, DataType)) nbytes = max(nbytes, sizeof(jlrettype)) end @@ -5318,7 +5429,7 @@ function union_alloca_type(UT) end # From https://github.com/JuliaLang/julia/blob/e6bf81f39a202eedc7bd4f310c1ab60b5b86c251/src/codegen.cpp#L6447 -function is_sret(jlrettype) +function is_sret(@nospecialize(jlrettype::Type)) if jlrettype === Union{} # jlrettype == (jl_value_t*)jl_bottom_type return false @@ -5341,7 +5452,7 @@ function is_sret(jlrettype) end return false end -function is_sret_union(jlrettype) +function is_sret_union(@nospecialize(jlrettype::Type)) if jlrettype === Union{} # jlrettype == (jl_value_t*)jl_bottom_type return false @@ -5361,7 +5472,7 @@ end # https://github.com/JuliaLang/julia/blob/0a696a3842750fcedca8832bc0aabe9096c7658f/src/codegen.cpp#L6812 function get_return_info( - jlrettype, + @nospecialize(jlrettype::Type), )::Tuple{Union{Nothing,Type},Union{Nothing,Type},Union{Nothing,Type}} sret = nothing returnRoots = nothing @@ -5412,13 +5523,13 @@ end # Modified from GPUCompiler/src/irgen.jl:365 lower_byval function lower_convention( - functy::Type, + @nospecialize(functy::Type), mod::LLVM.Module, entry_f::LLVM.Function, - actualRetType::Type, - RetActivity, - TT, - run_enzyme, + @nospecialize(actualRetType::Type), + @nospecialize(RetActivity::Type), + @nospecialize(TT::Union{Type, Nothing}), + run_enzyme::Bool, ) entry_ft = LLVM.function_type(entry_f) @@ -5787,7 +5898,7 @@ function lower_convention( T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) - function inner(jlrettype) + function inner(@nospecialize(jlrettype::Type)) BB = BasicBlock(wrapper_f, "box_union") position!(builder, BB) @@ -6268,7 +6379,7 @@ function GPUCompiler.codegen( instruction_combining!(pm) LLVM.run!(pm, mod) end - toremove = [] + toremove = String[] for f in functions(mod) if !any( map( @@ -6558,7 +6669,7 @@ end name = meth.name jlmod = meth.module - function handleCustom(llvmfn, name, attrs = [], setlink = true, noinl = true) + function handleCustom(llvmfn::LLVM.Function, name::String, attrs::Vector{LLVM.Attribute} = LLVM.Attribute[], setlink::Bool = true, noinl::Bool = true) attributes = function_attributes(llvmfn) custom[k_name] = linkage(llvmfn) if setlink @@ -6580,7 +6691,7 @@ end handleCustom( llvmfn, "enzyme_custom", - [StringAttribute("enzyme_preserve_primal", "*")], + LLVM.Attribute[StringAttribute("enzyme_preserve_primal", "*")], ) continue end @@ -6594,7 +6705,7 @@ end handleCustom( llvmfn, "jl_inactive_inout", - [ + LLVM.Attribute[ StringAttribute("enzyme_inactive"), EnumAttribute("readnone"), EnumAttribute("speculatable"), @@ -6605,7 +6716,7 @@ end handleCustom( llvmfn, "jl_inactive_inout", - [ + LLVM.Attribute[ StringAttribute("enzyme_inactive"), EnumAttribute("memory", NoEffects.data), EnumAttribute("speculatable"), @@ -6620,7 +6731,7 @@ end handleCustom( llvmfn, "jl_to_tuple_type", - [ + LLVM.Attribute[ EnumAttribute("readonly"), EnumAttribute("inaccessiblememonly", 0), EnumAttribute("speculatable", 0), @@ -6632,7 +6743,7 @@ end handleCustom( llvmfn, "jl_to_tuple_type", - [ + LLVM.Attribute[ EnumAttribute( "memory", MemoryEffect( @@ -6655,7 +6766,7 @@ end handleCustom( llvmfn, "jl_mightalias", - [ + LLVM.Attribute[ EnumAttribute("readonly"), StringAttribute("enzyme_shouldrecompute"), StringAttribute("enzyme_inactive"), @@ -6670,7 +6781,7 @@ end handleCustom( llvmfn, "jl_mightalias", - [ + LLVM.Attribute[ EnumAttribute("memory", ReadOnlyEffects.data), StringAttribute("enzyme_shouldrecompute"), StringAttribute("enzyme_inactive"), @@ -6690,7 +6801,7 @@ end handleCustom( llvmfn, name, - [ + LLVM.Attribute[ EnumAttribute("readonly"), EnumAttribute("inaccessiblememonly"), EnumAttribute("speculatable"), @@ -6703,7 +6814,7 @@ end handleCustom( llvmfn, name, - [ + LLVM.Attribute[ EnumAttribute( "memory", MemoryEffect( @@ -6731,7 +6842,7 @@ end handleCustom( llvmfn, "enz_noop", - [ + LLVM.Attribute[ StringAttribute("enzyme_inactive"), EnumAttribute("readonly"), StringAttribute("enzyme_ta_norecur"), @@ -6741,7 +6852,7 @@ end handleCustom( llvmfn, "enz_noop", - [ + LLVM.Attribute[ StringAttribute("enzyme_inactive"), EnumAttribute("memory", ReadOnlyEffects.data), StringAttribute("enzyme_ta_norecur"), @@ -6759,7 +6870,7 @@ end handleCustom( llvmfn, "enz_noop", - [ + LLVM.Attribute[ StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation"), @@ -6777,7 +6888,7 @@ end handleCustom( llvmfn, "enz_noop", - [ + LLVM.Attribute[ StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation"), @@ -6822,7 +6933,7 @@ end handleCustom( llvmfn, "base_match", - [ + LLVM.Attribute[ StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation"), @@ -6865,12 +6976,12 @@ end if func == typeof(Base.enq_work) && length(sparam_vals) == 1 && first(sparam_vals) <: Task - handleCustom(llvmfn, "jl_enq_work", [StringAttribute("enzyme_ta_norecur")]) + handleCustom(llvmfn, "jl_enq_work", LLVM.Attribute[StringAttribute("enzyme_ta_norecur")]) continue end if func == typeof(Base.wait) || func == typeof(Base._wait) if length(sparam_vals) == 1 && first(sparam_vals) <: Task - handleCustom(llvmfn, "jl_wait", [StringAttribute("enzyme_ta_norecur")]) + handleCustom(llvmfn, "jl_wait", LLVM.Attribute[StringAttribute("enzyme_ta_norecur")]) end continue end @@ -6914,9 +7025,9 @@ end name = T == Float32 ? name * "f" : name attrs = if LLVM.version().major <= 15 - [LLVM.EnumAttribute("readnone"), StringAttribute("enzyme_shouldrecompute")] + LLVM.Attribute[LLVM.EnumAttribute("readnone"), StringAttribute("enzyme_shouldrecompute")] else - [EnumAttribute("memory", NoEffects.data), StringAttribute("enzyme_shouldrecompute")] + LLVM.Attribute[EnumAttribute("memory", NoEffects.data), StringAttribute("enzyme_shouldrecompute")] end handleCustom(llvmfn, name, attrs) end @@ -7351,7 +7462,7 @@ end loweredArgs, boxedArgs, ) - toremove = [] + toremove = String[] # Inline the wrapper for f in functions(mod) for b in blocks(f) @@ -7817,7 +7928,7 @@ end ) -function jl_set_typeof(v::Ptr{Cvoid}, T) +function jl_set_typeof(v::Ptr{Cvoid}, @nospecialize(T::Type)) tag = reinterpret(Ptr{Any}, reinterpret(UInt, v) - 8) Base.unsafe_store!(tag, T) # set tag return nothing @@ -8023,7 +8134,7 @@ end error("Return type `$rrt` not marked Const, but is ghost or const type.") end - sret_types = [] # Julia types of all returned variables + sret_types = Type[] # Julia types of all returned variables # By ref values we create and need to preserve ccexprs = Union{Expr,Symbol}[] # The expressions passed to the `llvmcall` @@ -8421,7 +8532,7 @@ end # JIT ## -function _link(job, (mod, adjoint_name, primal_name, TapeType)) +function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType)) if job.config.params.ABI <: InlineABI return CompileResult( Val((Symbol(mod), Symbol(adjoint_name))), @@ -8443,7 +8554,7 @@ function _link(job, (mod, adjoint_name, primal_name, TapeType)) ), ) end - if primal_name === nothing + if primal_name isa Nothing primal_ptr = C_NULL else primal_addr = JIT.lookup(primal_name) @@ -8507,7 +8618,7 @@ const cache_lock = ReentrantLock() obj = get(cache, key, nothing) if obj === nothing asm = _thunk(job) - obj = _link(job, asm) + obj = _link(job, asm...) cache[key] = obj end obj @@ -8526,34 +8637,20 @@ end @inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated @inline function thunkbase( - ctx, mi::Core.MethodInstance, - ::Val{World}, - ::Type{FA}, - ::Type{A}, - tt::Type{TT}, - ::Val{Mode}, - ::Val{width}, - ::Val{ModifiedBetween}, - ::Val{ReturnPrimal}, - ::Val{ShadowInit}, - ::Type{ABI}, - ::Val{ErrIfFuncWritten}, - ::Val{RuntimeActivity}, -) where { - FA<:Annotation, - A<:Annotation, - TT, - Mode, - ModifiedBetween, - width, - ReturnPrimal, - ShadowInit, - World, - ABI, - ErrIfFuncWritten, - RuntimeActivity, -} + World::Union{UInt, Nothing}, + @nospecialize(FA::Type{<:Annotation}), + @nospecialize(A::Type{<:Annotation}), + @nospecialize(TT::Type), + Mode::API.CDerivativeMode, + width::Int, + @nospecialize(ModifiedBetween::(NTuple{N, Bool} where N)), + ReturnPrimal::Bool, + ShadowInit::Bool, + @nospecialize(ABI::Type), + ErrIfFuncWritten::Bool, + RuntimeActivity::Bool, +) target = Compiler.EnzymeTarget() params = Compiler.EnzymeCompilerParams( Tuple{FA,TT.parameters...}, @@ -8579,11 +8676,6 @@ end interp = GPUCompiler.get_interpreter(tmp_job) # TODO check compile return here, early - # rrt = Core.Compiler.return_type(f, primal.tt) # nothing - rrt = something( - Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), - Any, - ) rrt = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype run_enzyme = true @@ -8722,198 +8814,300 @@ end activate(ctx) try return thunkbase( - ctx, mi, - Val(nothing), + nothing, FA, A, TT, - Val(Mode), - Val(width), - Val(ModifiedBetween), - Val(ReturnPrimal), - Val(ShadowInit), + Mode, + width, + ModifiedBetween, + ReturnPrimal, + ShadowInit, ABI, - Val(ErrIfFuncWritten), - Val(RuntimeActivity), - ) #=World=# + ErrIfFuncWritten, + RuntimeActivity, + ) finally deactivate(ctx) dispose(ts_ctx) end end -@inline @generated function thunk( - ::Val{World}, - ::Type{FA}, - ::Type{A}, - tt::Type{TT}, - ::Val{Mode}, - ::Val{width}, - ::Val{ModifiedBetween}, - ::Val{ReturnPrimal}, - ::Val{ShadowInit}, - ::Type{ABI}, - ::Val{ErrIfFuncWritten}, - ::Val{RuntimeActivity}, -) where { - FA<:Annotation, - A<:Annotation, - TT, - Mode, - ModifiedBetween, - width, - ReturnPrimal, - ShadowInit, - World, - ABI, - ErrIfFuncWritten, - RuntimeActivity, -} - mi = fspec(eltype(FA), TT, World) +function thunk_generator(world::UInt, source::LineNumberNode, @nospecialize(FA::Type), @nospecialize(A::Type), @nospecialize(TT::Type), Mode::Enzyme.API.CDerivativeMode, Width::Int, @nospecialize(ModifiedBetween::(NTuple{N, Bool} where N)), ReturnPrimal::Bool, ShadowInit::Bool, @nospecialize(ABI::Type), ErrIfFuncWritten::Bool, RuntimeActivity::Bool, @nospecialize(self), @nospecialize(fakeworld), @nospecialize(fa::Type), @nospecialize(a::Type), @nospecialize(tt::Type), @nospecialize(mode::Type), @nospecialize(width::Type), @nospecialize(modifiedbetween::Type), @nospecialize(returnprimal::Type), @nospecialize(shadowinit::Type), @nospecialize(abi::Type), @nospecialize(erriffuncwritten::Type), @nospecialize(runtimeactivity::Type)) + @nospecialize + + parmnames = (:fakeworld, :fa, :a, :tt, :mode, :width, :modifiedbetween, :returnprimal, :shadowinit, :abi, :erriffuncwritten, :runtimeactivity) + stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, parmnames...), Core.svec()) + + ft = eltype(FA) + primal_tt = Tuple{map(eltype, TT.parameters)...} + # look up the method match + method_error = :(throw(MethodError($ft, $primal_tt, $world))) + sig = Tuple{ft, primal_tt.parameters...} + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + match = ccall(:jl_gf_invoke_lookup_worlds, Any, + (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), + sig, #=mt=# nothing, world, min_world, max_world) + match === nothing && return stub(world, source, method_error) + + # look up the method and code instance + mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, + (Any, Any, Any), match.method, match.spec_types, match.sparams) + + ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo + + # prepare a new code info + new_ci = copy(ci) + empty!(new_ci.code) + @static if isdefined(Core, :DebugInfo) + new_ci.debuginfo = Core.DebugInfo(:none) + else + empty!(new_ci.codelocs) + resize!(new_ci.linetable, 1) # see note below + end + empty!(new_ci.ssaflags) + new_ci.ssavaluetypes = 0 + # new_ci.min_world = min_world[] + new_ci.min_world = world + new_ci.max_world = max_world[] + new_ci.edges = Core.MethodInstance[mi] + # XXX: setting this edge does not give us proper method invalidation, see + # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. + # invoking `code_llvm` also does the necessary codegen, as does calling the + # underlying C methods -- which GPUCompiler does, so everything Just Works. + ts_ctx = JuliaContext() ctx = context(ts_ctx) activate(ctx) res = try thunkbase( - ctx, mi, - Val(World), + world, FA, A, TT, - Val(Mode), - Val(width), - Val(ModifiedBetween), - Val(ReturnPrimal), - Val(ShadowInit), + Mode, + Width, + ModifiedBetween, + ReturnPrimal, + ShadowInit, ABI, - Val(ErrIfFuncWritten), - Val(RuntimeActivity), + ErrIfFuncWritten, + RuntimeActivity, ) finally deactivate(ctx) dispose(ts_ctx) end - return quote - Base.@_inline_meta - return $(res) + + # prepare the slots + new_ci.slotnames = Symbol[Symbol("#self#"), parmnames...] + new_ci.slotflags = UInt8[0x00 for i = 1:length(new_ci.slotnames)] + + # return the codegen world age + push!(new_ci.code, Core.Compiler.ReturnNode(res)) + push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` + @static if isdefined(Core, :DebugInfo) + else + push!(new_ci.codelocs, 1) # see note below end -end + new_ci.ssavaluetypes += 1 -import GPUCompiler: deferred_codegen_jobs + # NOTE: we keep the first entry of the original linetable, and use it for location info + # on the call to check_cache. we can't not have a codeloc (using 0 causes + # corruption of the back trace), and reusing the target function's info + # has as advantage that we see the name of the kernel in the backtraces. -@generated function deferred_codegen( - ::Val{World}, - ::Type{FA}, - ::Val{TT}, - ::Val{A}, - ::Val{Mode}, - ::Val{width}, - ::Val{ModifiedBetween}, - ::Val{ReturnPrimal}, - ::Val{ShadowInit}, - ::Type{ExpectedTapeType}, - ::Val{ErrIfFuncWritten}, - ::Val{RuntimeActivity}, + return new_ci +end + +@eval @inline function thunk( + fakeworld::Val{0}, + fa::Type{FA}, + a::Type{A}, + tt::Type{TT}, + mode::Val{Mode}, + width::Val{Width}, + modifiedbetween::Val{ModifiedBetween}, + returnprimal::Val{ReturnPrimal}, + shadowinit::Val{ShadowInit}, + abi::Type{ABI}, + erriffuncwritten::Val{ErrIfFuncWritten}, + runtimeactivity::Val{RuntimeActivity}, ) where { - World, FA<:Annotation, + A<:Annotation, TT, - A, Mode, - width, + Width, ModifiedBetween, ReturnPrimal, ShadowInit, - ExpectedTapeType, + ABI, ErrIfFuncWritten, RuntimeActivity, } - JuliaContext() do ctx - Base.@_inline_meta - mi = fspec(eltype(FA), TT, World) - target = EnzymeTarget() + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, thunk_generator)) +end - rt2 = if A isa UnionAll - params = EnzymeCompilerParams( - Tuple{FA,TT.parameters...}, - Mode, - width, - remove_innerty(A), - true, - true, - ModifiedBetween, - ReturnPrimal, - ShadowInit, - ExpectedTapeType, - FFIABI, - ErrIfFuncWritten, - RuntimeActivity, - ) #=abiwrap=# - tmp_job = Compiler.CompilerJob( - mi, - CompilerConfig(target, params; kernel = false), - World, - ) +import GPUCompiler: deferred_codegen_jobs - interp = GPUCompiler.get_interpreter(tmp_job) +function deferred_id_generator(world::UInt, source::LineNumberNode, @nospecialize(FA::Type), @nospecialize(A::Type), @nospecialize(TT::Type), Mode::Enzyme.API.CDerivativeMode, Width::Int, @nospecialize(ModifiedBetween::(NTuple{N, Bool} where N)), ReturnPrimal::Bool, ShadowInit::Bool, @nospecialize(ExpectedTapeType::Type), ErrIfFuncWritten::Bool, RuntimeActivity::Bool, @nospecialize(self), @nospecialize(fa::Type), @nospecialize(a::Type), @nospecialize(tt::Type), @nospecialize(mode::Type), @nospecialize(width::Type), @nospecialize(modifiedbetween::Type), @nospecialize(returnprimal::Type), @nospecialize(shadowinit::Type), @nospecialize(expectedtapetype::Type), @nospecialize(erriffuncwritten::Type), @nospecialize(runtimeactivity::Type)) + @nospecialize + + parmnames = (:fa, :a, :tt, :mode, :width, :modifiedbetween, :returnprimal, :shadowinit, :expectedtapetype, :erriffuncwritten, :runtimeactivity) + stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, parmnames...), Core.svec()) + + ft = eltype(FA) + primal_tt = Tuple{map(eltype, TT.parameters)...} + # look up the method match + method_error = :(throw(MethodError($ft, $primal_tt, $world))) + sig = Tuple{ft, primal_tt.parameters...} + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + match = ccall(:jl_gf_invoke_lookup_worlds, Any, + (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), + sig, #=mt=# nothing, world, min_world, max_world) + match === nothing && return stub(world, source, method_error) + + # look up the method and code instance + mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, + (Any, Any, Any), match.method, match.spec_types, match.sparams) + + ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo + + # prepare a new code info + new_ci = copy(ci) + empty!(new_ci.code) + @static if isdefined(Core, :DebugInfo) + new_ci.debuginfo = Core.DebugInfo(:none) + else + empty!(new_ci.codelocs) + resize!(new_ci.linetable, 1) # see note below + end + empty!(new_ci.ssaflags) + new_ci.ssavaluetypes = 0 + # new_ci.min_world = min_world[] + new_ci.min_world = world + new_ci.max_world = max_world[] + new_ci.edges = Core.MethodInstance[mi] + # XXX: setting this edge does not give us proper method invalidation, see + # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. + # invoking `code_llvm` also does the necessary codegen, as does calling the + # underlying C methods -- which GPUCompiler does, so everything Just Works. + + target = EnzymeTarget() - rrt = something( - Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), - Any, - ) + rt2 = if A isa UnionAll + rrt = primal_return_type_world(Mode == API.DEM_ForwardMode ? Forward : Reverse, world, mi) - # Don't error here but default to nothing return since in cuda context we don't use the device overrides - if rrt == Union{} - rrt = Nothing - end + # Don't error here but default to nothing return since in cuda context we don't use the device overrides + if rrt == Union{} + rrt = Nothing + end - if !(A <: Const) && guaranteed_const_nongen(rrt, World) - estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant" - return quote - error($estr) - end + if !(A <: Const) && guaranteed_const_nongen(rrt, world) + estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant" + return quote + error($estr) end - A{rrt} - else - @assert A isa DataType - A end + A{rrt} + else + @assert A isa DataType + A + end - params = EnzymeCompilerParams( - Tuple{FA,TT.parameters...}, - Mode, - width, - rt2, - true, - true, - ModifiedBetween, - ReturnPrimal, - ShadowInit, - ExpectedTapeType, - FFIABI, - ErrIfFuncWritten, - RuntimeActivity, - ) #=abiwrap=# - job = - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false), World) + params = EnzymeCompilerParams( + Tuple{FA,TT.parameters...}, + Mode, + Width, + rt2, + true, + true, + ModifiedBetween, + ReturnPrimal, + ShadowInit, + ExpectedTapeType, + FFIABI, + ErrIfFuncWritten, + RuntimeActivity, + ) #=abiwrap=# + job = + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false), world) - addr = get_trampoline(job) - id = Base.reinterpret(Int, pointer(addr)) - deferred_codegen_jobs[id] = job + addr = get_trampoline(job) + id = Base.reinterpret(Int, pointer(addr)) + deferred_codegen_jobs[id] = job - quote - Base.@_inline_meta - ccall( - "extern deferred_codegen", - llvmcall, - Ptr{Cvoid}, - (Ptr{Cvoid},), - $(reinterpret(Ptr{Cvoid}, id)), - ) - end + # prepare the slots + new_ci.slotnames = Symbol[Symbol("#self#"), parmnames...] + new_ci.slotflags = UInt8[0x00 for i = 1:length(new_ci.slotnames)] + + # return the codegen world age + push!(new_ci.code, Core.Compiler.ReturnNode(reinterpret(Ptr{Cvoid}, id))) + push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` + @static if isdefined(Core, :DebugInfo) + else + push!(new_ci.codelocs, 1) # see note below end + new_ci.ssavaluetypes += 1 + + # NOTE: we keep the first entry of the original linetable, and use it for location info + # on the call to check_cache. we can't not have a codeloc (using 0 causes + # corruption of the back trace), and reusing the target function's info + # has as advantage that we see the name of the kernel in the backtraces. + + return new_ci +end + +@eval @inline function deferred_id_codegen( + fa::Type{FA}, + a::Type{A}, + tt::Type{TT}, + mode::Val{Mode}, + width::Val{Width}, + modifiedbetween::Val{ModifiedBetween}, + returnprimal::Val{ReturnPrimal}, + shadowinit::Val{ShadowInit}, + expectedtapetype::Type{ExpectedTapeType}, + erriffuncwritten::Val{ErrIfFuncWritten}, + runtimeactivity::Val{RuntimeActivity}, +) where { + FA<:Annotation, + A<:Annotation, + TT, + Mode, + Width, + ModifiedBetween, + ReturnPrimal, + ShadowInit, + ExpectedTapeType, + ErrIfFuncWritten, + RuntimeActivity, +} + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, deferred_id_generator)) +end + +@inline function deferred_codegen( + @nospecialize(fa::Type), + @nospecialize(a::Type), + @nospecialize(tt::Type), + @nospecialize(mode::Val), + @nospecialize(width::Val), + @nospecialize(modifiedbetween::Val), + @nospecialize(returnprimal::Val), + @nospecialize(shadowinit::Val), + @nospecialize(expectedtapetype::Type), + @nospecialize(erriffuncwritten::Val), + @nospecialize(runtimeactivity::Val) +) + id = deferred_id_codegen(fa, a, tt, mode, width, modifiedbetween, returnprimal, shadowinit, expectedtapetype, erriffuncwritten, runtimeactivity) + ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), id) end include("compiler/reflection.jl") @@ -8981,7 +9175,7 @@ include("compiler/reflection.jl") emit_box_int32!(builder, len) end - tag = emit_apply_type!(builder, NTuple, (boxed_count, unsafe_to_llvm(builder, T))) + tag = emit_apply_type!(builder, NTuple, LLVM.Value[boxed_count, unsafe_to_llvm(builder, T)]) fullsize = nuwmul!(builder, len, LLVM.ConstantInt(sizeof(Int))) obj = emit_allocobj!(builder, tag, fullsize, needs_dynamic_size_workaround) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 22761c2d1a..7648236ffb 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -96,31 +96,31 @@ EnzymeInterpreter( handler = nothing ) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, broadcast_rewrite, handler) -Core.Compiler.InferenceParams(interp::EnzymeInterpreter) = interp.inf_params -Core.Compiler.OptimizationParams(interp::EnzymeInterpreter) = interp.opt_params -get_inference_world(interp::EnzymeInterpreter) = interp.world -Core.Compiler.get_inference_cache(interp::EnzymeInterpreter) = interp.local_cache +Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params +Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params +get_inference_world(@nospecialize(interp::EnzymeInterpreter)) = interp.world +Core.Compiler.get_inference_cache(@nospecialize(interp::EnzymeInterpreter)) = interp.local_cache @static if HAS_INTEGRATED_CACHE - Core.Compiler.cache_owner(interp::EnzymeInterpreter) = interp.token + Core.Compiler.cache_owner(@nospecialize(interp::EnzymeInterpreter)) = interp.token else - Core.Compiler.code_cache(interp::EnzymeInterpreter) = + Core.Compiler.code_cache(@nospecialize(interp::EnzymeInterpreter)) = WorldView(interp.code_cache, interp.world) end # No need to do any locking since we're not putting our results into the runtime cache -Core.Compiler.lock_mi_inference(::EnzymeInterpreter, ::MethodInstance) = nothing -Core.Compiler.unlock_mi_inference(::EnzymeInterpreter, ::MethodInstance) = nothing +Core.Compiler.lock_mi_inference(@nospecialize(::EnzymeInterpreter), ::MethodInstance) = nothing +Core.Compiler.unlock_mi_inference(@nospecialize(::EnzymeInterpreter), ::MethodInstance) = nothing -Core.Compiler.may_optimize(::EnzymeInterpreter) = true -Core.Compiler.may_compress(::EnzymeInterpreter) = true +Core.Compiler.may_optimize(@nospecialize(::EnzymeInterpreter)) = true +Core.Compiler.may_compress(@nospecialize(::EnzymeInterpreter)) = true # From @aviatesk: # `may_discard_trees = true`` means a complicated (in terms of inlineability) source will be discarded, # but as far as I understand Enzyme wants "always inlining, except special cased functions", # so I guess we really don't want to discard sources? -Core.Compiler.may_discard_trees(::EnzymeInterpreter) = false -Core.Compiler.verbose_stmt_info(::EnzymeInterpreter) = false +Core.Compiler.may_discard_trees(@nospecialize(::EnzymeInterpreter)) = false +Core.Compiler.verbose_stmt_info(@nospecialize(::EnzymeInterpreter)) = false -Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = +Core.Compiler.method_table(@nospecialize(interp::EnzymeInterpreter), sv::InferenceState) = Core.Compiler.OverlayMethodTable(interp.world, interp.method_table) function is_alwaysinline_func(@nospecialize(TT)) @@ -194,7 +194,7 @@ Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = using Core.Compiler: ArgInfo, StmtInfo, AbsIntState function Core.Compiler.abstract_call_gf_by_type( - interp::EnzymeInterpreter, + @nospecialize(interp::EnzymeInterpreter), @nospecialize(f), arginfo::ArgInfo, si::StmtInfo, @@ -214,21 +214,27 @@ function Core.Compiler.abstract_call_gf_by_type( callinfo = ret.info method_table = Core.Compiler.method_table(interp) specTypes = simplify_kw(atype) + caller = if callinfo isa Core.Compiler.MethodMatchInfo && callinfo.results isa Core.Compiler.MethodLookupResult + callinfo.results + else + nothing + end + if is_primitive_func(specTypes) callinfo = NoInlineCallInfo(callinfo, atype, :primitive) elseif is_alwaysinline_func(specTypes) callinfo = AlwaysInlineCallInfo(callinfo, atype) - elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) + elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table, caller) callinfo = NoInlineCallInfo(callinfo, atype, :inactive) else if interp.forward_rules - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) + if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table, caller) callinfo = NoInlineCallInfo(callinfo, atype, :frule) end end if interp.reverse_rules - if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) + if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table, caller) callinfo = NoInlineCallInfo(callinfo, atype, :rrule) end end @@ -243,7 +249,7 @@ end let # overload `inlining_policy` @static if VERSION ≥ v"1.11.0-DEV.879" sigs_ex = :( - interp::EnzymeInterpreter, + @nospecialize(interp::EnzymeInterpreter), @nospecialize(src), @nospecialize(info::Core.Compiler.CallInfo), stmt_flag::UInt32, @@ -256,7 +262,7 @@ let # overload `inlining_policy` ) else sigs_ex = :( - interp::EnzymeInterpreter, + @nospecialize(interp::EnzymeInterpreter), @nospecialize(src), @nospecialize(info::Core.Compiler.CallInfo), stmt_flag::UInt8, @@ -753,7 +759,7 @@ end end function abstract_call_known( - interp::EnzymeInterpreter, + @nospecialize(interp::EnzymeInterpreter), @nospecialize(f), arginfo::ArgInfo, si::StmtInfo, diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 8c42ee2b55..66313d3853 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -18,20 +18,20 @@ end const RunAttributor = Ref(true) function pipeline_options(; - lower_intrinsics = true, - dump_native = false, - external_use = false, - llvm_only = false, - always_inline = true, - enable_early_simplifications = true, - enable_early_optimizations = true, - enable_scalar_optimizations = true, - enable_loop_optimizations = true, - enable_vector_pipeline = true, - remove_ni = true, - cleanup = true, - Size = 0, - Speedup = 3, + lower_intrinsics::Bool = true, + dump_native::Bool = false, + external_use::Bool = false, + llvm_only::Bool = false, + always_inline::Bool = true, + enable_early_simplifications::Bool = true, + enable_early_optimizations::Bool = true, + enable_scalar_optimizations::Bool = true, + enable_loop_optimizations::Bool = true, + enable_vector_pipeline::Bool = true, + remove_ni::Bool = true, + cleanup::Bool = true, + Size::Cint = 0, + Speedup::Cint = 3, ) return PipelineConfig( Speedup, @@ -51,7 +51,7 @@ function pipeline_options(; ) end -function run_jl_pipeline(pm, tm; kwargs...) +function run_jl_pipeline(pm::ModulePassManager, tm::LLVM.TargetMachine; kwargs...) config = Ref(pipeline_options(; kwargs...)) function jl_pipeline(m) @dispose pb = NewPMPassBuilder() begin @@ -75,12 +75,12 @@ else end @static if VERSION < v"1.11-" - function gc_invariant_verifier_tm!(pm, tm, cond) + function gc_invariant_verifier_tm!(pm::ModulePassManager, tm::LLVM.TargetMachine, cond::Bool) gc_invariant_verifier!(pm, cond) end else - function gc_invariant_verifier_tm!(pm, tm, cond) - function gc_invariant_verifier(mod) + function gc_invariant_verifier_tm!(pm::ModulePassManager, tm::LLVM.TargetMachine, cond::Bool) + function gc_invariant_verifier(mod::LLVM.Module) @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm @@ -96,12 +96,12 @@ else end @static if VERSION < v"1.11-" - function propagate_julia_addrsp_tm!(pm, tm) + function propagate_julia_addrsp_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) propagate_julia_addrsp!(pm) end else - function propagate_julia_addrsp_tm!(pm, tm) - function prop_julia_addr(mod) + function propagate_julia_addrsp_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) + function prop_julia_addr(mod::LLVM.Module) @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm @@ -117,12 +117,12 @@ else end @static if VERSION < v"1.11-" - function alloc_opt_tm!(pm, tm) + function alloc_opt_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) alloc_opt!(pm) end else - function alloc_opt_tm!(pm, tm) - function alloc_opt(mod) + function alloc_opt_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) + function alloc_opt(mod::LLVM.Module) @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm @@ -138,12 +138,12 @@ else end @static if VERSION < v"1.11-" - function remove_ni_tm!(pm, tm) + function remove_ni_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) remove_ni!(pm) end else - function remove_ni_tm!(pm, tm) - function remove_ni(mod) + function remove_ni_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) + function remove_ni(mod::LLVM.Module) @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, RemoveNIPass()) @@ -157,12 +157,12 @@ else end @static if VERSION < v"1.11-" - function julia_licm_tm!(pm, tm) + function julia_licm_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) julia_licm!(pm) end else - function julia_licm_tm!(pm, tm) - function julia_licm(mod) + function julia_licm_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) + function julia_licm(mod::LLVM.Module) @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm @@ -181,12 +181,12 @@ else end @static if VERSION < v"1.11-" - function lower_simdloop_tm!(pm, tm) + function lower_simdloop_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) lower_simdloop!(pm) end else - function lower_simdloop_tm!(pm, tm) - function lower_simdloop(mod) + function lower_simdloop_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) + function lower_simdloop(mod::LLVM.Module) @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm @@ -205,7 +205,7 @@ else end -function loop_optimizations_tm!(pm, tm) +function loop_optimizations_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) @static if true || VERSION < v"1.11-" lower_simdloop_tm!(pm, tm) licm!(pm) @@ -235,7 +235,7 @@ function loop_optimizations_tm!(pm, tm) end -function more_loop_optimizations_tm!(pm, tm) +function more_loop_optimizations_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) @static if true || VERSION < v"1.11-" loop_rotate!(pm) # moving IndVarSimplify here prevented removing the loop in perf_sumcartesian(10:-1:1) @@ -287,12 +287,12 @@ function more_loop_optimizations_tm!(pm, tm) end @static if VERSION < v"1.11-" - function demote_float16_tm!(pm, tm) + function demote_float16_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) demote_float16!(pm) end else - function demote_float16_tm!(pm, tm) - function demote_float16(mod) + function demote_float16_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) + function demote_float16(mod::LLVM.Module) @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm @@ -308,12 +308,12 @@ else end @static if VERSION < v"1.11-" - function lower_exc_handlers_tm!(pm, tm) + function lower_exc_handlers_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) lower_exc_handlers!(pm) end else - function lower_exc_handlers_tm!(pm, tm) - function lower_exc_handlers(mod) + function lower_exc_handlers_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) + function lower_exc_handlers(mod::LLVM.Module) @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm @@ -329,12 +329,12 @@ else end @static if VERSION < v"1.11-" - function lower_ptls_tm!(pm, tm, dump_native) + function lower_ptls_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine, dump_native::Bool) lower_ptls!(pm, dump_native) end else - function lower_ptls_tm!(pm, tm, dump_native) - function lower_ptls(mod) + function lower_ptls_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine, dump_native::Bool) + function lower_ptls(mod::LLVM.Module) @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, LowerPTLSPass()) @@ -348,13 +348,13 @@ else end @static if VERSION < v"1.11-" - function combine_mul_add_tm!(pm, tm) + function combine_mul_add_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) combine_mul_add!(pm) end else - function combine_mul_add_tm!(pm, tm) + function combine_mul_add_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) @static if VERSION < v"1.12.0-DEV.1390" - function combine_mul_add(mod) + function combine_mul_add(mod::LLVM.Module) @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm @@ -371,12 +371,12 @@ end end @static if VERSION < v"1.11-" - function late_lower_gc_frame_tm!(pm, tm) + function late_lower_gc_frame_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) late_lower_gc_frame!(pm) end else - function late_lower_gc_frame_tm!(pm, tm) - function late_lower_gc_frame(mod) + function late_lower_gc_frame_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) + function late_lower_gc_frame(mod::LLVM.Module) @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm @@ -392,12 +392,12 @@ else end @static if VERSION < v"1.11-" - function final_lower_gc_tm!(pm, tm) + function final_lower_gc_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) final_lower_gc!(pm) end else - function final_lower_gc_tm!(pm, tm) - function final_lower_gc(mod) + function final_lower_gc_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) + function final_lower_gc(mod::LLVM.Module) @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm @@ -413,7 +413,7 @@ else end @static if VERSION < v"1.11-" - function cpu_features_tm!(pm, tm) + function cpu_features_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) @static if isdefined(LLVM.Interop, :cpu_features!) LLVM.Interop.cpu_features!(pm) else @@ -423,7 +423,7 @@ end end end else - function cpu_features_tm!(pm, tm) + function cpu_features_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) function cpu_features(mod) @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm @@ -437,7 +437,7 @@ else end end -function addNA(inst, node::LLVM.Metadata, MD) +function addNA(@nospecialize(inst::LLVM.Instruction), @nospecialize(node::LLVM.Metadata), MD::LLVM.MDKind) md = metadata(inst) next = nothing if haskey(md, MD) @@ -475,15 +475,6 @@ function addr13NoAlias(mod::LLVM.Module) end end -function source_elem(v) - @static if LLVM.version() >= v"15" - LLVM.LLVMType(LLVM.API.LLVMGetGEPSourceElementType(v)) - else - eltype(value_type(operands(v)[1])) - end -end - - ## given code like # % a = alloca # ... @@ -788,7 +779,7 @@ function nodecayed_phis!(mod::LLVM.Module) v0 = v - @inline function getparent(v, offset, hasload) + @inline function getparent(@nospecialize(v::LLVM.Value), @nospecialize(offset::LLVM.Value), hasload::Bool) if addr == 11 && addrspace(value_type(v)) == 10 return v, offset, hasload end @@ -1099,7 +1090,7 @@ function nodecayed_phis!(mod::LLVM.Module) nphi = nextvs[inst] - function ogbc(x) + function ogbc(@nospecialize(x::LLVM.Value)) while isa(x, LLVM.BitCastInst) x = operands(x)[1] end @@ -1359,7 +1350,7 @@ function fix_decayaddr!(mod::LLVM.Module) position!(nb, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(st))) ld = load!(nb, elt, temp) si = store!(nb, ld, operands(inst)[1]) - julia_post_cache_store(si.ref, nb.ref, C_NULL) + julia_post_cache_store(si.ref, nb.ref, reinterpret(Ptr{UInt64}, C_NULL)) end end @@ -1401,7 +1392,7 @@ function pre_attr!(mod::LLVM.Module) return nothing end -function jl_inst_simplify!(PM) +function jl_inst_simplify!(PM::LLVM.ModulePassManager) ccall( (:LLVMAddJLInstSimplifyPass, API.libEnzyme), Cvoid, @@ -1412,7 +1403,7 @@ end function post_attr!(mod::LLVM.Module) end -function prop_global!(g) +function prop_global!(g::LLVM.GlobalVariable) newfns = String[] changed = false todo = Tuple{Vector{Cuint},LLVM.Value}[] @@ -1484,7 +1475,7 @@ function prop_global!(g) end # From https://llvm.org/doxygen/IR_2Instruction_8cpp_source.html#l00959 -function mayWriteToMemory(inst::LLVM.Instruction; err_is_readonly = false)::Bool +function mayWriteToMemory(@nospecialize(inst::LLVM.Instruction); err_is_readonly::Bool = false)::Bool # we will ignore fense here if isa(inst, LLVM.StoreInst) return true @@ -1991,7 +1982,7 @@ function propagate_returned!(mod::LLVM.Module) end end -function delete_writes_into_removed_args(fn::LLVM.Function, toremove, keepret::Bool) +function delete_writes_into_removed_args(fn::LLVM.Function, toremove::Vector{Int64}, keepret::Bool) args = collect(parameters(fn)) for tr in toremove tr = tr + 1 @@ -2145,7 +2136,7 @@ function detect_writeonly!(mod::LLVM.Module) return nothing end -function validate_return_roots!(mod) +function validate_return_roots!(mod::LLVM.Module) for f in functions(mod) srets = [] enzyme_srets = Int[] @@ -2315,7 +2306,7 @@ function validate_return_roots!(mod) end end -function checkNoAssumeFalse(mod, shouldshow = false) +function checkNoAssumeFalse(mod::LLVM.Module, shouldshow::Bool = false) for f in functions(mod) for bb in blocks(f), inst in instructions(bb) if !isa(inst, LLVM.CallInst) @@ -2358,7 +2349,7 @@ end cse!(pm) = LLVM.API.LLVMAddEarlyCSEPass(pm) -function removeDeadArgs!(mod::LLVM.Module, tm) +function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine) # We need to run globalopt first. This is because remove dead args will otherwise # take internal functions and replace their args with undef. Then on LLVM up to # and including 12 (but fixed 13+), Attributor will incorrectly change functions that @@ -2375,13 +2366,13 @@ function removeDeadArgs!(mod::LLVM.Module, tm) mod, "llvm.enzymefakeuse", funcT, - [EnumAttribute("readnone"), EnumAttribute("nofree")], + LLVM.Attribute[EnumAttribute("readnone"), EnumAttribute("nofree")], ) rfunc, _ = get_function!( mod, "llvm.enzymefakeread", funcT, - [ + LLVM.Attribute[ EnumAttribute("readonly"), EnumAttribute("nofree"), EnumAttribute("argmemonly"), @@ -2391,7 +2382,7 @@ function removeDeadArgs!(mod::LLVM.Module, tm) mod, "llvm.enzyme.sret_use", funcT, - [ + LLVM.Attribute[ EnumAttribute("readonly"), EnumAttribute("nofree"), EnumAttribute("argmemonly"), @@ -2402,19 +2393,19 @@ function removeDeadArgs!(mod::LLVM.Module, tm) mod, "llvm.enzymefakeuse", funcT, - [EnumAttribute("memory", NoEffects.data), EnumAttribute("nofree")], + LLVM.Attribute[EnumAttribute("memory", NoEffects.data), EnumAttribute("nofree")], ) rfunc, _ = get_function!( mod, "llvm.enzymefakeread", funcT, - [EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")], + LLVM.Attribute[EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")], ) sfunc, _ = get_function!( mod, "llvm.enzyme.sret_use", funcT, - [EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")], + LLVM.Attribute[EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")], ) end @@ -2563,7 +2554,7 @@ function removeDeadArgs!(mod::LLVM.Module, tm) eraseInst(mod, func) end -function optimize!(mod::LLVM.Module, tm) +function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine) addr13NoAlias(mod) # everying except unroll, slpvec, loop-vec # then finish Julia GC @@ -2696,13 +2687,13 @@ function optimize!(mod::LLVM.Module, tm) end # https://github.com/JuliaLang/julia/blob/2eb5da0e25756c33d1845348836a0a92984861ac/src/aotcompile.cpp#L603 -function addTargetPasses!(pm, tm, trip) +function addTargetPasses!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine, trip::String) add_library_info!(pm, trip) add_transform_info!(pm, tm) end # https://github.com/JuliaLang/julia/blob/2eb5da0e25756c33d1845348836a0a92984861ac/src/aotcompile.cpp#L620 -function addOptimizationPasses!(pm, tm) +function addOptimizationPasses!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) add!(pm, FunctionPass("ReinsertGCMarker", reinsert_gcmarker_pass!)) constant_merge!(pm) @@ -2787,7 +2778,7 @@ function addOptimizationPasses!(pm, tm) aggressive_dce!(pm) end -function addMachinePasses!(pm, tm) +function addMachinePasses!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine) combine_mul_add_tm!(pm, tm) # TODO: createDivRemPairs[] @@ -2795,7 +2786,7 @@ function addMachinePasses!(pm, tm) gvn!(pm) end -function addJuliaLegalizationPasses!(pm, tm, lower_intrinsics = true) +function addJuliaLegalizationPasses!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine, lower_intrinsics::Bool = true) if lower_intrinsics # LowerPTLS removes an indirect call. As a result, it is likely to trigger # LLVM's devirtualization heuristics, which would result in the entire @@ -2830,7 +2821,7 @@ function addJuliaLegalizationPasses!(pm, tm, lower_intrinsics = true) end end -function post_optimze!(mod, tm, machine = true) +function post_optimze!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true) addr13NoAlias(mod) removeDeadArgs!(mod, tm) for f in collect(functions(mod)) diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl index 304372951c..40b1293d8d 100644 --- a/src/compiler/reflection.jl +++ b/src/compiler/reflection.jl @@ -18,13 +18,14 @@ function get_job( ) tt = Tuple{map(eltype, types.parameters)...} - if world === nothing - world = codegen_world_age(Core.Typeof(func), tt) - end - primal = fspec(Core.Typeof(func), types, world) - rt = Core.Compiler.return_type(func, tt, world) + primal, rt = if world isa Nothing + fspec(Core.Typeof(func), types), Compiler.primal_return_type(mode == API.DEM_ForwardMode ? Forward : Reverse, Core.Typeof(func), tt) + else + fspec(Core.Typeof(func), types, world), Compiler.primal_return_type_world(mode == API.DEM_ForwardMode ? Forward : Reverse, world, Core.Typeof(func), tt) + end + rt = A{rt} target = Compiler.EnzymeTarget() if modifiedBetween === nothing @@ -46,11 +47,18 @@ function get_job( ErrIfFuncWritten, RuntimeActivity, ) - return Compiler.CompilerJob( - primal, - CompilerConfig(target, params; kernel = false), - world, - ) + if world isa Nothing + return Compiler.CompilerJob( + primal, + CompilerConfig(target, params; kernel = false), + ) + else + return Compiler.CompilerJob( + primal, + CompilerConfig(target, params; kernel = false), + world, + ) + end end function reflect( diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index 09ff90a50a..e58e574dd2 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -275,9 +275,9 @@ end function get_function!( mod::LLVM.Module, - name::AbstractString, + name::String, FT::LLVM.FunctionType, - attrs = [], + attrs::Vector{LLVM.Attribute} = LLVM.Attribute[], ) if haskey(functions(mod), name) F = functions(mod)[name] @@ -294,13 +294,13 @@ function get_function!( return F, FT end -function get_function!(builderF, mod::LLVM.Module, name) +function get_function!(@nospecialize(builderF), mod::LLVM.Module, name::String) get_function!(mod, name, builderF()) end T_ppjlvalue() = LLVM.PointerType(LLVM.PointerType(LLVM.StructType(LLVMType[]))) -function declare_pgcstack!(mod) +function declare_pgcstack!(mod::LLVM.Module) get_function!( mod, "julia.get_pgcstack", @@ -308,7 +308,7 @@ function declare_pgcstack!(mod) ) end -function emit_pgcstack(B) +function emit_pgcstack(B::LLVM.IRBuilder) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -316,7 +316,7 @@ function emit_pgcstack(B) return call!(B, fty, func) end -function get_pgcstack(func) +function get_pgcstack(func::LLVM.Function) entry_bb = first(blocks(func)) pgcstack_func = declare_pgcstack!(LLVM.parent(func)) @@ -328,7 +328,7 @@ function get_pgcstack(func) return nothing end -function reinsert_gcmarker!(func, PB = nothing) +function reinsert_gcmarker!(func::LLVM.Function, @nospecialize(PB::Union{Nothing, LLVM.IRBuilder}) = nothing) for (i, v) in enumerate(parameters(func)) if any( map( @@ -341,7 +341,7 @@ function reinsert_gcmarker!(func, PB = nothing) end pgs = get_pgcstack(func) - if pgs === nothing + if pgs isa Nothing context(LLVM.parent(func)) B = IRBuilder() entry_bb = first(blocks(func)) @@ -355,13 +355,27 @@ function reinsert_gcmarker!(func, PB = nothing) entry_bb = first(blocks(func)) fst = first(instructions(entry_bb)) if fst != pgs - API.moveBefore(pgs, fst, PB === nothing ? C_NULL : PB.ref) + API.moveBefore(pgs, fst, PB isa Nothing ? C_NULL : PB.ref) end pgs end end -function eraseInst(bb, inst) +function eraseInst(bb::LLVM.BasicBlock, @nospecialize(inst::LLVM.Instruction)) + @static if isdefined(LLVM, Symbol("erase!")) + LLVM.erase!(inst) + else + unsafe_delete!(bb, inst) + end +end +function eraseInst(bb::LLVM.Module, inst::LLVM.Function) + @static if isdefined(LLVM, Symbol("erase!")) + LLVM.erase!(inst) + else + unsafe_delete!(bb, inst) + end +end +function eraseInst(bb::LLVM.Module, inst::LLVM.GlobalVariable) @static if isdefined(LLVM, Symbol("erase!")) LLVM.erase!(inst) else @@ -369,7 +383,7 @@ function eraseInst(bb, inst) end end -function unique_gcmarker!(func) +function unique_gcmarker!(func::LLVM.Function) entry_bb = first(blocks(func)) pgcstack_func = declare_pgcstack!(LLVM.parent(func)) @@ -393,7 +407,7 @@ end NamedTuple{ntuple(i -> Symbol(i), Val(length(U.parameters))),U} # recursively compute the eltype type indexed by idx[0], idx[1], ... -function recursive_eltype(val::LLVM.Value, idxs::Vector{Cuint}) +function recursive_eltype(@nospecialize(val::LLVM.Value), idxs::Vector{Cuint}) ty = LLVM.value_type(val) for i in idxs if isa(ty, LLVM.ArrayType) @@ -409,10 +423,10 @@ end # Fix calling convention within julia that Tuple{Float,Float} ->[2 x float] rather than {float, float} # and that Bool -> i8, not i1 function calling_conv_fixup( - builder, - val::LLVM.Value, - tape::LLVM.LLVMType, - prev::LLVM.Value = LLVM.UndefValue(tape), + builder::LLVM.IRBuilder, + @nospecialize(val::LLVM.Value), + @nospecialize(tape::LLVM.LLVMType), + @nospecialize(prev::LLVM.Value) = LLVM.UndefValue(tape), lidxs::Vector{Cuint} = Cuint[], ridxs::Vector{Cuint} = Cuint[], emesg = nothing, diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 848b5734e7..6378b2e0a6 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -177,7 +177,7 @@ function restore_lookups(mod::LLVM.Module) end end -function check_ir(job, mod::LLVM.Module) +function check_ir(@nospecialize(job::CompilerJob), mod::LLVM.Module) errors = check_ir!(job, IRError[], mod) unique!(errors) if !isempty(errors) @@ -373,7 +373,7 @@ function rewrite_ccalls!(mod::LLVM.Module) end end -function check_ir!(job, errors, mod::LLVM.Module) +function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod::LLVM.Module) imported = Set(String[]) if haskey(functions(mod), "malloc") f = functions(mod)["malloc"] @@ -417,8 +417,8 @@ function check_ir!(job, errors, mod::LLVM.Module) return errors end -function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns) - calls = [] +function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, f::LLVM.Function, deletedfns::Vector{LLVM.Function}) + calls = LLVM.CallInst[] isInline = API.EnzymeGetCLBool(cglobal((:EnzymeInline, API.libEnzyme))) != 0 mod = LLVM.parent(f) for bb in blocks(f), inst in collect(instructions(bb)) @@ -654,19 +654,19 @@ const generic_method_offsets = Dict{String,Tuple{Int,Int}}(( "ijl_apply_generic" => (1, 2), )) -@inline function has_method(sig, world::UInt, mt::Union{Nothing,Core.MethodTable}) +@inline function has_method(@nospecialize(sig::Type), world::UInt, mt::Union{Nothing,Core.MethodTable}) return ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), sig, mt, world) !== nothing end -@inline function has_method(sig, world::UInt, mt::Core.Compiler.InternalMethodTable) +@inline function has_method(@nospecialize(sig::Type), world::UInt, mt::Core.Compiler.InternalMethodTable) return has_method(sig, mt.world, nothing) end -@inline function has_method(sig, world::UInt, mt::Core.Compiler.OverlayMethodTable) +@inline function has_method(@nospecialize(sig::Type), world::UInt, mt::Core.Compiler.OverlayMethodTable) return has_method(sig, mt.mt, mt.world) || has_method(sig, nothing, mt.world) end -@inline function is_inactive(tys, world::UInt, mt) +@inline function is_inactive(@nospecialize(tys::Union{Vector{Union{Type,Core.TypeofVararg}}, Core.SimpleVector}), world::UInt, @nospecialize(mt)) specTypes = Interpreter.simplify_kw(Tuple{tys...}) if has_method(Tuple{typeof(EnzymeRules.inactive),tys...}, world, mt) return true @@ -680,7 +680,7 @@ end import GPUCompiler: DYNAMIC_CALL, DELAYED_BINDING, RUNTIME_FUNCTION, UNKNOWN_FUNCTION, POINTER_FUNCTION import GPUCompiler: backtrace, isintrinsic -function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) +function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, inst::LLVM.CallInst, calls::Vector{LLVM.CallInst}) world = job.world interp = GPUCompiler.get_interpreter(job) method_table = Core.Compiler.method_table(interp) @@ -887,7 +887,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) if (isa(fname, LLVM.ConstantArray) || isa(fname, LLVM.ConstantDataArray)) && eltype(value_type(fname)) == LLVM.IntType(8) - fname = String(map((x) -> convert(UInt8, x), collect(fname)[1:(end-1)])) + fname = String(map(Base.Fix1(convert, UInt8), collect(fname)[1:(end-1)])) end if !isa(fname, String) || !isa(flib, String) @@ -1065,7 +1065,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) legal2, funclib, byref2 = abs_typeof(operands(inst)[funcoff+1]) if legal && (GT <: Vector || GT <: Tuple) if legal2 - tys = [funclib, Vararg{Any}] + tys = Union{Type, Core.TypeofVararg}[funclib, Vararg{Any}] if funclib == typeof(Core.apply_type) || is_inactive(tys, world, method_table) inactive = LLVM.StringAttribute("enzyme_inactive", "") @@ -1139,7 +1139,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) # Add 1 to account for function being first arg legal, flibty, byref = abs_typeof(operands(inst)[offset+1]) if legal - tys = Type[flibty] + tys = Union{Type, Core.TypeofVararg}[flibty] for op in collect(operands(inst))[start+1:end-1] legal, typ, byref2 = abs_typeof(op, true) if !legal @@ -1241,7 +1241,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) legal, flibty, byref = abs_typeof(operands(inst)[offset]) if legal - tys = Type[flibty] + tys = Union{Type, Core.TypeofVararg}[flibty] for op in collect(operands(inst))[start:end-1] legal, typ, byref2 = abs_typeof(op, true) if !legal @@ -1309,7 +1309,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) end -function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width) +function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off::Int64, world::UInt, width::Int) todo = Tuple{LLVM.Value,Tuple}[] for b in blocks(enzymefn) term = terminator(b) diff --git a/src/jlrt.jl b/src/jlrt.jl index 5a8cf33e0c..223697fbe0 100644 --- a/src/jlrt.jl +++ b/src/jlrt.jl @@ -43,7 +43,7 @@ function emit_allocobj!( if value_type(Size) != T_size_t # Fix Int32/Int64 issues on 32bit systems Size = trunc!(B, Size, T_size_t) end - return call!(B, alty, alloc_obj, [ptls, Size, tag]) + return call!(B, alty, alloc_obj, LLVM.Value[ptls, Size, tag]) end T_size_t = convert(LLVM.LLVMType, Int) @@ -56,7 +56,7 @@ function emit_allocobj!( alloc_obj, _ = get_function!(mod, "julia.gc_alloc_obj", alty) - return call!(B, alty, alloc_obj, [ct, Size, tag], name) + return call!(B, alty, alloc_obj, LLVM.Value[ct, Size, tag], name) end function emit_allocobj!(B::LLVM.IRBuilder, @nospecialize(T::DataType), name::String = "") curent_bb = position(B) @@ -91,7 +91,7 @@ function emit_pointerfromobjref!(B::LLVM.IRBuilder, @nospecialize(T::LLVM.Value) return call!(B, fty, func, [T]) end -declare_writebarrier!(mod) = +declare_writebarrier!(mod::LLVM.Module) = get_function!(mod, "julia.write_barrier") do T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -121,7 +121,7 @@ function emit_jl!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]) fn, _ = get_function!(mod, "jl_", FT) - call!(B, FT, fn, [val]) + call!(B, FT, fn, LLVM.Value[val]) end function emit_jl_isa!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(ty::LLVM.Value))::LLVM.Value @@ -133,7 +133,7 @@ function emit_jl_isa!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospec ity = LLVM.IntType(8*sizeof(Int)) FT = LLVM.FunctionType(ity, [T_prjlvalue, T_prjlvalue]) fn, _ = get_function!(mod, "jl_isa", FT) - call!(B, FT, fn, [val, val]) + call!(B, FT, fn, LLVM.Value[val, ty]) end function emit_jl_isa!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(ty::Type))::LLVM.Value @@ -164,12 +164,13 @@ function emit_getfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nosp vararg = true, ), ) - res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) + nargs = LLVM.Value[inv, val, fld] + res = call!(B, FT, julia_call, nargs) return res end -function emit_nthfield!(B::LLVM.IRBuilder, val::LLVM.Value, @nospecialize(fld::LLVM.Value))::LLVM.Value +function emit_nthfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(fld::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -181,12 +182,12 @@ function emit_nthfield!(B::LLVM.IRBuilder, val::LLVM.Value, @nospecialize(fld::L gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_size_t]) inv, _ = get_function!(mod, "jl_get_nth_field_checked", gen_FT) - args = [val, fld] + args = LLVM.Value[val, fld] call!(B, gen_FT, inv, args) end -function emit_nthfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), fld::Integer)::LLVM.Value - emit_nthfield!(B, val, LLVM.ConstantInt(Int(fld))) +function emit_nthfield!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), fld::Int)::LLVM.Value + emit_nthfield!(B, val, LLVM.ConstantInt(fld)) end function emit_jl_throw!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value @@ -198,7 +199,7 @@ function emit_jl_throw!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM T_prjlvalue = LLVM.PointerType(T_jlvalue, 12) FT = LLVM.FunctionType(T_void, [T_prjlvalue]) fn, _ = get_function!(mod, "jl_throw", FT) - call!(B, FT, fn, [val]) + call!(B, FT, fn, LLVM.Value[val]) end function emit_box_int32!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value @@ -212,7 +213,7 @@ function emit_box_int32!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLV FT = LLVM.FunctionType(T_prjlvalue, [T_int32]) box_int32, _ = get_function!(mod, "ijl_box_int32", FT) - call!(B, FT, box_int32, [val]) + call!(B, FT, box_int32, LLVM.Value[val]) end function emit_box_int64!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value @@ -229,7 +230,7 @@ function emit_box_int64!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLV call!(B, FT, box_int64, [val]) end -function emit_apply_generic!(B::LLVM.IRBuilder, @nospecialize(args))::LLVM.Value +function emit_apply_generic!(B::LLVM.IRBuilder, args::Vector{LLVM.Value})::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -252,11 +253,16 @@ function emit_apply_generic!(B::LLVM.IRBuilder, @nospecialize(args))::LLVM.Value vararg = true, ), ) - res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) + nargs = Vector{LLVM.Value}(undef, 1+length(args)) + nargs[1] = inv + for (i, v) in enumerate(args) + nargs[1+i] = v + end + res = call!(B, FT, julia_call, nargs) return res end -function emit_invoke!(B::LLVM.IRBuilder, @nospecialize(args))::LLVM.Value +function emit_invoke!(B::LLVM.IRBuilder, args::Vector{LLVM.Value})::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -281,11 +287,16 @@ function emit_invoke!(B::LLVM.IRBuilder, @nospecialize(args))::LLVM.Value vararg = true, ), ) - res = call!(B, FT, julia_call, [inv, args...]) + nargs = Vector{LLVM.Value}(undef, 1+length(args)) + nargs[1] = inv + for (i, v) in enumerate(args) + nargs[1+i] = v + end + res = call!(B, FT, julia_call, nargs) return res end -function emit_svec!(B::LLVM.IRBuilder, @nospecialize(args))::LLVM.Value +function emit_svec!(B::LLVM.IRBuilder, args::Vector{LLVM.Value})::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -297,7 +308,13 @@ function emit_svec!(B::LLVM.IRBuilder, @nospecialize(args))::LLVM.Value LLVM.FunctionType(T_prjlvalue, [sz]; vararg = true) sz = convert(LLVMType, Csize_t) - call!(B, fty, fn, [LLVM.ConstantInt(sz, length(args)), args...]) + + nargs = Vector{LLVM.Value}(undef, 1+length(args)) + nargs[1] = LLVM.ConstantInt(sz, length(args)) + for (i, v) in enumerate(args) + nargs[1+i] = v + end + call!(B, fty, fn, nargs) end @@ -309,7 +326,7 @@ function load_if_mixed(oval::OT, val::VT) where {OT, VT} end end -function val_from_byref_if_mixed(B::LLVM.IRBuilder, gutils, @nospecialize(oval::LLVM.Value), @nospecialize(val::LLVM.Value)) +function val_from_byref_if_mixed(B::LLVM.IRBuilder, gutils::GradientUtils, @nospecialize(oval::LLVM.Value), @nospecialize(val::LLVM.Value)) world = enzyme_extract_world(LLVM.parent(position(B))) legal, TT, _ = abs_typeof(oval) if !legal @@ -384,13 +401,13 @@ function byref_from_val_if_mixed(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Valu end end -function emit_apply_type!(B::LLVM.IRBuilder, Ty, args)::LLVM.Value +function emit_apply_type!(B::LLVM.IRBuilder, @nospecialize(Ty::Type), args::Vector{LLVM.Value})::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) legal = true - found = [] + found = Any[] for arg in args slegal, foundv = absint(arg) if slegal @@ -412,7 +429,6 @@ function emit_apply_type!(B::LLVM.IRBuilder, Ty, args)::LLVM.Value generic_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32]) f_apply_type, _ = get_function!(mod, "jl_f_apply_type", generic_FT) - Ty = unsafe_to_llvm(B, Ty) # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) julia_call, FT = get_function!( @@ -424,22 +440,29 @@ function emit_apply_type!(B::LLVM.IRBuilder, Ty, args)::LLVM.Value vararg = true, ), ) + nargs = Vector{LLVM.Value}(undef, 3+length(args)) + nargs[1] = f_apply_type + nargs[2] = LLVM.PointerNull(T_prjlvalue) + nargs[3] = unsafe_to_llvm(B, Ty) + for (i, v) in enumerate(args) + nargs[3+i] = v + end tag = call!( B, FT, julia_call, - LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), Ty, args...], + nargs ) return tag end -function emit_tuple!(B, args)::LLVM.Value +function emit_tuple!(B::LLVM.IRBuilder, args::Vector{LLVM.Value})::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) legal = true - found = [] + found = Any[] for arg in args slegal, foundv = absint(arg) if slegal @@ -472,16 +495,22 @@ function emit_tuple!(B, args)::LLVM.Value vararg = true, ), ) + nargs = Vector{LLVM.Value}(undef, 2+length(args)) + nargs[1] = f_apply_type + nargs[2] = LLVM.PointerNull(T_prjlvalue) + for (i, v) in enumerate(args) + nargs[2+i] = v + end tag = call!( B, FT, julia_call, - LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), args...], + nargs ) return tag end -function emit_jltypeof!(B::LLVM.IRBuilder, arg::LLVM.Value)::LLVM.Value +function emit_jltypeof!(B::LLVM.IRBuilder, @nospecialize(arg::LLVM.Value))::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -498,7 +527,7 @@ function emit_jltypeof!(B::LLVM.IRBuilder, arg::LLVM.Value)::LLVM.Value call!(B, FT, fn, [arg]) end -function emit_methodinstance!(B::LLVM.IRBuilder, func, args)::LLVM.Value +function emit_methodinstance!(B::LLVM.IRBuilder, @nospecialize(func), args::Vector{LLVM.Value})::LLVM.Value curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -581,7 +610,7 @@ function emit_methodinstance!(B::LLVM.IRBuilder, func, args)::LLVM.Value return mi end -function emit_writebarrier!(B, T) +function emit_writebarrier!(B::LLVM.IRBuilder, T::Vector{LLVM.Value}) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -685,7 +714,7 @@ function get_memory_struct() return LLVM.StructType([sizeT, ptrty]; packed = true) end -function get_memory_data(B, array) +function get_memory_data(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) mty = get_memory_struct() array = LLVM.pointercast!( B, @@ -766,7 +795,7 @@ function get_datatype_struct() return LLVM.StructType([jlvaluet, jlvaluet, jlvaluet, jlvaluet, jlvaluet, jlvaluet, i32, i16]; packed = true) end -function get_array_data(B, array) +function get_array_data(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) i8 = LLVM.IntType(8) ptrty = LLVM.PointerType(i8, 13) array = LLVM.pointercast!( @@ -777,7 +806,7 @@ function get_array_data(B, array) return LLVM.load!(B, ptrty, array) end -function get_array_elsz(B, array) +function get_array_elsz(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) ST = get_array_struct() elsz = LLVM.IntType(16) array = LLVM.pointercast!( @@ -794,7 +823,7 @@ function get_array_elsz(B, array) return LLVM.load!(B, elsz, v) end -function emit_layout_of_type!(B, ty) +function emit_layout_of_type!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value)) legal, JTy = absint(ty) ls = get_layout_struct() lptr = LLVM.PointerType(ls, 10) @@ -814,7 +843,7 @@ function emit_layout_of_type!(B, ty) return layout end -function emit_memorytype_elsz!(B, ty) +function emit_memorytype_elsz!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value)) legal, JTy = absint(ty) if legal res = unsafe_load(reinterpret(Ptr{UInt32}, JTy.layout)) @@ -828,12 +857,12 @@ function emit_memorytype_elsz!(B, ty) return load!(B, i32, lty) end -function get_memory_elsz(B, array) +function get_memory_elsz(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) ty = emit_jltypeof!(B, array) return emit_memorytype_elsz!(B, ty) end -function get_array_len(B, array) +function get_array_len(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) if isa(array, LLVM.CallInst) fn = LLVM.called_operand(array) nm = "" @@ -874,7 +903,7 @@ function get_array_len(B, array) return LLVM.load!(B, sizeT, v) end -function get_memory_len(B, array) +function get_memory_len(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) if isa(array, LLVM.CallInst) fn = LLVM.called_operand(array) nm = "" @@ -911,7 +940,7 @@ function get_memory_len(B, array) return LLVM.load!(B, sizeT, v) end -function get_array_nrows(B, array) +function get_array_nrows(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value)) ST = get_array_struct() array = LLVM.pointercast!( B, @@ -928,7 +957,7 @@ function get_array_nrows(B, array) return LLVM.load!(B, nrows, v) end -function emit_gc_preserve_begin(B::LLVM.IRBuilder, args = LLVM.Value[]) +function emit_gc_preserve_begin(B::LLVM.IRBuilder, args::Vector{LLVM.Value} = LLVM.Value[]) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -942,7 +971,7 @@ function emit_gc_preserve_begin(B::LLVM.IRBuilder, args = LLVM.Value[]) return token end -function emit_gc_preserve_end(B::LLVM.IRBuilder, token) +function emit_gc_preserve_end(B::LLVM.IRBuilder, @nospecialize(token::LLVM.Value)) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -957,20 +986,20 @@ function emit_gc_preserve_end(B::LLVM.IRBuilder, token) return end -function allocate_sret!(B::LLVM.IRBuilder, N) +function allocate_sret!(B::LLVM.IRBuilder, @nospecialize(N::LLVM.LLVMType)) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) al = LLVM.alloca!(B, LLVM.ArrayType(T_prjlvalue, N)) return al end -function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, N) +function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, @nospecialize(N::LLVM.LLVMType)) B = LLVM.IRBuilder() position!(B, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) allocate_sret!(B, N) end -function emit_error(B::LLVM.IRBuilder, orig, string, errty = EnzymeRuntimeException) +function emit_error(B::LLVM.IRBuilder, @nospecialize(orig::Union{Nothing, LLVM.Instruction}), string::String, @nospecialize(errty::Type) = EnzymeRuntimeException) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 03611e7d26..55f3136286 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -1,10 +1,10 @@ function enzyme_custom_setup_args( - B, + @nospecialize(B::Union{Nothing, LLVM.IRBuilder}), orig::LLVM.CallInst, gutils::GradientUtils, - mi, - @nospecialize(RT), + mi::Core.MethodInstance, + @nospecialize(RT::Type), reverse::Bool, isKWCall::Bool, ) @@ -356,9 +356,9 @@ end function enzyme_custom_setup_ret( gutils::GradientUtils, orig::LLVM.CallInst, - mi, - @nospecialize(RealRt), - B, + mi::Core.MethodInstance, + @nospecialize(RealRt::Type), + @nospecialize(B::Union{LLVM.IRBuilder,Nothing}) ) width = get_width(gutils) mode = get_mode(gutils) @@ -448,11 +448,11 @@ function enzyme_custom_setup_ret( return RT, needsPrimal, needsShadowP[] != 0, origNeedsPrimal end -function custom_rule_method_error(world, fn, args...) +function custom_rule_method_error(world::UInt, @nospecialize(fn), @nospecialize(args::Vararg)) throw(MethodError(fn, (args...,), world)) end -@register_fwd function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) +@register_fwd function enzyme_custom_fwd(B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::GradientUtils, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef}) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end @@ -527,7 +527,7 @@ end if EnzymeRules.isapplicable(kwfunc, TT; world) @safe_debug "Applying custom forward rule (kwcall)" TT llvmf = nested_codegen!(mode, mod, kwfunc, TT, world) - fwd_RT = Core.Compiler.return_type(kwfunc, TT, world) + fwd_RT = Compiler.primal_return_type_world(Forward, world, Core.Typeof(kwfunc), TT) else TT = Tuple{typeof(world),typeof(kwfunc),TT.parameters...} llvmf = nested_codegen!(mode, mod, custom_rule_method_error, TT, world) @@ -538,7 +538,7 @@ end if EnzymeRules.isapplicable(EnzymeRules.forward, TT; world) @safe_debug "Applying custom forward rule" TT llvmf = nested_codegen!(mode, mod, EnzymeRules.forward, TT, world) - fwd_RT = Core.Compiler.return_type(EnzymeRules.forward, TT, world) + fwd_RT = Compiler.primal_return_type_world(Forward, world, typeof(EnzymeRules.forward), TT) else TT = Tuple{typeof(world),typeof(EnzymeRules.forward),TT.parameters...} llvmf = nested_codegen!(mode, mod, custom_rule_method_error, TT, world) @@ -768,8 +768,8 @@ end @inline function aug_fwd_mi( orig::LLVM.CallInst, gutils::GradientUtils, - forward = false, - B = nothing, + forward::Bool = false, + @nospecialize(B::Union{Nothing, LLVM.IRBuilder}) = nothing, ) width = get_width(gutils) @@ -874,18 +874,18 @@ end ) end -@inline function has_aug_fwd_rule(orig, gutils) +@inline function has_aug_fwd_rule(orig::LLVM.CallInst, gutils::GradientUtils) return aug_fwd_mi(orig, gutils)[1] !== nothing end -@register_rev function enzyme_custom_common_rev( +function enzyme_custom_common_rev( forward::Bool, - B, + B::LLVM.IRBuilder, orig::LLVM.CallInst, - gutils, - normalR, - shadowR, - tape, + gutils::GradientUtils, + normalR::Ptr{LLVM.API.LLVMValueRef}, + shadowR::Ptr{LLVM.API.LLVMValueRef}, + tape::Union{Nothing, LLVM.Value}, )::LLVM.API.LLVMValueRef ctx = LLVM.context(orig) @@ -1030,7 +1030,7 @@ end @safe_debug "Applying custom reverse rule (kwcall)" TT = rev_TT try llvmf = nested_codegen!(mode, mod, rkwfunc, rev_TT, world) - rev_RT = Core.Compiler.return_type(rkwfunc, rev_TT, world) + rev_RT = Compiler.primal_return_type_world(Reverse, world, Core.Typeof(rkwfunc), rev_TT) catch e rev_TT = Tuple{typeof(world),typeof(rkwfunc),rev_TT.parameters...} llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) @@ -1050,7 +1050,7 @@ end @safe_debug "Applying custom reverse rule" TT = rev_TT try llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, rev_TT, world) - rev_RT = Core.Compiler.return_type(EnzymeRules.reverse, rev_TT, world) + rev_RT = Compiler.primal_return_type_world(Reverse, world, typeof(EnzymeRules.reverse), rev_TT) catch e rev_TT = Tuple{typeof(world),typeof(EnzymeRules.reverse),rev_TT.parameters...} @@ -1121,7 +1121,7 @@ end if !forward funcTy = rev_TT.parameters[isKWCall ? 4 : 2] if needsTape - @assert tape != C_NULL + @assert tape isa LLVM.Value tape_idx = 1 + (kwtup !== nothing && !isghostty(kwtup)) + @@ -1574,7 +1574,7 @@ end end -@register_aug function enzyme_custom_augfwd(B, orig, gutils, normalR, shadowR, tapeR) +@register_aug function enzyme_custom_augfwd(B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::GradientUtils, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef}, tapeR::Ptr{LLVM.API.LLVMValueRef}) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) @@ -1587,17 +1587,17 @@ end return false end -@register_rev function enzyme_custom_rev(B, orig, gutils, tape) +@register_rev function enzyme_custom_rev(B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::GradientUtils, @nospecialize(tape::Union{Nothing, LLVM.Value})) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) return end - enzyme_custom_common_rev(false, B, orig, gutils, C_NULL, C_NULL, tape) #=tape=# + enzyme_custom_common_rev(false, B, orig, gutils, reinterpret(Ptr{LLVM.API.LLVMValueRef}, C_NULL), reinterpret(Ptr{LLVM.API.LLVMValueRef}, C_NULL), tape) #=tape=# return nothing end -@register_diffuse function enzyme_custom_diffuse(orig, gutils, val, isshadow, mode) +@register_diffuse function enzyme_custom_diffuse(orig::LLVM.CallInst, gutils::GradientUtils, @nospecialize(val::LLVM.Value), isshadow::Bool, mode::API.CDerivativeMode) # use default if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 546231c3f4..3ca0d0a3cd 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -267,7 +267,8 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) # tt0 = Tuple{$(primtypes...)} tt = Tuple{$(ElTypes...)} tt′ = Tuple{$(Types...)} - rt = Core.Compiler.return_type(f, Tuple{$(ElTypes...)}) + FT = Core.Typeof(f) + rt = Compiler.primal_return_type(Forward, FT, tt) annotation = guess_activity(rt, API.DEM_ForwardMode) if annotation <: DuplicatedNoNeed @@ -280,15 +281,12 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) end dupClosure = ActivityTup[1] - FT = Core.Typeof(f) if dupClosure && guaranteed_const(FT) dupClosure = false end - world = codegen_world_age(FT, tt) - opt_mi = Val(world) forward = thunk( - opt_mi, + Val(0), dupClosure ? $dupty : Const{FT}, annotation, tt′, @@ -438,23 +436,23 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) args = ($(wrapped...),) $(MakeTypes...) + FT = Core.Typeof(f) + dupClosure0 = if ActivityTup[1] + !guaranteed_const(FT) + else + false + end + + internal_tape, origRet, initShadow, annotation = if f isa typeof(Core.getglobal) gv = Core.getglobal(args[1].val, args[2].val) @assert sizeof(gv) == 0 (nothing, gv, nothing, Const) else - FT = Core.Typeof(f) + tt = Tuple{$(ElTypes...)} - world = codegen_world_age(FT, tt) - - dupClosure0 = if ActivityTup[1] - !guaranteed_const(FT) - else - false - end - - rt = Compiler.primal_return_type(Reverse, Val(world), FT, tt) + rt = Compiler.primal_return_type(Reverse, FT, tt) annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) annotationA = if $Width != 1 && annotation0 <: Duplicated @@ -464,10 +462,8 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) else annotation0 end - - opt_mi = Val(world) forward, adjoint = thunk( - opt_mi, + Val(0), dupClosure0 ? $dupty : Const{FT}, annotationA, Tuple{$(Types...)}, @@ -572,7 +568,7 @@ function nonzero_active_data(x::T) where {T} end function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, active_refs) - outs = [] + outs = Vector{Expr}(undef, N*Width) for i = 1:N for w = 1:Width expr = if Width == 1 @@ -600,16 +596,16 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act ) end end - push!(outs, out) + @inbounds outs[(i-1)*Width+w] = out end end shadow_ret = nothing if Width == 1 shadowret = :(tape.shadow_return[]) else - shadowret = [] + shadowret = Vector{Expr}(undef, Width) for w = 1:Width - push!(shadowret, :(tape.shadow_return[$w][])) + @inbounds shadowret[w] = :(tape.shadow_return[$w][]) end shadowret = :(($(shadowret...),)) end @@ -651,13 +647,8 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act else false end - tt = Tuple{$(ElTypes...)} - - world = codegen_world_age(FT, tt) - - rt = Compiler.primal_return_type(Reverse, Val(world), FT, tt) - + rt = Compiler.primal_return_type(Reverse, FT, tt) annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) annotation = if $Width != 1 && annotation0 <: Duplicated @@ -666,9 +657,8 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act annotation0 end - opt_mi = Val(world) _, adjoint = thunk( - opt_mi, + Val(0), dupClosure0 ? $dupty : Const{FT}, annotation, Tuple{$(Types...)}, @@ -991,7 +981,7 @@ function fwddiff_with_return( tt = Enzyme.vaEltypes(tt′) - rt = Core.Compiler.return_type(f, tt) + rt = Compiler.primal_return_type(Forward, FT, tt) annotation0 = guess_activity(rt, API.DEM_ForwardMode) annotation = if width != 1 @@ -1008,7 +998,6 @@ function fwddiff_with_return( end end - world = codegen_world_age(FT, tt) fa = if dupClosure if width == 1 Duplicated(f, df) @@ -1018,9 +1007,8 @@ function fwddiff_with_return( else Const(f) end - opt_mi = Val(world) res = thunk( - opt_mi, + Val(0), FA, annotation, tt′, @@ -1124,15 +1112,13 @@ end ) where {Ann,Nargs} expr = Vector{Expr}(undef, Nargs) for i = 1:Nargs - @inbounds expr[i] = quote - @assert !(args[$i] isa Active) - if args[$i] isa Const - args[$i].val - elseif args[$i] isa MixedDuplicated - args[$i].dval[] - else - args[$i].dval - end + @assert !(args[i] <: Active) + @inbounds expr[i] = if args[i] <: Const + :(args[$i].val) + elseif args[i] <: MixedDuplicated + :(args[$i].dval[]) + else + :(args[$i].dval) end end rval = :(($(expr...),)) @@ -1154,19 +1140,17 @@ end for w = 1:width expr = Vector{Expr}(undef, Nargs) for i = 1:Nargs - @inbounds expr[i] = quote - @assert !(args[$i] isa Active) - if args[$i] isa Const - args[$i].val - elseif args[$i] isa BatchMixedDuplicated - args[$i].dval[$w][] - else - args[$i].dval[$w] - end + @assert !(args[i] <: Active) + @inbounds expr[i] = if args[i] <: Const + :(args[$i].val) + elseif args[i] <: BatchMixedDuplicated + :(args[$i].dval[$w][]) + else + :(args[$i].dval[$w]) end end rval = :(($(expr...),)) - if Ann <: BatchMixedDuplicated + if Ann <: BatchMixedDuplicated || Ann <: MixedDuplicated rval = :(Ref($rval)) end @inbounds wexpr[w] = rval @@ -1205,32 +1189,33 @@ function augfwd_with_return( ModifiedBetween = Val(ModifiedBetween0) tt = Enzyme.vaEltypes(tt′) - rt = Core.Compiler.return_type(f, tt) - annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - annotation = if width != 1 - if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated - BatchDuplicated{rt,width} - elseif annotation0 <: MixedDuplicated - BatchMixedDuplicated{rt,width} - elseif annotation0 <: Active - Active{rt} - else - Const{rt} - end - else - if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated - Duplicated{rt} - elseif annotation0 <: MixedDuplicated - MixedDuplicated{rt} - elseif annotation0 <: Active - Active{rt} + internal_tape, origRet, initShadow = if f != Base.tuple + rt = Compiler.primal_return_type(Reverse, FT, tt) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + + annotation = if width != 1 + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + BatchDuplicated{rt,width} + elseif annotation0 <: MixedDuplicated + BatchMixedDuplicated{rt,width} + elseif annotation0 <: Active + Active{rt} + else + Const{rt} + end else - Const{rt} + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + Duplicated{rt} + elseif annotation0 <: MixedDuplicated + MixedDuplicated{rt} + elseif annotation0 <: Active + Active{rt} + else + Const{rt} + end end - end - internal_tape, origRet, initShadow = if f != Base.tuple dupClosure = dupClosure0 && !guaranteed_const(FT) FA = dupClosure ? Duplicated{FT} : Const{FT} @@ -1243,10 +1228,8 @@ function augfwd_with_return( else Const(f) end - world = codegen_world_age(FT, tt) - opt_mi = Val(world) forward, adjoint = thunk( - opt_mi, + Val(0), FA, annotation, tt′, @@ -1261,21 +1244,22 @@ function augfwd_with_return( ) #=erriffuncwritten=# forward(fa, args...) else + annotation0 = guess_activity(tt, API.DEM_ReverseModePrimal) nothing, primal_tuple(args...), - annotation <: Active ? nothing : shadow_tuple(annotation, Val(width), args...) + annotation0 <: Active ? nothing : shadow_tuple(annotation0, Val(width), args...) end resT = typeof(origRet) - if annotation <: Const + if annotation0 <: Const shadow_return = nothing tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( internal_tape, shadow_return, ) return ReturnType((allSame(Val(width + 1), origRet)..., tape)) - elseif annotation <: Active + elseif annotation0 <: Active shadow_return = if width == 1 Ref(make_zero(origRet)) else @@ -1293,7 +1277,7 @@ function augfwd_with_return( end if width == 1 - if annotation <: MixedDuplicated + if annotation0 <: MixedDuplicated shadow_return = initShadow tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( internal_tape, @@ -1309,7 +1293,7 @@ function augfwd_with_return( return ReturnType((origRet, initShadow, tape)) end else - if annotation <: BatchMixedDuplicated + if annotation0 <: MixedDuplicated shadow_return = initShadow tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( internal_tape, @@ -1445,63 +1429,57 @@ end Nargs, } - nontupexprs = Vector{Expr}(undef, Nargs) + nontupexprs = Vector{Union{Symbol,Expr}}(undef, Nargs) for i = 1:Nargs - mid = if width == 1 - :(tape.shadow_return[][$i]) - else - mexprs = Vector{Expr}(undef, width) - for w = 1:width - @inbounds mexprs[w] = :(tape.shadow_return[$w][][$i]) - end - quote - ($(mexprs...),) - end - end - - @inbounds nontupexprs[i] = quote - if args[$i] isa Active || - args[$i] isa MixedDuplicated || - args[$i] isa BatchMixedDuplicated - $mid + @inbounds nontupexprs[i] = if args[i] <: Active || args[i] <: MixedDuplicated || args[i] <: BatchMixedDuplicated + if width == 1 + :(tape.shadow_return[][$i]) else - nothing + mexprs = Vector{Expr}(undef, width) + for w = 1:width + @inbounds mexprs[w] = :(tape.shadow_return[$w][][$i]) + end + quote + ($(mexprs...),) + end end + else + :nothing end end endexprs = Matrix{Expr}(undef, Nargs, width) for i = 1:Nargs for w = 1:width - @inbounds endexprs[i, w] = quote - if args[$i] isa Active || - args[$i] isa MixedDuplicated || - args[$i] isa BatchMixedDuplicated - expr = if args[$i] isa Active || f == Base.tuple - if $width == 1 - tup[$i] - else - tup[$i][$w] - end - elseif args[$i] isa MixedDuplicated - args[$i].dval[] + @inbounds endexprs[i, w] = if args[i] <: Active || args[i] <: MixedDuplicated || args[i] <: BatchMixedDuplicated + expr = if args[i] <: Active || f <: typeof(Base.tuple) + if width == 1 + :(tup[$i]) else - # if args[$i] isa BatchMixedDuplicated - args[$i].dval[$w][] + :(tup[$i][$w]) end + elseif args[i] <: MixedDuplicated + :(args[$i].dval[]) + else + :(args[$i].dval[$w][]) + end + quote idx_of_vec, idx_in_vec = $(lengths[i]) vec = @inbounds shadowargs[idx_of_vec][$w] if vec isa Base.RefValue vecld = vec[] T = Core.Typeof(vecld) @assert !(vecld isa Base.RefValue) - vec[] = recursive_index_add(T, vecld, Val(idx_in_vec), expr) + vec[] = recursive_index_add(T, vecld, Val(idx_in_vec), $expr) else val = @inbounds vec[idx_in_vec] - add_into_vec!(Base.inferencebarrier(val), expr, vec, idx_in_vec) + add_into_vec!(Base.inferencebarrier(val), $expr, vec, idx_in_vec) end end + else + quote + end end end end @@ -1556,20 +1534,17 @@ end tt = $tt - rt = Core.Compiler.return_type(f, tt) + rt = Compiler.primal_return_type(Reverse, FT, tt) annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - annotation = $annotation - world = codegen_world_age(FT, tt) fa = if dupClosure $(width == 1 ? :Duplicated : :BatchDuplicated)(f, df) else Const(f) end - opt_mi = Val(world) forward, adjoint = thunk( - opt_mi, + Val(0), FA, annotation, $ttp, @@ -1874,7 +1849,7 @@ function generic_setup( pushfirst!(vals, unsafe_to_llvm(B, Val(get_runtime_activity(gutils)))) end etup0 = emit_tuple!(B, ActivityList) - etup = emit_apply_type!(B, Base.Val, [etup0]) + etup = emit_apply_type!(B, Base.Val, LLVM.Value[etup0]) if isa(etup, LLVM.Instruction) @assert length(collect(LLVM.uses(etup0))) == 1 end diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index d882fc2672..08f3f04afd 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -1,23 +1,21 @@ function runtime_newtask_fwd( - world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, runtimeActivity::Val{RuntimeActivity}, ::Val{width}, -) where {FT1,FT2,World,width,RuntimeActivity} +) where {FT1,FT2,width,RuntimeActivity} FT = Core.Typeof(fn) ghos = guaranteed_const(FT) - opt_mi = world forward = thunk( - opt_mi, + Val(0), (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), - Val(width), + Val(Int(width)), Val((false,)), Val(true), Val(false), @@ -35,7 +33,6 @@ function runtime_newtask_fwd( end function runtime_newtask_augfwd( - world::Val{World}, fn::FT1, dfn::FT2, post::Any, @@ -43,18 +40,17 @@ function runtime_newtask_augfwd( runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{ModifiedBetween}, -) where {FT1,FT2,World,width,ModifiedBetween,RuntimeActivity} +) where {FT1,FT2,width,ModifiedBetween,RuntimeActivity} # TODO make this AD subcall type stable FT = Core.Typeof(fn) ghos = guaranteed_const(FT) - opt_mi = world forward, adjoint = thunk( - opt_mi, + Val(0), (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), - Val(width), + Val(Int(width)), Val(ModifiedBetween), Val(true), Val(false), @@ -416,7 +412,7 @@ end al = addrspacecast!(B, al, LLVM.PointerType(ll_th, Derived)) push!(vals, al) - copies = [] + copies = Tuple{LLVM.Value, LLVM.Value, LLVM.LLVMType}[] if !isghostty(dfuncT) llty = convert(LLVMType, dfuncT) @@ -446,7 +442,7 @@ end val = bitcast!(B, val, LLVM.PointerType(pllty, addrspace(value_type(val)))) val = addrspacecast!(B, val, LLVM.PointerType(pllty, Derived)) store!(B, v, val) - if pv !== nothing + if !(pv isa Nothing) push!(copies, (pv, val, pllty)) end @@ -671,7 +667,6 @@ end vals = LLVM.Value[ unsafe_to_llvm(B, runtime_newtask_fwd), - unsafe_to_llvm(B, Val(world)), new_from_original(gutils, ops[1]), invert_pointer(gutils, ops[1], B), new_from_original(gutils, ops[2]), @@ -727,7 +722,6 @@ end vals = LLVM.Value[ unsafe_to_llvm(B, runtime_newtask_augfwd), - unsafe_to_llvm(B, Val(world)), new_from_original(gutils, ops[1]), invert_pointer(gutils, ops[1], B), new_from_original(gutils, ops[2]), diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 723ed23a31..fd157890e2 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -108,7 +108,7 @@ function body_construct_rev( batchshadowargs, tuple, ) - outs = [] + outs = Vector{Expr}(undef, N*Width) for i = 1:N for w = 1:Width tsym = Symbol("tval_$w") @@ -131,15 +131,16 @@ function body_construct_rev( end end ) - push!(outs, out) + @inbounds outs[(i-1)*Width+w] = out end end - tapes = Expr[:(tval_1 = tape[])] + tapes = Vector{Expr}(undef, Width) + @inbounds tapes[1] = :(tval_1 = tape[]) for w = 2:Width sym = Symbol("tval_$w") df = Symbol("df_$w") - push!(tapes, :($sym = $df[])) + @inbounds tapes[w] = :($sym = $df[]) end quote @@ -1413,7 +1414,7 @@ function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, ta push!(vals, inps[1]) sym = new_from_original(gutils, ops[3]) - sym = emit_apply_type!(B, Base.Val, [sym]) + sym = emit_apply_type!(B, Base.Val, LLVM.Value[sym]) push!(vals, sym) push!(vals, unsafe_to_llvm(B, Val(is_constant_value(gutils, ops[2])))) @@ -1510,7 +1511,7 @@ function common_jl_getfield_rev(offset, B, orig, gutils, tape) sym = new_from_original(gutils, ops[3]) sym = lookup_value(gutils, sym, B) - sym = emit_apply_type!(B, Base.Val, [sym]) + sym = emit_apply_type!(B, Base.Val, LLVM.Value[sym]) push!(vals, sym) push!(vals, unsafe_to_llvm(B, Val(is_constant_value(gutils, ops[2])))) @@ -1606,7 +1607,7 @@ end sym = new_from_original(gutils, ops[2]) sym = (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)(B, sym) - sym = emit_apply_type!(B, Base.Val, [sym]) + sym = emit_apply_type!(B, Base.Val, LLVM.Value[sym]) push!(vals, sym) push!(vals, unsafe_to_llvm(B, Val(is_constant_value(gutils, ops[1])))) @@ -1705,7 +1706,7 @@ end sym = new_from_original(gutils, ops[2]) sym = lookup_value(gutils, sym, B) sym = (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)(B, sym) - sym = emit_apply_type!(B, Base.Val, [sym]) + sym = emit_apply_type!(B, Base.Val, LLVM.Value[sym]) push!(vals, sym) push!(vals, unsafe_to_llvm(B, Val(is_constant_value(gutils, ops[1])))) diff --git a/src/utils.jl b/src/utils.jl index 0e23ca486b..d5d0ed733a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -134,10 +134,6 @@ end import Base: allocatedinline -#Excerpt from https://github.com/JuliaGPU/GPUCompiler.jl/blob/v0.19.4/src/jlgen.jl -# !!! warning "codegen_world_age below is fundamentally unsound." -# It was removed from GPUCompiler since it can produce incorrect results. - using Core: MethodInstance using GPUCompiler: tls_world_age, MethodError, methodinstance using Core.Compiler: retrieve_code_info, CodeInfo, SSAValue, ReturnNode @@ -146,118 +142,6 @@ using Base: _methods_by_ftype # Julia compiler integration -## world age lookups - -# `tls_world_age` should be used to look up the current world age. in most cases, this is -# what you should use to invoke the compiler with. -# -# `codegen_world_age` is a special function that returns the world age in which the passed -# method instance (identified by its function and argument types) is to be compiled. the -# returned constant is automatically invalidated when the method is redefined, and as such -# can be used to drive cached compilation. it is unlikely that you should use this function -# directly, instead use `cached_compilation` which handles invalidation for you. - - -# on 1.10 (JuliaLang/julia#48611) the generated function knows which world it was invoked in - -function _generated_ex(world, source, ex) - stub = Core.GeneratedFunctionStub( - identity, - Core.svec(:methodinstance, :ft, :tt), - Core.svec(), - ) - stub(world, source, ex) -end - -function codegen_world_age_generator(world::UInt, source, self, ft::Type, tt::Type) - @nospecialize - @assert Core.Compiler.isType(ft) && Core.Compiler.isType(tt) - ft = ft.parameters[1] - tt = tt.parameters[1] - - # validation - ft <: Core.Builtin && - error("$(GPUCompiler.unsafe_function_from_type(ft)) is not a generic function") - - # look up the method - method_error = :(throw(MethodError(ft, tt, $world))) - sig = Tuple{ft,tt.parameters...} - min_world = Ref{UInt}(typemin(UInt)) - max_world = Ref{UInt}(typemax(UInt)) - has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results - mthds = Base._methods_by_ftype( - sig, - nothing, - -1, #=lim=# - world, - false, #=ambig=# - min_world, - max_world, - has_ambig, - ) - mthds === nothing && return _generated_ex(world, source, method_error) - length(mthds) == 1 || return _generated_ex(world, source, method_error) - - # look up the method and code instance - mtypes, msp, m = mthds[1] - mi = ccall( - :jl_specializations_get_linfo, - Ref{MethodInstance}, - (Any, Any, Any), - m, - mtypes, - msp, - ) - ci = retrieve_code_info(mi, world)::CodeInfo - - # prepare a new code info - new_ci = copy(ci) - empty!(new_ci.code) - @static if isdefined(Core, :DebugInfo) - new_ci.debuginfo = Core.DebugInfo(:none) - else - empty!(new_ci.codelocs) - resize!(new_ci.linetable, 1) # see note below - end - empty!(new_ci.ssaflags) - new_ci.ssavaluetypes = 0 - new_ci.min_world = min_world[] - new_ci.max_world = max_world[] - new_ci.edges = MethodInstance[mi] - # XXX: setting this edge does not give us proper method invalidation, see - # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. - # invoking `code_llvm` also does the necessary codegen, as does calling the - # underlying C methods -- which GPUCompiler does, so everything Just Works. - - # prepare the slots - new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt] - new_ci.slotflags = UInt8[0x00 for i = 1:3] - - # return the codegen world age - push!(new_ci.code, ReturnNode(world)) - push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` - @static if isdefined(Core, :DebugInfo) - else - push!(new_ci.codelocs, 1) # see note below - end - new_ci.ssavaluetypes += 1 - - # NOTE: we keep the first entry of the original linetable, and use it for location info - # on the call to check_cache. we can't not have a codeloc (using 0 causes - # corruption of the back trace), and reusing the target function's info - # has as advantage that we see the name of the kernel in the backtraces. - - return new_ci -end - -@eval function codegen_world_age(ft, tt) - $(Expr(:meta, :generated_only)) - $(Expr(:meta, :generated, codegen_world_age_generator)) -end - -export codegen_world_age - - if VERSION >= v"1.11.0-DEV.1552" diff --git a/test/ext/chainrulescore.jl b/test/ext/chainrulescore.jl index 65984ef26f..9526bd40d6 100644 --- a/test/ext/chainrulescore.jl +++ b/test/ext/chainrulescore.jl @@ -21,7 +21,7 @@ end function ChainRulesCore.rrule(::typeof(MockModule.mock_function), x) y = MockModule.mock_function(x) - return y, ȳ -> 2 * ȳ + return y, ȳ -> (NoTangent(), MockModule.MockType(2 * ȳ)) end fdiff(f, x::Number) = autodiff(ForwardWithPrimal, f, Duplicated, Duplicated(x, one(x)))[1] diff --git a/test/optimize.jl b/test/optimize.jl index d13a6ed752..04fda17f25 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -142,3 +142,66 @@ end @testset "Memcopy of constant" begin @test Enzyme.autodiff(Enzyme.Forward, mc_f, Duplicated(2.7, 1.0))[1] ≈ 0.0 end + +module RetTypeMod + using Enzyme + struct Stacked + end + + @inline function myrand(td::Stacked, num_samples::Int) + return Base.inferencebarrier(ones(1)) + end + + struct TestProb1 end + + logdensity(::TestProb1, θ) = sum(θ) + + struct TestProb2 end + + logdensity(::TestProb2, θ) = sum(θ) + + struct MvLocationScale + end + + # This specialization improves AD performance of the sampling path + @inline function myrand( + q::MvLocationScale, num_samples::Int + ) + return ones(5, num_samples) + end + + function mymean(problem, A::AbstractArray) + isempty(A) && return sum(Base.Fix1(logdensity, problem), A) + x1 = sum(@inbounds first(A)) + return 1.0 + end + + function estimate_repgradelbo_ad_forward(problem, model) + zs = myrand(model, 10) + return mymean(problem, eachcol(zs)) + end + + function main() + d = 5 + for prob in [TestProb1(), TestProb2()] + q = if prob isa TestProb1 + MvLocationScale() + else + Stacked() + end + + Enzyme.autodiff( + Enzyme.Reverse, + estimate_repgradelbo_ad_forward, + Enzyme.Active, + Enzyme.Const(prob), + Enzyme.Const(q), + ) + end + end + +end + +@testset "Indirect function call return type analysis" begin + RetTypeMod.main() +end diff --git a/test/ruleinvalidation.jl b/test/ruleinvalidation.jl index 62579e2415..704ada2b6e 100644 --- a/test/ruleinvalidation.jl +++ b/test/ruleinvalidation.jl @@ -33,11 +33,19 @@ for m in methods(forward, Tuple{Any,Const{typeof(issue696)},Vararg{Any}}) Base.delete_method(m) end @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 +@static if VERSION < v"1.11-" +@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 +else @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 +end # now test invalidation for `inactive` inactive(::typeof(issue696), args...) = nothing @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 +@static if VERSION < v"1.11-" +@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 +else @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 +end end # module diff --git a/test/runtests.jl b/test/runtests.jl index d9b7dc97ac..26587c3892 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -147,11 +147,10 @@ end @test Enzyme.Compiler.active_reg_inner(Tuple{Incomplete}, (), nothing, #=justActive=#Val(false)) == Enzyme.Compiler.MixedState @test Enzyme.Compiler.active_reg_inner(Tuple{Incomplete}, (), nothing, #=justActive=#Val(true)) == Enzyme.Compiler.ActiveState - world = codegen_world_age(typeof(f0), Tuple{Float64}) - thunk_a = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) - thunk_b = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Const, Tuple{Const{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) - thunk_c = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) - thunk_d = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) + thunk_a = Enzyme.Compiler.thunk(Val(0), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) + thunk_b = Enzyme.Compiler.thunk(Val(0), Const{typeof(f0)}, Const, Tuple{Const{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) + thunk_c = Enzyme.Compiler.thunk(Val(0), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) + thunk_d = Enzyme.Compiler.thunk(Val(0), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) @test thunk_a.adjoint !== thunk_b.adjoint @test thunk_c.adjoint === thunk_a.adjoint @test thunk_c.adjoint === thunk_d.adjoint @@ -160,7 +159,7 @@ end @test thunk_a(Const(f0), Active(2.0), 2.0) == ((2.0,),) @test thunk_b(Const(f0), Const(2.0)) === ((nothing,),) - forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) + forward, pullback = Enzyme.Compiler.thunk(Val(0), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) @test forward(Const(f0), Active(2.0)) == (nothing,nothing,nothing) @test pullback(Const(f0), Active(2.0), 1.0, nothing) == ((1.0,),) @@ -170,8 +169,7 @@ end end d = Duplicated([3.0, 5.0], [0.0, 0.0]) - world = codegen_world_age(typeof(mul2), Tuple{Vector{Float64}}) - forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(mul2)}, Active, Tuple{Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, true)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) + forward, pullback = Enzyme.Compiler.thunk(Val(0), Const{typeof(mul2)}, Active, Tuple{Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, true)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) res = forward(Const(mul2), d) @static if VERSION < v"1.11-" @@ -185,8 +183,7 @@ end @test d.dval[2] ≈ 3.0 d = Duplicated([3.0, 5.0], [0.0, 0.0]) - world = codegen_world_age(typeof(vrec), Tuple{Int, Vector{Float64}}) - forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(vrec)}, Active, Tuple{Const{Int}, Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false, true)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) + forward, pullback = Enzyme.Compiler.thunk(Val(0), Const{typeof(vrec)}, Active, Tuple{Const{Int}, Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false, true)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) res = forward(Const(vrec), Const(Int(1)), d) pullback(Const(vrec), Const(1), d, 1.0, res[1]) @test d.dval[1] ≈ 5.0 From f0d0895a66229dced6bcb0a5e86602c6a99026ee Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 26 Nov 2024 21:08:20 -0500 Subject: [PATCH 450/495] improve typeof (#2123) --- src/absint.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/absint.jl b/src/absint.jl index 3205dccc25..951db001d0 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -560,7 +560,8 @@ function abs_typeof( break end - if fo != 0 && fo != typed_fieldoffset(typ, i-1) + if (i != typed_fieldcount(typ) && fo != typed_fieldoffset(typ, i+1)) || + (i == typed_fieldcount(typ) && fo != actual_size(typ)) lasti = i end end From 08bb15b5e71dffd834c36a1913c60cb3db87e20c Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 27 Nov 2024 00:06:18 -0500 Subject: [PATCH 451/495] fix stability (#2125) --- src/jlrt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jlrt.jl b/src/jlrt.jl index 223697fbe0..300fdc5515 100644 --- a/src/jlrt.jl +++ b/src/jlrt.jl @@ -337,7 +337,7 @@ function val_from_byref_if_mixed(B::LLVM.IRBuilder, gutils::GradientUtils, @nosp return val end end - return emit_apply_generic!(B, [unsafe_to_llvm(B, load_if_mixed), new_from_original(gutils, oval), val]) + return emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, load_if_mixed), new_from_original(gutils, oval), val]) end act = active_reg_inner(TT, (), world) if act == ActiveState || act == MixedState From 6a8e46141cec083758adecffc218fca420b59ffc Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 27 Nov 2024 12:12:22 -0500 Subject: [PATCH 452/495] simplify deferred failures (#2126) Co-authored-by: William Moses --- src/Enzyme.jl | 9 ++++++--- src/compiler.jl | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 5305beff8d..8c94911368 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -690,13 +690,13 @@ code, as well as high-order differentiation. rt = Compiler.primal_return_type(mode, FTy, tt) A2 = A{rt} if rt == Union{} - throw(ErrorException("Return type inferred to be Union{}. Giving up.")) + rt = Nothing end else @assert A isa DataType rt = A if rt == Union{} - throw(ErrorException("Return type inferred to be Union{}. Giving up.")) + throw(ErrorException("Return type inferred to be Union{}. Giving up.")) end end @@ -841,7 +841,10 @@ code, as well as high-order differentiation. if RT isa UnionAll rt = Compiler.primal_return_type(mode, FT, tt) - rt = RT{rt} + if rt == Union{} + rt = Nothing + end + rt = RT{rt} else @assert RT isa DataType rt = RT diff --git a/src/compiler.jl b/src/compiler.jl index e836f4bb8c..f266fd5f34 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3262,7 +3262,7 @@ function primal_return_type_generator(world::UInt, source, self, @nospecialize(m return new_ci end -@eval @inline function primal_return_type(mode::Mode, ft::Type, tt::Type) +@eval Base.@assume_effects :removable :foldable :nothrow @inline function primal_return_type(mode::Mode, ft::Type, tt::Type) $(Expr(:meta, :generated_only)) $(Expr(:meta, :generated, primal_return_type_generator)) end From 06e791eaea2b56dc9c4a8465ada77d9dbbb06472 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 27 Nov 2024 22:51:21 -0500 Subject: [PATCH 453/495] Update validation.jl (#2129) --- src/compiler/validation.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 6378b2e0a6..2a5c860f64 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -459,11 +459,11 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp opv = operands(loadfn)[1] if !isa(opv, LLVM.GlobalVariable) for iv in instructions(last(blocks(initfn))) - if !(isa, LLVM.StoreInst) + if !(iv isa LLVM.StoreInst) continue end gv = operands(iv)[2] - if !(isa, LLVM.GlobalVariable) + if !(gv isa LLVM.GlobalVariable) continue end opv = gv From 45f01bd94ca560940328cd996cd7036276b7db9d Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Wed, 27 Nov 2024 21:27:40 -0800 Subject: [PATCH 454/495] make_zero(!) bugfixes and improved tests (#1961) * Fix make_zero(!) bugs * Add make_zero(!) tests Aiming for full coverage of both new and old implementations of make_zero(!) * Fix more make_zero(!) bugs and add more tests * Improve make_zero! error message * Simplify likely dead branch * Reinstate single-arg StaticArrays methods --- ext/EnzymeStaticArraysExt.jl | 47 ++- src/make_zero.jl | 237 +++++++----- test/abi.jl | 32 -- test/make_zero.jl | 725 +++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 5 files changed, 912 insertions(+), 130 deletions(-) create mode 100644 test/make_zero.jl diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index c2639a4c99..ef955ebd9b 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -32,11 +32,50 @@ end end end -@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:SArray} - return Base.zero(x) +@inline function Enzyme.EnzymeCore.make_zero( + prev::FT +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T}} + return Base.zero(prev)::FT end -@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:MArray} - return Base.zero(x) +@inline function Enzyme.EnzymeCore.make_zero( + prev::FT +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} + return Base.zero(prev)::FT +end + +@inline function Enzyme.EnzymeCore.make_zero( + ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false) +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T},copy_if_inactive} + return Base.zero(prev)::FT +end +@inline function Enzyme.EnzymeCore.make_zero( + ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false) +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T},copy_if_inactive} + if haskey(seen, prev) + return seen[prev] + end + new = Base.zero(prev)::FT + seen[prev] = new + return new +end + +@inline function Enzyme.EnzymeCore.make_zero!( + prev::FT, seen +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end + fill!(prev, zero(T)) + return nothing +end +@inline function Enzyme.EnzymeCore.make_zero!( + prev::FT +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} + Enzyme.EnzymeCore.make_zero!(prev, nothing) + return nothing end end diff --git a/src/make_zero.jl b/src/make_zero.jl index f2fd055c61..5c7b49a749 100644 --- a/src/make_zero.jl +++ b/src/make_zero.jl @@ -1,4 +1,3 @@ - @inline function EnzymeCore.make_zero(x::FT)::FT where {FT<:AbstractFloat} return Base.zero(x) end @@ -104,7 +103,7 @@ end prev::Complex{RT}, ::Val{copy_if_inactive} = Val(false), )::Complex{RT} where {copy_if_inactive,RT<:AbstractFloat} - return RT(0) + return Complex{RT}(0) end @inline function EnzymeCore.make_zero( @@ -178,7 +177,9 @@ end prev::NamedTuple{A,RT}, ::Val{copy_if_inactive} = Val(false), )::NamedTuple{A,RT} where {copy_if_inactive,A,RT} - return NamedTuple{A,RT}(EnzymeCore.make_zero(RT, seen, RT(prev), Val(copy_if_inactive))) + prevtup = RT(prev) + TT = Core.Typeof(prevtup) # RT can be abstract + return NamedTuple{A,RT}(EnzymeCore.make_zero(TT, seen, prevtup, Val(copy_if_inactive))) end @inline function EnzymeCore.make_zero( @@ -193,9 +194,7 @@ end prev2 = prev.contents res = Core.Box() seen[prev] = res - res.contents = Base.Ref( - EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)), - ) + res.contents = EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)) return res end @@ -214,7 +213,6 @@ end @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) - if ismutable(prev) y = ccall(:jl_new_struct_uninit, Any, (Any,), RT)::RT seen[prev] = y @@ -232,11 +230,9 @@ end end return y end - if nf == 0 return prev end - flds = Vector{Any}(undef, nf) for i = 1:nf if isdefined(prev, i) @@ -254,48 +250,71 @@ end end function make_zero_immutable!(prev::T, seen::S)::T where {T<:AbstractFloat,S} - zero(T) + return zero(T) end function make_zero_immutable!( prev::Complex{T}, seen::S, )::Complex{T} where {T<:AbstractFloat,S} - zero(T) + return zero(Complex{T}) end function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} + if guaranteed_const_nongen(T, nothing) + return prev # unreachable from make_zero! + end ntuple(Val(length(T.parameters))) do i Base.@_inline_meta - make_zero_immutable!(prev[i], seen) + p = prev[i] + SBT = Core.Typeof(p) + if guaranteed_const_nongen(SBT, nothing) + p # covered by several tests even if not shown in coverage + elseif !ismutabletype(SBT) + make_zero_immutable!(p, seen) + else + EnzymeCore.make_zero!(p, seen) + p + end end end function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} - NamedTuple{a,b}(ntuple(Val(length(T.parameters))) do i + if guaranteed_const_nongen(NamedTuple{a,b}, nothing) + return prev # unreachable from make_zero! + end + NamedTuple{a,b}(ntuple(Val(length(b.parameters))) do i Base.@_inline_meta - make_zero_immutable!(prev[a[i]], seen) + p = prev[a[i]] + SBT = Core.Typeof(p) + if guaranteed_const_nongen(SBT, nothing) + p # covered by several tests even if not shown in coverage + elseif !ismutabletype(SBT) + make_zero_immutable!(p, seen) + else + EnzymeCore.make_zero!(p, seen) + p + end end) end function make_zero_immutable!(prev::T, seen::S)::T where {T,S} if guaranteed_const_nongen(T, nothing) - return prev + return prev # unreachable from make_zero! end - @assert !ismutable(prev) - - RT = Core.Typeof(prev) - @assert !Base.isabstracttype(RT) - @assert Base.isconcretetype(RT) - nf = fieldcount(RT) - + @assert !ismutabletype(T) + @assert !Base.isabstracttype(T) + @assert Base.isconcretetype(T) + nf = fieldcount(T) flds = Vector{Any}(undef, nf) for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) ST = Core.Typeof(xi) - flds[i] = if active_reg_inner(ST, (), nothing, Val(true)) == ActiveState #=justActive=# + flds[i] = if guaranteed_const_nongen(ST, nothing) + xi + elseif !ismutabletype(ST) make_zero_immutable!(xi, seen) else EnzymeCore.make_zero!(xi, seen) @@ -306,39 +325,63 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T,S} break end end - ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)::T + return ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), T, flds, nf)::T end @inline function EnzymeCore.make_zero!( prev::Base.RefValue{T}, seen::ST, )::Nothing where {T<:AbstractFloat,ST} - T[] = zero(T) - nothing + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end + prev[] = zero(T) + return nothing end @inline function EnzymeCore.make_zero!( prev::Base.RefValue{Complex{T}}, seen::ST, )::Nothing where {T<:AbstractFloat,ST} - T[] = zero(Complex{T}) - nothing + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end + prev[] = zero(Complex{T}) + return nothing end @inline function EnzymeCore.make_zero!( prev::Array{T,N}, seen::ST, )::Nothing where {T<:AbstractFloat,N,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end fill!(prev, zero(T)) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::Array{Complex{T},N}, seen::ST, )::Nothing where {T<:AbstractFloat,N,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end fill!(prev, zero(Complex{T})) - nothing + return nothing end @static if VERSION < v"1.11-" @@ -347,16 +390,28 @@ else prev::GenericMemory{kind, T}, seen::ST, )::Nothing where {T<:AbstractFloat,kind,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end fill!(prev, zero(T)) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::GenericMemory{kind, Complex{T}}, seen::ST, )::Nothing where {T<:AbstractFloat,kind,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end fill!(prev, zero(Complex{T})) - nothing + return nothing end end @@ -364,90 +419,88 @@ end prev::Base.RefValue{T}, )::Nothing where {T<:AbstractFloat} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::Base.RefValue{Complex{T}}, )::Nothing where {T<:AbstractFloat} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!(prev::Array{T,N})::Nothing where {T<:AbstractFloat,N} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::Array{Complex{T},N}, )::Nothing where {T<:AbstractFloat,N} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!(prev::Array{T,N}, seen::ST)::Nothing where {T,N,ST} if guaranteed_const_nongen(T, nothing) - return + return nothing end - if in(seen, prev) - return + if prev in seen + return nothing end push!(seen, prev) - for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + continue + elseif !ismutabletype(SBT) @inbounds prev[I] = make_zero_immutable!(pv, seen) - nothing else EnzymeCore.make_zero!(pv, seen) - nothing end end end - nothing + return nothing end @static if VERSION < v"1.11-" else @inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T})::Nothing where {T<:AbstractFloat,kind} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::GenericMemory{kind, Complex{T}}, )::Nothing where {T<:AbstractFloat, kind} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T}, seen::ST)::Nothing where {T,kind,ST} if guaranteed_const_nongen(T, nothing) - return + return nothing end - if in(seen, prev) - return + if prev in seen + return nothing end push!(seen, prev) - for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + continue + elseif !ismutabletype(SBT) @inbounds prev[I] = make_zero_immutable!(pv, seen) - nothing else EnzymeCore.make_zero!(pv, seen) - nothing end end end - nothing + return nothing end end @@ -457,82 +510,78 @@ end seen::ST, )::Nothing where {T,ST} if guaranteed_const_nongen(T, nothing) - return + return nothing end - if in(seen, prev) - return + if prev in seen + return nothing end push!(seen, prev) - pv = prev[] SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + return nothing + elseif !ismutabletype(SBT) prev[] = make_zero_immutable!(pv, seen) - nothing else EnzymeCore.make_zero!(pv, seen) - nothing end - nothing + return nothing end @inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST} - pv = prev.contents - T = Core.Typeof(pv) - if guaranteed_const_nongen(T, nothing) - return - end - if in(seen, prev) - return + if prev in seen + return nothing end push!(seen, prev) + pv = prev.contents SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) - nothing + if guaranteed_const_nongen(SBT, nothing) + return nothing + elseif !ismutabletype(SBT) + prev.contents = make_zero_immutable!(pv, seen) else EnzymeCore.make_zero!(pv, seen) - nothing end - nothing + return nothing end -@inline function EnzymeCore.make_zero!( - prev::T, - seen::S = Base.IdSet{Any}(), -)::Nothing where {T,S} +@inline function EnzymeCore.make_zero!(prev::T, seen::S)::Nothing where {T,S} if guaranteed_const_nongen(T, nothing) - return + return nothing end - if in(prev, seen) - return + if prev in seen + return nothing end @assert !Base.isabstracttype(T) @assert Base.isconcretetype(T) nf = fieldcount(T) - - if nf == 0 - return + return nothing end - push!(seen, prev) - for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) SBT = Core.Typeof(xi) - if guaranteed_const_nongen(SBT, nothing) + activitystate = active_reg_inner(SBT, (), nothing) + if activitystate == AnyState # guaranteed_const continue - end - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - setfield!(prev, i, make_zero_immutable!(xi, seen)) - nothing - else + elseif ismutabletype(T) && !ismutabletype(SBT) + yi = make_zero_immutable!(xi, seen) + if Base.isconst(T, i) + ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), prev, i-1, yi) + else + setfield!(prev, i, yi) + end + elseif activitystate == DupState EnzymeCore.make_zero!(xi, seen) - nothing + else + msg = "cannot set $xi to zero in-place, as it contains differentiable values in immutable positions" + throw(ArgumentError(msg)) end end end - return + return nothing end + +@inline EnzymeCore.make_zero!(prev) = EnzymeCore.make_zero!(prev, Base.IdSet()) diff --git a/test/abi.jl b/test/abi.jl index 20747f2aaa..b6898ac1ba 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -489,38 +489,6 @@ mulsin(x) = sin(x[1] * x[2]) @test Enzyme.autodiff(ForwardWithPrimal, () -> Enzyme.within_autodiff())[1] end -mutable struct ConstVal - x::Float64 - const y::Float64 -end - -struct WithIO{F} - v::Vector{Float64} - callback::F - function WithIO(v, io) - callback() = println(io, "hello") - return new{typeof(callback)}(v, callback) - end -end - -@testset "Make Zero" begin - v = ConstVal(2.0, 3.0) - dv = make_zero(v) - @test dv isa ConstVal - @test dv.x ≈ 0.0 - @test dv.y ≈ 0.0 - - f = WithIO([1.0, 2.0], stdout) - df = @test_nowarn try - # catch errors to get failed test instead of "exception outside of a @test" - make_zero(f) - catch e - showerror(stderr, e) - end - @test df.v == [0.0, 0.0] - @test df.callback === f.callback -end - @testset "Type inference" begin x = ones(10) @inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x)) diff --git a/test/make_zero.jl b/test/make_zero.jl new file mode 100644 index 0000000000..cbe2f2159f --- /dev/null +++ b/test/make_zero.jl @@ -0,0 +1,725 @@ +module MakeZeroTests + +using Enzyme +using StaticArrays +using Test + +# Universal getters/setters for built-in and custom containers/wrappers +getx(w::Base.RefValue) = w[] +getx(w::Core.Box) = w.contents +getx(w) = first(w) +gety(w) = last(w) + +setx!(w::Base.RefValue, x) = (w[] = x) +setx!(w::Core.Box, x) = (w.contents = x) +setx!(w, x) = (w[begin] = x) +sety!(w, y) = (w[end] = y) + +# non-isbits MArray doesn't support setindex!, so requires a little hack +function setx!(w::MArray{S,T}, x) where {S,T} + if isbitstype(T) + w[begin] = x + else + w.data = (x, Base.tail(w.data)...) + end + return x +end + +function sety!(w::MArray{S,T}, y) where {S,T} + if isbitstype(T) + w[end] = y + else + w.data = (Base.front(w.data)..., y) + end + return y +end + +struct Empty end + +mutable struct MutableEmpty end + +Base.:(==)(::MutableEmpty, ::MutableEmpty) = true + +struct Wrapper{T} + x::T +end + +Base.:(==)(a::Wrapper, b::Wrapper) = (a === b) || (a.x == b.x) +getx(a::Wrapper) = a.x + +mutable struct MutableWrapper{T} + x::T +end + +Base.:(==)(a::MutableWrapper, b::MutableWrapper) = (a === b) || (a.x == b.x) + +getx(a::MutableWrapper) = a.x +setx!(a::MutableWrapper, x) = (a.x = x) + +struct DualWrapper{Tx,Ty} + x::Tx + y::Ty +end + +DualWrapper{T}(x::T, y) where {T} = DualWrapper{T,typeof(y)}(x, y) + +function Base.:(==)(a::DualWrapper, b::DualWrapper) + return (a === b) || ((a.x == b.x) && (a.y == b.y)) +end + +getx(a::DualWrapper) = a.x +gety(a::DualWrapper) = a.y + +mutable struct MutableDualWrapper{Tx,Ty} + x::Tx + y::Ty +end + +MutableDualWrapper{T}(x::T, y) where {T} = MutableDualWrapper{T,typeof(y)}(x, y) + +function Base.:(==)(a::MutableDualWrapper, b::MutableDualWrapper) + return (a === b) || ((a.x == b.x) && (a.y == b.y)) +end + +getx(a::MutableDualWrapper) = a.x +gety(a::MutableDualWrapper) = a.y + +setx!(a::MutableDualWrapper, x) = (a.x = x) +sety!(a::MutableDualWrapper, y) = (a.y = y) + +struct Incomplete{T} + s::String + x::Float64 + w::T + z # not initialized + Incomplete(s, x, w) = new{typeof(w)}(s, x, w) +end + +function Base.:(==)(a::Incomplete, b::Incomplete) + (a === b) && return true + ((a.s == b.s) && (a.x == b.x) && (a.w == b.w)) || return false + if isdefined(a, :z) && isdefined(b, :z) + (a.z == b.z) || return false + elseif isdefined(a, :z) || isdefined(b, :z) + return false + end + return true +end + +mutable struct MutableIncomplete{T} + s::String + const x::Float64 + y::Float64 + z # not initialized + w::T + function MutableIncomplete(s, x, y, w) + ret = new{typeof(w)}(s, x, y) + ret.w = w + return ret + end +end + +function Base.:(==)(a::MutableIncomplete, b::MutableIncomplete) + (a === b) && return true + if (a.s != b.s) || (a.x != b.x) || (a.y != b.y) || (a.w != b.w) + return false + end + if isdefined(a, :z) && isdefined(b, :z) + (a.z == b.z) || return false + elseif isdefined(a, :z) || isdefined(b, :z) + return false + end + return true +end + +mutable struct CustomVector{T} <: AbstractVector{T} + data::Vector{T} +end + +Base.:(==)(a::CustomVector, b::CustomVector) = (a === b) || (a.data == b.data) + +function Enzyme.EnzymeCore.make_zero( + ::Type{CV}, seen::IdDict, prev::CV, ::Val{copy_if_inactive} +) where {CV<:CustomVector{<:AbstractFloat},copy_if_inactive} + @info "make_zero(::CustomVector)" + if haskey(seen, prev) + return seen[prev] + end + new = CustomVector(zero(prev.data))::CV + seen[prev] = new + return new +end + +function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}, seen)::Nothing + @info "make_zero!(::CustomVector)" + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end + fill!(prev.data, false) + return nothing +end + +function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}) + return Enzyme.EnzymeCore.make_zero!(prev, nothing) +end + +struct WithIO{F} # issue 2091 + v::Vector{Float64} + callback::F + function WithIO(v, io) + callback() = println(io, "hello") + return new{typeof(callback)}(v, callback) + end +end + +macro test_noerr(expr) + return quote + @test_nowarn try + # catch errors to get failed test instead of "exception outside of a @test" + $(esc(expr)) + catch e + showerror(stderr, e) + end + end +end + +const scalartypes = [Float32, ComplexF32, Float64, ComplexF64] + +const inactivetup = ("a", Empty(), MutableEmpty()) +const inactivearr = [inactivetup] + +const wrappers = [ + (name="Tuple{X}", f=tuple, N=1, mutable=false, typed=true), + (name="@NamedTuple{x::X}", f=(NamedTuple{(:x,)} ∘ tuple), N=1, mutable=false, typed=true), + (name="struct{X}", f=Wrapper, N=1, mutable=false, typed=true), + + (name="@NamedTuple{x}", f=(@NamedTuple{x} ∘ tuple), N=1, mutable=false, typed=false), + (name="struct{Any}", f=Wrapper{Any}, N=1, mutable=false, typed=false), + + (name="Array{X}", f=(x -> [x]), N=1, mutable=true, typed=true), + (name="Base.RefValue{X}", f=Ref, N=1, mutable=true, typed=true), + (name="mutable struct{X}", f=MutableWrapper, N=1, mutable=true, typed=true), + + (name="Array{Any}", f=(x -> Any[x]), N=1, mutable=true, typed=false), + (name="Base.RefValue{Any}", f=Ref{Any}, N=1, mutable=true, typed=false), + (name="Core.Box", f=Core.Box, N=1, mutable=true, typed=false), + (name="mutable struct{Any}", f=MutableWrapper{Any}, N=1, mutable=true, typed=false), + + (name="Tuple{X,Y}", f=tuple, N=2, mutable=false, typed=true), + (name="@NamedTuple{x::X,y::Y}", f=(NamedTuple{(:x, :y)} ∘ tuple), N=2, mutable=false, typed=true), + (name="struct{X,Y}", f=DualWrapper, N=2, mutable=false, typed=true), + + (name="@NamedTuple{x,y::Y}", f=((x, y) -> @NamedTuple{x,y::typeof(y)}((x, y))), N=2, mutable=false, typed=:partial), + (name="struct{Any,Y}", f=DualWrapper{Any}, N=2, mutable=false, typed=:partial), + + (name="@NamedTuple{x,y}", f=@NamedTuple{x,y} ∘ tuple, N=2, mutable=false, typed=false), + (name="struct{Any}", f=DualWrapper{Any,Any}, N=2, mutable=false, typed=false), + + (name="mutable struct{X,Y}", f=MutableDualWrapper, N=2, mutable=true, typed=true), + + (name="Array{promote_type(X,Y)}", f=((x, y) -> [x, y]), N=2, mutable=true, typed=:promoted), + (name="mutable struct{Any,Y}", f=MutableDualWrapper{Any}, N=2, mutable=true, typed=:partial), + + (name="Array{Any}", f=((x, y) -> Any[x, y]), N=2, mutable=true, typed=false), + (name="mutable struct{Any,Any}", f=MutableDualWrapper{Any,Any}, N=2, mutable=true, typed=false), + + # StaticArrays extension + (name="SVector{1,X}", f=SVector{1} ∘ tuple, N=1, mutable=false, typed=true), + (name="SVector{1,Any}", f=SVector{1,Any} ∘ tuple, N=1, mutable=false, typed=false), + (name="MVector{1,X}", f=MVector{1} ∘ tuple, N=1, mutable=true, typed=true), + (name="MVector{1,Any}", f=MVector{1,Any} ∘ tuple, N=1, mutable=true, typed=false), + (name="SVector{2,promote_type(X,Y)}", f=SVector{2} ∘ tuple, N=2, mutable=false, typed=:promoted), + (name="SVector{2,Any}", f=SVector{2,Any} ∘ tuple, N=2, mutable=false, typed=false), + (name="MVector{2,promote_type(X,Y)}", f=MVector{2} ∘ tuple, N=2, mutable=true, typed=:promoted), + (name="MVector{2,Any}", f=MVector{2,Any} ∘ tuple, N=2, mutable=true, typed=false), +] + +@static if VERSION < v"1.11-" +else +_memory(x::Vector) = Memory{eltype(x)}(x) +push!( + wrappers, + (name="Memory{X}", f=(x -> _memory([x])), N=1, mutable=true, typed=true), + (name="Memory{Any}", f=(x -> _memory(Any[x])), N=1, mutable=true, typed=false), + (name="Memory{promote_type(X,Y)}", f=((x, y) -> _memory([x, y])), N=2, mutable=true, typed=:promoted), + (name="Memory{Any}", f=((x, y) -> _memory(Any[x, y])), N=2, mutable=true, typed=false), +) +end + +function test_make_zero() + @testset "scalars" begin + @testset "$T" for T in scalartypes + x = oneunit(T) + x_makez = make_zero(x) + @test typeof(x_makez) === T # correct type + @test x_makez == zero(T) # correct value + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + end + @testset "nested types" begin + @testset "$T in $(wrapper.name)" for + T in scalartypes, wrapper in filter(w -> (w.N == 1), wrappers) + x = oneunit(T) + w = wrapper.f(x) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === T # correct type + @test getx(w_makez) == zero(T) # correct value + @test getx(w) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + @testset "doubly included in $(dualwrapper.name)" for + dualwrapper in filter(w -> (w.N == 2), wrappers) + w_inner = wrapper.f(x) + d_outer = dualwrapper.f(w_inner, w_inner) + d_outer_makez = make_zero(d_outer) + @test typeof(d_outer_makez) === typeof(d_outer) # correct type + @test typeof(getx(d_outer_makez)) === typeof(w_inner) # correct type + @test typeof(getx(getx(d_outer_makez))) === T # correct type + @test getx(d_outer_makez) === gety(d_outer_makez) # correct topology + @test getx(getx(d_outer_makez)) == zero(T) # correct value + @test getx(d_outer) === gety(d_outer) # no mutation of original + @test getx(d_outer) === w_inner # no mutation of original + @test getx(w_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + d_inner = dualwrapper.f(x, x) + w_outer = wrapper.f(d_inner) + w_outer_makez = make_zero(w_outer) + @test typeof(w_outer_makez) === typeof(w_outer) # correct type + @test typeof(getx(w_outer_makez)) === typeof(d_inner) # correct type + @test typeof(getx(getx(w_outer_makez))) === T # correct type + @test getx(getx(w_outer_makez)) == gety(getx(w_outer_makez)) # correct topology + @test getx(getx(w_outer_makez)) == zero(T) # correct value + @test getx(w_outer) === d_inner # no mutation of original + @test getx(d_inner) === gety(d_inner) # no mutation of original + @test getx(d_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + if wrapper.mutable && !dualwrapper.mutable + # some code paths can only be hit with three layers of wrapping: + # mutable(immutable(mutable(scalar))) + @testset "all wrapped in $(outerwrapper.name)" for + outerwrapper in filter(w -> ((w.N == 1) && w.mutable), wrappers) + w_inner = wrapper.f(x) + d_middle = dualwrapper.f(w_inner, w_inner) + w_outer = outerwrapper.f(d_middle) + w_outer_makez = make_zero(w_outer) + @test typeof(w_outer_makez) === typeof(w_outer) # correct type + @test typeof(getx(w_outer_makez)) === typeof(d_middle) # correct type + @test typeof(getx(getx(w_outer_makez))) === typeof(w_inner) # correct type + @test typeof(getx(getx(getx(w_outer_makez)))) === T # correct type + @test getx(getx(w_outer_makez)) === gety(getx(w_outer_makez)) # correct topology + @test getx(getx(getx(w_outer_makez))) == zero(T) # correct value + @test getx(w_outer) === d_middle # no mutation of original + @test getx(d_middle) === gety(d_middle) # no mutation of original + @test getx(d_middle) === w_inner # no mutation of original + @test getx(w_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + end + end + end + end + @testset "inactive" begin + @testset "in $(wrapper.name)" for wrapper in wrappers + if wrapper.N == 1 + w = wrapper.f(inactivearr) + w_makez = make_zero(w) + if wrapper.typed == true + @test w_makez === w # preserved wrapper identity if guaranteed const + end + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + @test getx(w) === inactivearr # no mutation of original + else # wrapper.N == 2 + @testset "multiple references" begin + w = wrapper.f(inactivearr, inactivearr) + w_makez = make_zero(w) + if wrapper.typed == true + @test w_makez === w # preserved wrapper identity if guaranteed const + end + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === gety(w_makez) # preserved topology + @test getx(w_makez) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + @test getx(w) === gety(w) # no mutation of original + @test getx(w) === inactivearr # no mutation of original + end + @testset "alongside active" begin + a = [1.0] + w = wrapper.f(a, inactivearr) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === typeof(a) # correct type + @test getx(w_makez) == [0.0] # correct value + @test gety(w_makez) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + @test getx(w) === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + @test gety(w) === inactivearr # no mutation of original + if wrapper.typed == :partial + # above: untyped active / typed inactive + # below: untyped inactive / typed active + w = wrapper.f(inactivearr, a) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + @test typeof(gety(w_makez)) === typeof(a) # correct type + @test gety(w_makez) == [0.0] # correct value + @test getx(w) === inactivearr # no mutation of original + @test gety(w) === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + end + end + @testset "copy_if_inactive $value" for (value, args) in [ + ("unspecified", ()), + ("= false", (Val(false),)), + ("= true", (Val(true),)), + ] + a = [1.0] + w = Any[a, inactivearr, inactivearr] + w_makez = make_zero(w, args...) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(w_makez[1]) === typeof(a) # correct type + @test w_makez[1] == [0.0] # correct value + @test w_makez[2] === w_makez[3] # correct topology (topology should propagate even when copy_if_inactive = Val(true)) + @test w[1] === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + @test w[2] === w[3] # no mutation of original + @test w[2] === inactivearr # no mutation of original + @test inactivearr[1] === inactivetup # no mutation of original + if args == (Val(true),) + @test typeof(w_makez[2]) === typeof(inactivearr) # correct type + @test w_makez[2] == inactivearr # correct value + @test w_makez[2][1] !== inactivetup # correct identity + else + @test w_makez[2] === inactivearr # correct value/type/identity + end + end + end + @testset "heterogeneous containers" begin + scalars, scalarsz = oneunit.(scalartypes), zero.(scalartypes) + wraps, wrapsz = Wrapper.(scalars), Wrapper.(scalarsz) + mwraps, mwrapsz = MutableWrapper.(scalars), MutableWrapper.(scalarsz) + items = (inactivetup..., scalars..., wraps..., mwraps...) + itemsz = (inactivetup..., scalarsz..., wrapsz..., mwrapsz...) + labels = Symbol.("i" .* string.(1:length(items))) + @testset "$name" for (name, c, cz) in [ + ("Tuple", Tuple(items), Tuple(itemsz)), + ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), + ("Array", collect(items), collect(itemsz)), + ] + c_makez = make_zero(c) + @test typeof(c_makez) === typeof(c) # correct type + @test all(typeof(czj) === typeof(cj) for (czj, cj) in zip(c_makez, c)) # correct type + @test c_makez == cz # correct value + @test all(czj === inj for (czj, inj) in zip(c_makez, inactivetup)) # preserved inactive identities + @test all(cj === itj for (cj, itj) in zip(c, items)) # no mutation of original + @test all(m.x == oneunit(m.x) for m in mwraps) # no mutation of original + end + end + @testset "circular references" begin + @testset "$(wrapper.name)" for wrapper in ( + filter(w -> (w.mutable && (w.typed in (:partial, false))), wrappers) + ) + a = [1.0] + if wrapper.N == 1 + w = wrapper.f(nothing) + setx!(w, (w, a)) + else + w = wrapper.f(nothing, a) + setx!(w, w) + end + w_makez = @test_noerr make_zero(w) + if wrapper.N == 1 + xz, yz = getx(w_makez) + x, y = getx(w) + else + xz, yz = getx(w_makez), gety(w_makez) + x, y = getx(w), gety(w) + end + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(xz) === typeof(w) # correct type + @test typeof(yz) === typeof(a) # correct type + @test xz === w_makez # correct self-reference + @test yz == [0.0] # correct value + @test x === w # no mutation of original + @test y === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "bring your own IdDict" begin + a = [1.0] + seen = IdDict() + a_makez = make_zero(typeof(a), seen, a) + @test typeof(a_makez) === typeof(a) # correct type + @test a_makez == [0.0] # correct value + @test a[1] === 1.0 # no mutation of original + @test haskey(seen, a) # original added to IdDict + @test seen[a] === a_makez # original points to zeroed value + end + @testset "custom leaf type" begin + a = [1.0] + v = CustomVector(a) + # include optional arg Val(false) to avoid calling the custom method directly; + # it should still be invoked + v_makez = @test_logs (:info, "make_zero(::CustomVector)") make_zero(v, Val(false)) + @test typeof(v_makez) === typeof(v) # correct type + @test typeof(v_makez.data) === typeof(a) # correct type + @test v_makez == CustomVector([0.0]) # correct value + @test v.data === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + @testset "undefined fields/unassigned elements" begin + @testset "array w inactive/active/mutable/unassigned" begin + a = [1.0] + values = ("a", 1.0, a) + arr = Vector{Any}(undef, 4) + arr[1:3] .= values + arr_makez = make_zero(arr) + @views begin + @test typeof(arr_makez) === typeof(arr) # correct type + @test all(typeof.(arr_makez[1:3]) .=== typeof.(values)) # correct type + @test arr_makez[1:3] == ["a", 0.0, [0.0]] # correct value + @test !isassigned(arr_makez, 4) # propagated undefined + @test all(arr[1:3] .=== values) # no mutation of original + @test !isassigned(arr, 4) # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "struct w inactive/active/mutable/undefined" begin + a = [1.0] + incomplete = Incomplete("a", 1.0, a) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == Incomplete("a", 0.0, [0.0]) # correct value, propagated undefined + @test a[1] === 1.0 # no mutation of original + end + @testset "mutable struct w inactive/const active/active/mutable/undefined" begin + a = [1.0] + incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, propagated undefined + @test incomplete == MutableIncomplete("a", 1.0, 1.0, a) # no mutation of original + @test incomplete.w === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "containing IO" begin # issue #2091 + f = WithIO([1.0, 2.0], stdout) + df = @test_noerr make_zero(f) + @test df.v == [0.0, 0.0] + @test df.callback === f.callback + end + return nothing +end + +function test_make_zero!() + @testset "nested types" begin + @testset "$T in $(wrapper.name)" for + T in scalartypes, wrapper in filter(w -> (w.N == 1), wrappers) + x = oneunit(T) + if wrapper.mutable + w = wrapper.f(x) + make_zero!(w) + @test typeof(getx(w)) === T # preserved type + @test getx(w) == zero(T) # correct value + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + end + @testset "doubly included in $(dualwrapper.name)" for dualwrapper in ( + filter(w -> ((w.N == 2) && (w.mutable || wrapper.mutable)), wrappers) + ) + w_inner = wrapper.f(x) + d_outer = dualwrapper.f(w_inner, w_inner) + make_zero!(d_outer) + @test typeof(getx(d_outer)) === typeof(w_inner) # preserved type + @test typeof(getx(getx(d_outer))) === T # preserved type + @test getx(getx(d_outer)) == zero(T) # correct value + @test getx(d_outer) === gety(d_outer) # preserved topology + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + if wrapper.mutable + @test getx(d_outer) === w_inner # preserved identity + end + d_inner = dualwrapper.f(x, x) + w_outer = wrapper.f(d_inner) + make_zero!(w_outer) + @test typeof(getx(w_outer)) === typeof(d_inner) # preserved type + @test typeof(getx(getx(w_outer))) === T # preserved type + @test getx(getx(w_outer)) == zero(T) # correct value + @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved topology + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + if dualwrapper.mutable + @test getx(w_outer) === d_inner # preserved identity + end + if wrapper.mutable && !dualwrapper.mutable + # some code paths can only be hit with three layers of wrapping: + # mutable(immutable(mutable(scalar))) + @assert !dualwrapper.mutable # sanity check + @testset "all wrapped in $(outerwrapper.name)" for + outerwrapper in filter(w -> ((w.N == 1) && w.mutable), wrappers) + w_inner = wrapper.f(x) + d_middle = dualwrapper.f(w_inner, w_inner) + w_outer = outerwrapper.f(d_middle) + make_zero!(w_outer) + @test typeof(getx(w_outer)) === typeof(d_middle) # preserved type + @test typeof(getx(getx(w_outer))) === typeof(w_inner) # preserved type + @test typeof(getx(getx(getx(w_outer)))) === T # preserved type + @test getx(getx(getx(w_outer))) == zero(T) # correct value + @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved topology + @test getx(getx(w_outer)) === w_inner # preserved identity + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + end + end + end + end + end + @testset "inactive" begin + @testset "in $(wrapper.name)" for + wrapper in filter(w -> (w.mutable || (w.typed == true)), wrappers) + if wrapper.N == 1 + w = wrapper.f(inactivearr) + make_zero!(w) + @test getx(w) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + else # wrapper.N == 2 + @testset "multiple references" begin + w = wrapper.f(inactivearr, inactivearr) + make_zero!(w) + @test getx(w) === gety(w) # preserved topology + @test getx(w) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + end + @testset "alongside active" begin + a = [1.0] + w = wrapper.f(a, inactivearr) + make_zero!(w) + @test getx(w) === a # preserved identity + @test a[1] === 0.0 # correct value + @test gety(w) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + end + end + end + end + @testset "heterogeneous containers" begin + mwraps = MutableWrapper.(oneunit.(scalartypes)) + mwrapsz = MutableWrapper.(zero.(scalartypes)) + items = (inactivetup..., mwraps...) + itemsz = (inactivetup..., mwrapsz...) + labels = Symbol.("i" .* string.(1:length(items))) + @testset "$name" for (name, c, cz) in [ + ("Tuple", Tuple(items), Tuple(itemsz)), + ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), + ("Array", collect(items), collect(itemsz)), + ] + make_zero!(c) + @test all(cj === itj for (cj, itj) in zip(c, items)) # preserved identities + @test c == cz # correct value + end + end + @testset "circular references" begin + @testset "$(wrapper.name)" for wrapper in ( + filter(w -> (w.mutable && (w.typed in (:partial, false))), wrappers) + ) + a = [1.0] + if wrapper.N == 1 + w = wrapper.f(nothing) + setx!(w, (w, a)) + else + w = wrapper.f(nothing, a) + setx!(w, w) + end + @test_noerr make_zero!(w) + if wrapper.N == 1 + x, y = getx(w) + else + x, y = getx(w), gety(w) + end + @test x === w # preserved self-referential identity + @test y === a # preserved identity + @test a[1] === 0.0 # correct value + end + end + @testset "bring your own IdSet" begin + a = [1.0] + seen = Base.IdSet() + make_zero!(a, seen) + @test a[1] === 0.0 # correct value + @test (a in seen) # object added to IdSet + end + @testset "custom leaf type" begin + a = [1.0] + v = CustomVector(a) + # bringing own IdSet to avoid calling the custom method directly; + # it should still be invoked + @test_logs (:info, "make_zero!(::CustomVector)") make_zero!(v, Base.IdSet()) + @test v.data === a # preserved identity + @test a[1] === 0.0 # correct value + end + @testset "undefined fields/unassigned elements" begin + @testset "array w inactive/active/mutable/unassigned" begin + a = [1.0] + values = ("a", 1.0, a) + arr = Vector{Any}(undef, 4) + arr[1:3] .= values + make_zero!(arr) + @views begin + @test all(typeof.(arr[1:3]) .=== typeof.(values)) # preserved types + @test arr[1:3] == ["a", 0.0, [0.0]] # correct value + @test arr[3] === a # preserved identity + @test !isassigned(arr, 4) # preserved unassigned + end + end + @testset "struct w inactive/active/mutable/undefined" begin + a = [1.0] + incompletearr = [Incomplete("a", 1.0, a)] + make_zero!(incompletearr) + @test incompletearr == [Incomplete("a", 0.0, [0.0])] # correct value, preserved undefined + @test incompletearr[1].w === a # preserved identity + end + @testset "mutable struct w inactive/const active/active/mutable/undefined" begin + a = [1.0] + incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) + make_zero!(incomplete) + @test incomplete == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, preserved undefined + @test incomplete.w === a # preserved identity + end + @testset "Array{Tuple{struct w undefined}} (issue #1935)" begin + # old implementation triggered #1935 + # new implementation would work regardless due to limited use of justActive + a = [1.0] + incomplete = Incomplete("a", 1.0, a) + incompletetuparr = [(incomplete,)] + make_zero!(incompletetuparr) + @test typeof(incompletetuparr[1]) === typeof((incomplete,)) # preserved type + @test incompletetuparr == [(Incomplete("a", 0.0, [0.0]),)] # correct value + @test incompletetuparr[1][1].w === a # preserved identity + end + end + @testset "active/mixed type error" begin + @test_throws ArgumentError make_zero!((1.0,)) + @test_throws ArgumentError make_zero!((1.0, [1.0])) + @test_throws ArgumentError make_zero!((Incomplete("a", 1.0, 1.0im),)) # issue #1935 + end + @testset "containing IO" begin # issue #2091 + f = WithIO([1.0, 2.0], stdout) + fwrapped = [f] + @test_noerr make_zero!(fwrapped) + @test fwrapped[1] === f + @test fwrapped[1].v == [0.0, 0.0] + end + return nothing +end + +@testset "make_zero" test_make_zero() +@testset "make_zero!" test_make_zero!() + +end # module MakeZeroTests diff --git a/test/runtests.jl b/test/runtests.jl index 26587c3892..e331645378 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,6 +74,7 @@ end include("abi.jl") include("typetree.jl") include("optimize.jl") +include("make_zero.jl") include("rules.jl") include("rrules.jl") From e9d303be9232b7ce1de6cc8234a225b5bbd0e7d9 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 28 Nov 2024 01:13:18 -0500 Subject: [PATCH 455/495] Remove julia level type rules (#2130) --- src/api.jl | 6 +- src/compiler.jl | 102 +---------------------------- src/rules/allocrules.jl | 14 +--- src/rules/typerules.jl | 140 ---------------------------------------- src/typeanalysis.jl | 18 ++++-- 5 files changed, 16 insertions(+), 264 deletions(-) delete mode 100644 src/rules/typerules.jl diff --git a/src/api.jl b/src/api.jl index 2861eba86f..3cdba76ae8 100644 --- a/src/api.jl +++ b/src/api.jl @@ -441,9 +441,9 @@ function CreateTypeAnalysis(logic, rulenames, rules) EnzymeTypeAnalysisRef, (EnzymeLogicRef, Ptr{Cstring}, Ptr{CustomRuleType}, Csize_t), logic, - rulenames, - rules, - length(rules), + rulenames isa Tuple{} ? C_NULL : rulenames, + rules isa Tuple{} ? C_NULL : rules, + rulenames isa Tuple{} ? 0 : length(rules), ) end diff --git a/src/compiler.jl b/src/compiler.jl index f266fd5f34..cbfac114e3 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3940,7 +3940,6 @@ function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error::Bool = tru return ty, byref end -include("rules/typerules.jl") include("rules/activityrules.jl") @inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:Const} = API.DFT_CONSTANT @@ -4073,107 +4072,8 @@ function enzyme!( convert(API.CDIFFE_TYPE, rt) end - rules = Dict{String,API.CustomRuleType}( - "jl_array_copy" => @cfunction( - inout_rule, - UInt8, - ( - Cint, - API.CTypeTreeRef, - Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, - Csize_t, - LLVM.API.LLVMValueRef, - ) - ), - "ijl_array_copy" => @cfunction( - inout_rule, - UInt8, - ( - Cint, - API.CTypeTreeRef, - Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, - Csize_t, - LLVM.API.LLVMValueRef, - ) - ), - "jl_genericmemory_copy_slice" => @cfunction( - inoutcopyslice_rule, - UInt8, - ( - Cint, - API.CTypeTreeRef, - Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, - Csize_t, - LLVM.API.LLVMValueRef, - ) - ), - "ijl_genericmemory_copy_slice" => @cfunction( - inoutcopyslice_rule, - UInt8, - ( - Cint, - API.CTypeTreeRef, - Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, - Csize_t, - LLVM.API.LLVMValueRef, - ) - ), - "jl_inactive_inout" => @cfunction( - inout_rule, - UInt8, - ( - Cint, - API.CTypeTreeRef, - Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, - Csize_t, - LLVM.API.LLVMValueRef, - ) - ), - "jl_excstack_state" => @cfunction( - int_return_rule, - UInt8, - ( - Cint, - API.CTypeTreeRef, - Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, - Csize_t, - LLVM.API.LLVMValueRef, - ) - ), - "ijl_excstack_state" => @cfunction( - int_return_rule, - UInt8, - ( - Cint, - API.CTypeTreeRef, - Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, - Csize_t, - LLVM.API.LLVMValueRef, - ) - ), - "julia.except_enter" => @cfunction( - int_return_rule, - UInt8, - ( - Cint, - API.CTypeTreeRef, - Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, - Csize_t, - LLVM.API.LLVMValueRef, - ) - ), - ) - logic = Logic() - TA = TypeAnalysis(logic, rules) + TA = TypeAnalysis(logic) retTT = if !isa(actualRetType, Union) && actualRetType <: Tuple && diff --git a/src/rules/allocrules.jl b/src/rules/allocrules.jl index 83a9a22cd4..7c611b6c85 100644 --- a/src/rules/allocrules.jl +++ b/src/rules/allocrules.jl @@ -86,14 +86,6 @@ function array_shadow_handler( return ref end -function null_free_handler( - B::LLVM.API.LLVMBuilderRef, - ToFree::LLVM.API.LLVMValueRef, - Fn::LLVM.API.LLVMValueRef, -)::LLVM.API.LLVMValueRef - return C_NULL -end - function register_alloc_handler!(variants, alloc_handler, free_handler) for variant in variants API.EnzymeRegisterAllocationHandler(variant, alloc_handler, free_handler) @@ -120,10 +112,6 @@ end API.EnzymeGradientUtilsRef, ) ), - @cfunction( - null_free_handler, - LLVM.API.LLVMValueRef, - (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef) - ) + C_NULL ) end diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl deleted file mode 100644 index 67c4776c4a..0000000000 --- a/src/rules/typerules.jl +++ /dev/null @@ -1,140 +0,0 @@ - -function int_return_rule( - direction::Cint, - ret::API.CTypeTreeRef, - args::Ptr{API.CTypeTreeRef}, - known_values::Ptr{API.IntList}, - numArgs::Csize_t, - val::LLVM.API.LLVMValueRef, -)::UInt8 - TT = TypeTree(API.DT_Integer, LLVM.context(LLVM.Value(val))) - only!(TT, -1) - API.EnzymeMergeTypeTree(ret, TT) - return UInt8(false) -end - -function inout_rule( - direction::Cint, - ret::API.CTypeTreeRef, - args::Ptr{API.CTypeTreeRef}, - known_values::Ptr{API.IntList}, - numArgs::Csize_t, - val::LLVM.API.LLVMValueRef, -)::UInt8 - if numArgs != 1 - return UInt8(false) - end - inst = LLVM.Instruction(val) - - legal, typ = abs_typeof(inst) - - if legal - if (direction & API.DOWN) != 0 - ctx = LLVM.context(inst) - dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - rest = typetree(typ, ctx, dl) - if GPUCompiler.deserves_retbox(typ) - rest = copy(rest) - merge!(rest, TypeTree(API.DT_Pointer, ctx)) - only!(rest, -1) - end - changed, legal = API.EnzymeCheckedMergeTypeTree(ret, rest) - @assert legal - end - return UInt8(false) - end - - if (direction & API.UP) != 0 - changed, legal = API.EnzymeCheckedMergeTypeTree(unsafe_load(args), ret) - @assert legal - end - if (direction & API.DOWN) != 0 - changed, legal = API.EnzymeCheckedMergeTypeTree(ret, unsafe_load(args)) - @assert legal - end - return UInt8(false) -end - -function inoutcopyslice_rule( - direction::Cint, - ret::API.CTypeTreeRef, - args::Ptr{API.CTypeTreeRef}, - known_values::Ptr{API.IntList}, - numArgs::Csize_t, - val::LLVM.API.LLVMValueRef, -)::UInt8 - if numArgs != 1 - return UInt8(false) - end - inst = LLVM.Instruction(val) - - legal, typ = abs_typeof(inst) - - if legal - if (direction & API.DOWN) != 0 - ctx = LLVM.context(inst) - dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - rest = typetree(typ, ctx, dl) - if GPUCompiler.deserves_retbox(typ) - rest = copy(rest) - merge!(rest, TypeTree(API.DT_Pointer, ctx)) - only!(rest, -1) - end - changed, legal = API.EnzymeCheckedMergeTypeTree(ret, rest) - @assert legal - end - return UInt8(false) - end - - if (direction & API.UP) != 0 - changed, legal = API.EnzymeCheckedMergeTypeTree(unsafe_load(args), ret) - @assert legal - end - if (direction & API.DOWN) != 0 - changed, legal = API.EnzymeCheckedMergeTypeTree(ret, unsafe_load(args)) - @assert legal - end - return UInt8(false) -end - -function inoutgcloaded_rule( - direction::Cint, - ret::API.CTypeTreeRef, - args::Ptr{API.CTypeTreeRef}, - known_values::Ptr{API.IntList}, - numArgs::Csize_t, - val::LLVM.API.LLVMValueRef, -)::UInt8 - if numArgs != 1 - return UInt8(false) - end - inst = LLVM.Instruction(val) - - legal, typ = abs_typeof(inst) - - if legal - if (direction & API.DOWN) != 0 - ctx = LLVM.context(inst) - dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - rest = typetree(typ, ctx, dl) - if GPUCompiler.deserves_retbox(typ) - rest = copy(rest) - merge!(rest, TypeTree(API.DT_Pointer, ctx)) - only!(rest, -1) - end - changed, legal = API.EnzymeCheckedMergeTypeTree(ret, rest) - @assert legal - end - return UInt8(false) - end - - if (direction & API.UP) != 0 - changed, legal = API.EnzymeCheckedMergeTypeTree(unsafe_load(args, 2), ret) - @assert legal - end - if (direction & API.DOWN) != 0 - changed, legal = API.EnzymeCheckedMergeTypeTree(ret, unsafe_load(args, 2)) - @assert legal - end - return UInt8(false) -end diff --git a/src/typeanalysis.jl b/src/typeanalysis.jl index a84f96f856..a7bd3d1a77 100644 --- a/src/typeanalysis.jl +++ b/src/typeanalysis.jl @@ -9,15 +9,19 @@ LLVM.dispose(ta::TypeAnalysis) = API.FreeTypeAnalysis(ta) function TypeAnalysis( logic, - typerules::Dict{String,CustomRuleType} = Dict{String,CustomRuleType}(), + typerules::Union{Dict{String,CustomRuleType}, Nothing} = nothing, ) - rulenames = String[] - rules = CustomRuleType[] - for (rulename, rule) in typerules - push!(rulenames, rulename) - push!(rules, rule) + if typerules isa Nothing + ref = API.CreateTypeAnalysis(logic, (), ()) + else + rulenames = String[] + rules = CustomRuleType[] + for (rulename, rule) in typerules + push!(rulenames, rulename) + push!(rules, rule) + end + ref = API.CreateTypeAnalysis(logic, rulenames, rules) end - ref = API.CreateTypeAnalysis(logic, rulenames, rules) TypeAnalysis(ref) end From 8b6538148cc03417fb43d9a9bdec14bad8febefc Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 28 Nov 2024 01:13:33 -0500 Subject: [PATCH 456/495] Save julia types on sret (#2127) * Save julia types on sret * fix * lig --- src/absint.jl | 10 ++++++++++ src/compiler.jl | 48 +++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 951db001d0..c4866fe4da 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -305,6 +305,16 @@ function abs_typeof( end end + if isa(arg, LLVM.AllocaInst) || isa(arg, LLVM.CallInst) + if haskey(metadata(arg), "enzymejl_allocart") + mds = operands(metadata(arg)["enzymejl_allocart"])[1]::MDString + mds = Base.convert(String, mds) + ptr = reinterpret(Ptr{Cvoid}, parse(UInt, mds)) + RT = Base.unsafe_pointer_to_objref(ptr) + return (true, RT, GPUCompiler.MUT_REF) + end + end + if isa(arg, LLVM.CallInst) fn = LLVM.called_operand(arg) nm = "" diff --git a/src/compiler.jl b/src/compiler.jl index cbfac114e3..ad9367fca3 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1727,6 +1727,7 @@ end else "Unknown object of type" * " " * string(TT) end + @assert !illegal illegalVal = cur illegal = true return make_batched(ncur, prevbb) @@ -1770,6 +1771,7 @@ end end cur2 = if changed + @assert !illegal illegalVal = cur illegal = true # TODO replace with correct insertions/splats @@ -1942,8 +1944,10 @@ end return make_batched(ncur, prevbb) end - illegal = true - illegalVal = cur + if !illegal + illegal = true + illegalVal = cur + end return ncur end @@ -7070,10 +7074,48 @@ end ctx = LLVM.context(mod) for f in functions(mod), bb in blocks(f), inst in instructions(bb) fn = isa(inst, LLVM.CallInst) ? LLVM.called_operand(inst) : nothing + + if !API.HasFromStack(inst) && isa(inst, LLVM.AllocaInst) + + calluse = nothing + for u in LLVM.uses(inst) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) && operands(u)[1] == inst + + sretkind = kind(if LLVM.version().major >= 12 + TypeAttribute("sret", LLVM.Int32Type()) + else + EnumAttribute("sret") + end) + hassret = false + llvmfn = LLVM.called_operand(u) + if llvmfn isa LLVM.Function + for attr in collect(parameter_attributes(llvmfn, 1)) + if kind(attr) == sretkind + hassret = true + break + end + end + end + if hassret + calluse = u + end + end + end + if calluse isa LLVM.CallInst + _, RT = enzyme_custom_extract_mi(calluse, false) + if RT !== nothing + llrt, sret, returnRoots = get_return_info(RT) + if !(sret isa Nothing) && !is_sret_union(RT) + metadata(inst)["enzymejl_allocart"] = MDNode(LLVM.Metadata[MDString(string(convert(UInt, unsafe_to_pointer(RT))))]) + end + end + end + end if !API.HasFromStack(inst) && ((isa(inst, LLVM.CallInst) && - (!isa(fn, LLVM.Function) || isempty(blocks(fn))) ) || isa(inst, LLVM.LoadInst)) + (!isa(fn, LLVM.Function) || isempty(blocks(fn))) ) || isa(inst, LLVM.LoadInst) || isa(inst, LLVM.AllocaInst)) legal, source_typ, byref = abs_typeof(inst) codegen_typ = value_type(inst) if legal From b78ec7b43c4d7f00e0e01002b9dd5147da887d0c Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 28 Nov 2024 11:40:35 -0500 Subject: [PATCH 457/495] Update Project.toml (#2134) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 369e43f88a..b5d14f19ad 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.7" -Enzyme_jll = "0.0.165" +Enzyme_jll = "0.0.166" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From 30b6b2d93d8ef1bdfb9f628e8c111d123cc4595e Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 28 Nov 2024 11:55:22 -0500 Subject: [PATCH 458/495] Fanon2 (#2136) * Fewer anonymous funcs * minor cleanup * cleanup * fix * Aggressively noinfer * more type annotations --- src/absint.jl | 16 ++++---- src/compiler.jl | 75 ++++++++++++++++++++++--------------- src/compiler/interpreter.jl | 14 +++---- src/compiler/validation.jl | 20 +++++++--- src/jlrt.jl | 4 +- src/rules/allocrules.jl | 7 +--- src/utils.jl | 54 +++++++++++++++++++------- 7 files changed, 118 insertions(+), 72 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index c4866fe4da..4bc607ed7f 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -1,7 +1,7 @@ # Abstractly interpret julia from LLVM # Return (bool if could interpret, julia object interpreted to) -function absint(arg::LLVM.Value, partial::Bool = false) +function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false)::Tuple{Bool,Any} if isa(arg, LLVM.BitCastInst) || isa(arg, LLVM.AddrSpaceCastInst) return absint(operands(arg)[1], partial) end @@ -165,7 +165,7 @@ function absint(arg::LLVM.Value, partial::Bool = false) return (false, nothing) end -function actual_size(@nospecialize(typ2)) +function actual_size(@nospecialize(typ2))::Int @static if VERSION < v"1.11-" if typ2 <: Array return sizeof(Ptr{Cvoid}) + 2 + 2 + 4 + 2 * sizeof(Csize_t) + sizeof(Csize_t) @@ -184,10 +184,10 @@ function actual_size(@nospecialize(typ2)) end end -@inline function first_non_ghost(@nospecialize(typ2)) +@inline function first_non_ghost(@nospecialize(typ2))::Tuple{Int, Int} @static if VERSION < v"1.11-" if typ2 <: Array - return (1, typed_fieldtype(typ2, 1)) + return (1, 0) end end fc = fieldcount(typ2) @@ -204,7 +204,7 @@ end return (-1, 0) end -function should_recurse(@nospecialize(typ2), arg_t, byref, dl) +function should_recurse(@nospecialize(typ2), @nospecialize(arg_t::LLVM.LLVMType), byref::GPUCompiler.ArgumentCC, dl::LLVM.DataLayout)::Bool sz = sizeof(dl, arg_t) if byref != GPUCompiler.BITS_VALUE if sz != sizeof(Int) @@ -228,7 +228,7 @@ function should_recurse(@nospecialize(typ2), arg_t, byref, dl) end end -function get_base_and_offset(larg::LLVM.Value; offsetAllowed=true, inttoptr=false)::Tuple{LLVM.Value, Int} +function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Bool=true, inttoptr::Bool=false)::Tuple{LLVM.Value, Int} offset = 0 while true if isa(larg, LLVM.ConstantExpr) @@ -277,7 +277,7 @@ function get_base_and_offset(larg::LLVM.Value; offsetAllowed=true, inttoptr=fals end function abs_typeof( - arg::LLVM.Value, + @nospecialize(arg::LLVM.Value), partial::Bool = false, seenphis=Set{LLVM.PHIInst}() )::Union{Tuple{Bool,Type,GPUCompiler.ArgumentCC},Tuple{Bool,Nothing,Nothing}} if isa(arg, LLVM.BitCastInst) || isa(arg, LLVM.AddrSpaceCastInst) @@ -729,7 +729,7 @@ function abs_typeof( return (false, nothing, nothing) end -function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} +function abs_cstring(@nospecialize(arg::LLVM.Value))::Tuple{Bool,String} if isa(arg, ConstantExpr) ce = arg while isa(ce, ConstantExpr) diff --git a/src/compiler.jl b/src/compiler.jl index ad9367fca3..83a4efde4f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1125,7 +1125,11 @@ struct Return2 end function force_recompute!(mod::LLVM.Module) - for f in functions(mod), bb in blocks(f), inst in collect(instructions(bb)) + for f in functions(mod), bb in blocks(f) + iter = LLVM.API.LLVMGetFirstInstruction(bb) + while iter != C_NULL + inst = LLVM.Instruction(iter) + iter = LLVM.API.LLVMGetNextInstruction(iter) if isa(inst, LLVM.LoadInst) has_loaded = false for u in LLVM.uses(inst) @@ -1170,6 +1174,7 @@ function force_recompute!(mod::LLVM.Module) end end end + end end function permit_inlining!(f::LLVM.Function) @@ -3275,7 +3280,7 @@ end # Enzyme compiler step ## -function annotate!(mod, mode) +function annotate!(mod::LLVM.Module) inactive = LLVM.StringAttribute("enzyme_inactive", "") active = LLVM.StringAttribute("enzyme_active", "") no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") @@ -3891,7 +3896,7 @@ function enzyme_extract_world(fn::LLVM.Function)::UInt throw(AssertionError("Enzyme: could not find world in $(string(fn))")) end -function enzyme_custom_extract_mi(orig::LLVM.Instruction, error::Bool = true) +function enzyme_custom_extract_mi(orig::LLVM.CallInst, error::Bool = true) operand = LLVM.called_operand(orig) if isa(operand, LLVM.Function) return enzyme_custom_extract_mi(operand::LLVM.Function, error) @@ -6144,7 +6149,7 @@ end using Random # returns arg, return -function no_type_setting(@nospecialize(specTypes); world = nothing) +function no_type_setting(@nospecialize(specTypes::Type{<:Tuple}); world = nothing) # Even though the julia type here is ptr{int8}, the actual data can be something else if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_simd) return (true, false) @@ -7037,7 +7042,7 @@ end end # annotate - annotate!(mod, mode) + annotate!(mod) for name in ("gpu_report_exception", "report_exception") if haskey(functions(mod), name) exc = functions(mod)[name] @@ -8012,9 +8017,6 @@ end ::Type{TapeType}, args::Vararg{Any,N}, ) where {RawCall,PT,FA,T,RT,TapeType,N,CC,width,returnPrimal} - - JuliaContext() do ctx - Base.@_inline_meta F = eltype(FA) is_forward = CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk @@ -8263,6 +8265,10 @@ end i += 1 end + ts_ctx = JuliaContext() + ctx = context(ts_ctx) + activate(ctx) + (ir, fn, combinedReturn) = try if is_adjoint NT = Tuple{ActiveRetTypes...} @@ -8441,31 +8447,35 @@ end ir = string(mod) fn = LLVM.name(llvm_f) + (ir, fn, combinedReturn) + finally + deactivate(ctx) + dispose(ts_ctx) + end - @assert length(types) == length(ccexprs) + @assert length(types) == length(ccexprs) - if !(GPUCompiler.isghosttype(PT) || Core.Compiler.isconstType(PT)) - return quote - Base.@_inline_meta - Base.llvmcall( - ($ir, $fn), - $combinedReturn, - Tuple{$PT,$(types...)}, - fptr, - $(ccexprs...), - ) - end - else - return quote - Base.@_inline_meta - Base.llvmcall( - ($ir, $fn), - $combinedReturn, - Tuple{$(types...)}, - $(ccexprs...), - ) - end + if !(GPUCompiler.isghosttype(PT) || Core.Compiler.isconstType(PT)) + return quote + Base.@_inline_meta + Base.llvmcall( + ($ir, $fn), + $combinedReturn, + Tuple{$PT,$(types...)}, + fptr, + $(ccexprs...), + ) + end + else + return quote + Base.@_inline_meta + Base.llvmcall( + ($ir, $fn), + $combinedReturn, + Tuple{$(types...)}, + $(ccexprs...), + ) end end end @@ -9071,7 +9081,10 @@ include("compiler/reflection.jl") ) copysetfn = meta.entry blk = first(blocks(copysetfn)) - for inst in collect(instructions(blk)) + iter = LLVM.API.LLVMGetFirstInstruction(blk) + while iter != C_NULL + inst = LLVM.Instruction(iter) + iter = LLVM.API.LLVMGetNextInstruction(iter) if isa(inst, LLVM.FenceInst) eraseInst(blk, inst) end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 7648236ffb..e1204a709f 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -123,7 +123,7 @@ Core.Compiler.verbose_stmt_info(@nospecialize(::EnzymeInterpreter)) = false Core.Compiler.method_table(@nospecialize(interp::EnzymeInterpreter), sv::InferenceState) = Core.Compiler.OverlayMethodTable(interp.world, interp.method_table) -function is_alwaysinline_func(@nospecialize(TT)) +function is_alwaysinline_func(@nospecialize(TT))::Bool isa(TT, DataType) || return false @static if VERSION ≥ v"1.11-" if TT.parameters[1] == typeof(Core.memoryref) @@ -133,7 +133,7 @@ function is_alwaysinline_func(@nospecialize(TT)) return false end -function is_primitive_func(@nospecialize(TT)) +function is_primitive_func(@nospecialize(TT))::Bool isa(TT, DataType) || return false ft = TT.parameters[1] if ft == typeof(Enzyme.pmap) @@ -156,11 +156,11 @@ function is_primitive_func(@nospecialize(TT)) return false end -function isKWCallSignature(@nospecialize(TT)) +function isKWCallSignature(@nospecialize(TT))::Bool return TT <: Tuple{typeof(Core.kwcall),Any,Any,Vararg} end -function simplify_kw(@nospecialize specTypes) +function simplify_kw(@nospecialize(specTypes)) if isKWCallSignature(specTypes) return Base.tuple_type_tail(Base.tuple_type_tail(specTypes)) else @@ -742,15 +742,15 @@ end end end -@inline function array_or_number(@nospecialize(Ty)) +@inline function array_or_number(@nospecialize(Ty))::Bool return Ty <: AbstractArray || Ty <: Number end -@inline function isa_array_or_number(@nospecialize(x)) +@inline function isa_array_or_number(@nospecialize(x))::Bool return x isa AbstractArray || x isa Number end -@inline function num_or_eltype(@nospecialize(Ty)) +@inline function num_or_eltype(@nospecialize(Ty))::Type if Ty <: AbstractArray eltype(Ty) else diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 2a5c860f64..3e833324b6 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -20,7 +20,7 @@ function get_blas_symbols() return symbols end -function lookup_blas_symbol(name) +function lookup_blas_symbol(name::String) Libdl.dlsym(blas_handle::Ptr{Cvoid}, name; throw_error = false) end end @@ -127,7 +127,7 @@ function __init__() end end -function memoize!(ptr, fn) +function memoize!(ptr::Ptr{Cvoid}, fn::String)::String fn = get(ptr_map, ptr, fn) if !haskey(ptr_map, ptr) ptr_map[ptr] = fn @@ -140,7 +140,7 @@ end import GPUCompiler: IRError, InvalidIRError -function restore_lookups(mod::LLVM.Module) +function restore_lookups(mod::LLVM.Module)::Nothing T_size_t = convert(LLVM.LLVMType, Int) for (v, k) in FFI.ptr_map if haskey(functions(mod), k) @@ -421,7 +421,11 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp calls = LLVM.CallInst[] isInline = API.EnzymeGetCLBool(cglobal((:EnzymeInline, API.libEnzyme))) != 0 mod = LLVM.parent(f) - for bb in blocks(f), inst in collect(instructions(bb)) + for bb in blocks(f) + iter = LLVM.API.LLVMGetFirstInstruction(bb) + while iter != C_NULL + inst = LLVM.Instruction(iter) + iter = LLVM.API.LLVMGetNextInstruction(iter) if isa(inst, LLVM.CallInst) push!(calls, inst) # remove illegal invariant.load and jtbaa_const invariants @@ -489,7 +493,11 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp newf, _ = get_function!(mod, fname, FT) else found = nothing - for lbb in blocks(initfn), linst in collect(instructions(lbb)) + for lbb in blocks(initfn) + liter = LLVM.API.LLVMGetFirstInstruction(lbb) + while liter != C_NULL + linst = LLVM.Instruction(liter) + liter = LLVM.API.LLVMGetNextInstruction(liter) if !isa(linst, LLVM.CallInst) continue end @@ -502,6 +510,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp break end end + end if found == nothing msg = sprint() do io::IO println( @@ -630,6 +639,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp end end end + end while length(calls) > 0 inst = pop!(calls) diff --git a/src/jlrt.jl b/src/jlrt.jl index 300fdc5515..59acd1d231 100644 --- a/src/jlrt.jl +++ b/src/jlrt.jl @@ -326,7 +326,7 @@ function load_if_mixed(oval::OT, val::VT) where {OT, VT} end end -function val_from_byref_if_mixed(B::LLVM.IRBuilder, gutils::GradientUtils, @nospecialize(oval::LLVM.Value), @nospecialize(val::LLVM.Value)) +function val_from_byref_if_mixed(B::LLVM.IRBuilder, gutils::GradientUtils, @nospecialize(oval::LLVM.Value), @nospecialize(val::LLVM.Value))::LLVM.Value world = enzyme_extract_world(LLVM.parent(position(B))) legal, TT, _ = abs_typeof(oval) if !legal @@ -374,7 +374,7 @@ function ref_if_mixed(val::VT) where {VT} end end -function byref_from_val_if_mixed(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value)) +function byref_from_val_if_mixed(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value world = enzyme_extract_world(LLVM.parent(position(B))) legal, TT, _ = abs_typeof(val) if !legal diff --git a/src/rules/allocrules.jl b/src/rules/allocrules.jl index 7c611b6c85..7cd3c59b4c 100644 --- a/src/rules/allocrules.jl +++ b/src/rules/allocrules.jl @@ -1,7 +1,4 @@ - -function array_inner(::Type{<:Array{T}}) where {T} - return T -end +@inline LLT_ALIGN(x::Int, sz::Int) = (((x) + (sz) - 1) & ~((sz) - 1)) function array_shadow_handler( B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, @@ -52,8 +49,6 @@ function array_shadow_handler( isunion = typ isa Union - LLT_ALIGN(x, sz) = (((x) + (sz) - 1) & ~((sz) - 1)) - if !isunboxed elsz = sizeof(Ptr{Cvoid}) al = elsz diff --git a/src/utils.jl b/src/utils.jl index d5d0ed733a..8fc1ce4962 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,12 +5,40 @@ Assumes that `val` is globally rooted and pointer to it can be leaked. Prefer `pointer_from_objref`. Only use inside Enzyme.jl should be for Types. """ -@inline unsafe_to_pointer(val::Type{T}) where {T} = ccall( - Base.@cfunction(Base.identity, Ptr{Cvoid}, (Ptr{Cvoid},)), +@inline unsafe_to_pointer(@nospecialize(val::Type)) = @static if sizeof(Int) == sizeof(Int64) + Base.llvmcall(( +""" +declare nonnull {}* @julia.pointer_from_objref({} addrspace(11)*) + +define i64 @f({} addrspace(10)* %obj) readnone alwaysinline { + %c = addrspacecast {} addrspace(10)* %obj to {} addrspace(11)* + %r = call {}* @julia.pointer_from_objref({} addrspace(11)* %c) + %e = ptrtoint {}* %r to i64 + ret i64 %e +} +""", "f"), + Ptr{Cvoid}, + Tuple{Any}, + val, +) +else + Base.llvmcall(( +""" +declare nonnull {}* @julia.pointer_from_objref({} addrspace(11)*) + +define i32 @f({} addrspace(10)* %obj) readnone alwaysinline { + %c = addrspacecast {} addrspace(10)* %obj to {} addrspace(11)* + %r = call {}* @julia.pointer_from_objref({} addrspace(11)* %c) + %e = ptrtoint {}* %r to i32 + ret i32 %e +} +""", "f"), Ptr{Cvoid}, - (Any,), + Tuple{Any}, val, ) +end + export unsafe_to_pointer @inline is_concrete_tuple(x::Type{T2}) where {T2} = @@ -53,7 +81,7 @@ end export unsafe_to_ptr # This mimicks literal_pointer_val / literal_pointer_val_slot -function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val)) +function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val))::LLVM.Value T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) @@ -113,7 +141,7 @@ function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val)) end export unsafe_to_llvm, unsafe_nothing_to_llvm -function makeInstanceOf(B::LLVM.IRBuilder, @nospecialize(T)) +function makeInstanceOf(B::LLVM.IRBuilder, @nospecialize(T::Type)) if !Core.Compiler.isconstType(T) throw(AssertionError("Tried to make instance of non constant type $T")) end @@ -123,7 +151,7 @@ end export makeInstanceOf -function hasfieldcount(@nospecialize(dt)) +function hasfieldcount(@nospecialize(dt))::Bool try fieldcount(dt) catch @@ -240,7 +268,7 @@ export my_methodinstance # # // followed by alignment padding and inline data, or owner pointer # } jl_array_t; -@inline function typed_fieldtype(@nospecialize(T::Type), i::Int) +@inline function typed_fieldtype(@nospecialize(T::Type), i::Int)::Type if T <: Array eT = eltype(T) PT = Ptr{eT} @@ -250,7 +278,7 @@ export my_methodinstance end end -@inline function typed_fieldcount(@nospecialize(T::Type)) +@inline function typed_fieldcount(@nospecialize(T::Type))::Int if T <: Array return 7 else @@ -258,7 +286,7 @@ end end end -@inline function typed_fieldoffset(@nospecialize(T::Type), i::Int) +@inline function typed_fieldoffset(@nospecialize(T::Type), i::Int)::Int if T <: Array tys = (Ptr, Csize_t, UInt16, UInt16, UInt32, Csize_t, Csize_t) sum = 0 @@ -275,7 +303,7 @@ end else -@inline function typed_fieldtype(@nospecialize(T::Type), i::Int) +@inline function typed_fieldtype(@nospecialize(T::Type), i::Int)::Type if T <: GenericMemoryRef && i == 1 || T <: GenericMemory && i == 2 eT = eltype(T) Ptr{eT} @@ -284,11 +312,11 @@ else end end -@inline function typed_fieldcount(@nospecialize(T::Type)) +@inline function typed_fieldcount(@nospecialize(T::Type))::Int fieldcount(T) end -@inline function typed_fieldoffset(@nospecialize(T::Type), i::Int) +@inline function typed_fieldoffset(@nospecialize(T::Type), i::Int)::Int fieldoffset(T, i) end @@ -299,7 +327,7 @@ export typed_fieldcount export typed_fieldoffset # returns the inner type of an sret/enzyme_sret/enzyme_sret_v -function sret_ty(fn::LLVM.Function, idx::Int) +function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType return eltype(LLVM.value_type(LLVM.parameters(fn)[idx])) end From 6746e5a0040427dcecf470eadf704bff025f903f Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 28 Nov 2024 16:06:58 -0500 Subject: [PATCH 459/495] Further absint improvements (#2140) * Further absint improvements * Update Project.toml --------- Co-authored-by: William Moses --- src/absint.jl | 20 ++++++++++++++------ src/compiler.jl | 2 +- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 4bc607ed7f..5a72fa0873 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -576,12 +576,20 @@ function abs_typeof( end end if !seen && typed_fieldcount(typ) > 0 - offset = offset - typed_fieldoffset(typ, lasti) - typ = typed_fieldtype(typ, lasti) - @assert Base.isconcretetype(typ) - if !Base.allocatedinline(typ) - legal = false - end + offset = offset - typed_fieldoffset(typ, lasti) + typ = typed_fieldtype(typ, lasti) + if offset == 0 + if !Base.allocatedinline(typ) + if byref != GPUCompiler.BITS_VALUE + legal = false + end + byref = GPUCompiler.MUT_REF + end + else + if !Base.isconcretetype(typ) || !Base.allocatedinline(typ) + legal = false + end + end seen = true end if !seen diff --git a/src/compiler.jl b/src/compiler.jl index 83a4efde4f..c94a6c84b4 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6177,7 +6177,7 @@ function GPUCompiler.codegen( ) params = job.config.params if params.run_enzyme - @assert eltype(params.rt) != Union{} + # @assert eltype(params.rt) != Union{} end expectedTapeType = params.expectedTapeType mode = params.mode From a207b27e8ec57d7eba98ba9b77fed5f9120932a7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 28 Nov 2024 16:07:13 -0500 Subject: [PATCH 460/495] Speed up invoke (#2138) * Speed up invoke * handler specialize --- src/compiler/interpreter.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index e1204a709f..1e442482be 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -759,13 +759,13 @@ end end function abstract_call_known( - @nospecialize(interp::EnzymeInterpreter), + interp::EnzymeInterpreter{Handler}, @nospecialize(f), arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int = get_max_methods(interp, f, sv), -) +) where Handler (; fargs, argtypes) = arginfo @@ -889,7 +889,7 @@ function abstract_call_known( end return Base.@invoke abstract_call_known( interp::AbstractInterpreter, - f, + f::Any, arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, From 2bfc9b512c61a93d5f277a58d412dcf0b3c50eb3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 29 Nov 2024 01:05:36 -0500 Subject: [PATCH 461/495] Organize code out of compiler.jl (#2137) * Organize code out of compiler.jl * More cleanup * more cleaning * add file * Add file * fix * Update sugar.jl * Update sugar.jl * Update sugar.jl * fixup * fix * Update sugar.jl * Update sugar.jl * Update sugar.jl --- src/Enzyme.jl | 1036 +-------- src/analyses/activity.jl | 457 ++++ src/{typeanalysis.jl => analyses/type.jl} | 0 src/compiler.jl | 2567 +-------------------- src/errors.jl | 640 +++++ src/llvm/attributes.jl | 1060 +++++++++ src/sugar.jl | 1155 +++++++++ src/{ => typeutils}/make_zero.jl | 0 src/typeutils/recursive_add.jl | 86 + 9 files changed, 3507 insertions(+), 3494 deletions(-) create mode 100644 src/analyses/activity.jl rename src/{typeanalysis.jl => analyses/type.jl} (100%) create mode 100644 src/errors.jl create mode 100644 src/llvm/attributes.jl create mode 100644 src/sugar.jl rename src/{ => typeutils}/make_zero.jl (100%) create mode 100644 src/typeutils/recursive_add.jl diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 8c94911368..942df0581c 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -120,7 +120,7 @@ Base.convert(::Type{API.CDerivativeMode}, ::ForwardMode) = API.DEM_ForwardMode function guess_activity end include("logic.jl") -include("typeanalysis.jl") +include("analyses/type.jl") include("typetree.jl") include("gradientutils.jl") include("utils.jl") @@ -1264,7 +1264,7 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType false, #=errifwritte=# RuntimeActivity, ) - job = Compiler.CompilerJob(mi, Compiler.CompilerConfig(target, params; kernel = false)) + job = GPUCompiler.CompilerJob(mi, GPUCompiler.CompilerConfig(target, params; kernel = false)) key = hash(parent_job, hash(job)) @@ -1515,1037 +1515,7 @@ end nothing end -function zerosetfn(x, i::Int) - res = zero(x) - @inbounds res[i] = 1 - return res -end - -@inline function onehot(x::Array) - Compiler.onehot_internal(zerosetfn, x, 0, length(x)) -end - -@inline function onehot(x::Array, start::Int, endl::Int) - Compiler.onehot_internal(zerosetfn, x, start-1, endl-start+1) -end - -@inline function onehot(x::AbstractArray) - N = length(x) - ntuple(Val(N)) do i - Base.@_inline_meta - res = similar(x) - for idx = 1:N - @inbounds res[idx] = (i == idx) ? 1.0 : 0.0 - end - return res - end -end -@inline function onehot(x::AbstractArray, start::Int, endl::Int) - ntuple(Val(endl - start + 1)) do i - Base.@_inline_meta - res = similar(x) - for idx = 1:length(x) - @inbounds res[idx] = (i + start - 1 == idx) ? 1.0 : 0.0 - end - return res - end -end - -@inline function onehot(::Type{NTuple{N,T}}) where {T,N} - ntuple(Val(N)) do i - Base.@_inline_meta - ntuple(Val(N)) do idx - Base.@_inline_meta - return (i == idx) ? T(1) : T(0) - end - end -end -@inline onehot(x::Tuple{}) = () -@inline function onehot(x::NTuple{N,T}) where {T,N} - onehot(NTuple{N,T}) -end -@inline function onehot(x::NTuple{N,T}, start, endl) where {T,N} - ntuple(Val(endl - start + 1)) do i - Base.@_inline_meta - ntuple(Val(N)) do idx - Base.@_inline_meta - return (i + start - 1 == idx) ? T(1) : T(0) - end - end -end - -@inline function onehot(x::AbstractFloat) - return (one(x),) -end - -""" - gradient(::ReverseMode, f, args...) - -Compute the gradient of a real-valued function `f` using reverse mode. -For each differentiable argument, this function will allocate and return new derivative object, returning -a tuple of derivatives for each argument. If an argument is not differentiable, the element of the returned -tuple with be nothing. - -In reverse mode (here), the derivatives will be the same type as the original argument. - -This is a structure gradient. For a struct `x` it returns another instance of the same type, -whose fields contain the components of the gradient. -In the result, `grad.a` contains `∂f/∂x.a` for any differential `x.a`, -while `grad.c == x.c` for other types. - -Examples: - -```jldoctest gradient -f(x) = x[1]*x[2] - -grad = gradient(Reverse, f, [2.0, 3.0]) - -# output -([3.0, 2.0],) -``` - -```jldoctest gradient -grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) - -# output - -((a = 3.0, b = [2.0], c = "str"),) -``` - -```jldoctest gradient -mul(x, y) = x[1]*y[1] - -grad = gradient(Reverse, mul, [2.0], [3.0]) - -# output -([3.0], [2.0]) -``` - -```jldoctest gradient - -grad = gradient(Reverse, mul, [2.0], Const([3.0])) - -# output -([3.0], nothing) -``` - -If passing a mode that returns the primal (e.g. ReverseWithPrimal), the return type will instead be -a tuple where the first element contains the derivatives, and the second element contains the result of the original computation. - -```jldoctest gradient - -grad = gradient(ReverseWithPrimal, f, [2.0, 3.0]) - -# output -(derivs = ([3.0, 2.0],), val = 6.0) -``` -```jldoctest gradient - -grad = gradient(ReverseWithPrimal, mul, [2.0], [3.0]) - -# output -(derivs = ([3.0], [2.0]), val = 6.0) -``` - -```jldoctest gradient -grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) - -# output -(derivs = ([3.0], nothing), val = 6.0) -``` - -""" -@generated function gradient( - rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, - f::F, - x::ty_0, - args::Vararg{Any,N}, -) where {F,ty_0,ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten,N} - toemit = Expr[quote - act_0 = - !(x isa Enzyme.Const) && - Compiler.active_reg_inner(Core.Typeof(x), (), nothing, Val(true)) == - Compiler.ActiveState #=justActive=# - end] - rargs = Union{Symbol,Expr}[:x] - acts = Symbol[Symbol("act_0")] - - for i = 1:N - argidx = quote - args[$i] - end - push!(rargs, argidx) - sym = Symbol("act_$i") - push!(acts, sym) - push!( - toemit, - quote - $sym = - !($argidx isa Enzyme.Const) && - Compiler.active_reg_inner( - Core.Typeof($argidx), - (), - nothing, - Val(true), - ) == Compiler.ActiveState #=justActive=# - end, - ) - end - - idx = 0 - shadows = Symbol[] - enz_args = Expr[] - resargs = Expr[] - for (arg, act) in zip(rargs, acts) - shad = Symbol("shad_$idx") - push!(shadows, shad) - push!(toemit, quote - $shad = if $arg isa Enzyme.Const - nothing - elseif $act - Ref(make_zero($arg)) - else - make_zero($arg) - end - end) - push!(enz_args, quote - if $arg isa Enzyme.Const - $arg - elseif $act - MixedDuplicated($arg, $shad) - else - Duplicated($arg, $shad) - end - end) - push!(resargs, quote - if $arg isa Enzyme.Const - nothing - elseif $act - $shad[] - else - $shad - end - end) - idx += 1 - end - push!(toemit, quote - res = autodiff(rm, f, Active, $(enz_args...)) - end) - - if ReturnPrimal - return quote - Base.@_inline_meta - $(toemit...) - (; derivs = ($(resargs...),), val = res[2]) - end - else - return quote - Base.@_inline_meta - $(toemit...) - ($(resargs...),) - end - end -end - -""" - gradient!(::ReverseMode, dx, f, x) - -Compute the gradient of an array-input function `f` using reverse mode, -storing the derivative result in an existing array `dx`. -Both `x` and `dx` must be `Array`s of the same type. - -Example: - -```jldoctest gradip -f(x) = x[1]*x[2] - -dx = [0.0, 0.0] -gradient!(Reverse, dx, f, [2.0, 3.0]) - -# output -([3.0, 2.0],) -``` - -```jldoctest gradip -dx = [0.0, 0.0] -gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0]) - -# output -(derivs = ([3.0, 2.0],), val = 6.0) -``` -""" -@inline function gradient!( - rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, - dx::X, - f::F, - x::X, -) where {X<:Array,F,ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} - make_zero!(dx) - res = autodiff(rm, f, Active, Duplicated(x, dx)) - return if ReturnPrimal - (; derivs = (dx,), val = res[2]) - else - (dx,) - end -end - -@inline function chunkedonehot(x, ::Val{chunk}) where {chunk} - sz = length(x) - num = ((sz + chunk - 1) ÷ chunk) - ntuple(Val(num)) do i - Base.@_inline_meta - onehot(x, (i - 1) * chunk + 1, i == num ? sz : (i * chunk)) - end -end - -@inline function chunkedonehot(x::AbstractFloat, ::Val{chunk}) where {chunk} - return ((one(x),),) -end - -@inline tupleconcat(x) = x -@inline tupleconcat(x, y) = (x..., y...) -@inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...) - -@generated function create_shadows(chunk::ChunkTy, x::X, vargs::Vararg{Any,N}) where {ChunkTy, X, N} - args = Union{Symbol,Expr}[:x] - tys = Type[X] - for i in 1:N - push!(args, :(vargs[$i])) - push!(tys, vargs[i]) - end - - exprs = Union{Symbol,Expr}[] - for (arg, ty) in zip(args, tys) - if ty <: Enzyme.Const - push!(exprs, :(nothing)) - elseif ty <: AbstractFloat - push!(exprs, :(nothing)) - elseif ChunkTy == Nothing || ChunkTy == Val{1} - push!(exprs, :(onehot($arg))) - else - push!(exprs, :(chunkedonehot($arg, chunk))) - end - end - return quote - Base.@_inline_meta - ($(exprs...),) - end -end - -struct TupleArray{T,Shape,Length,N} <: AbstractArray{T,N} - data::NTuple{Length,T} -end -TupleArray(data::NTuple{Length,T}, Shape) where {Length,T} = - TupleArray{T,Shape,Length,length(Shape)}(data) - -@inline Base.eltype(::TupleArray{T}) where {T} = T -@inline Base.eltype(::Type{<:TupleArray{T}}) where {T} = T -@inline Base.size(::TupleArray{<:Any,Shape}) where {Shape} = Shape -@inline Base.ndims(::TupleArray{<:Any,<:Any,<:Any,N}) where {N} = N - -function Base.convert( - ::Type{Array{T,N}}, - X::TupleArray{T,Shape,Length,N}, -) where {T,Shape,Length,N} - vals = Array{T,N}(undef, Shape...) - for i = 1:Length - @inbounds val[i] = X.data[i] - end - return vals -end - -function Base.getindex(a::TupleArray, args::Vararg{Int,N}) where {N} - start = 0 - for i = 1:N - start *= size(a, N - i + 1) - start += (args[N-i+1] - 1) - end - start += 1 - return a.data[start] -end - -@inline function tupstack(data::Tuple{Vararg{Array{T}}}, outshape::Tuple{Vararg{Int}}, inshape::Tuple{Vararg{Int}}) where {T} - num = prod(outshape) - res = Array{T}(undef, outshape..., inshape...) - for (i, val) in enumerate(data) - Base.unsafe_copyto!(res, num*(i-1)+1, val, 1, Base.reinterpret(UInt, num)) - end - res -end - -@inline function tupstack(x, outshape::Tuple{Vararg{Int}}, inshape::Tuple{Vararg{Int}}) - st = Base.stack(x) - if length(outshape) == 1 - st - else - reshape(st, (outshape..., inshape...)) - end -end - -@inline specialize_output(output, input) = output - -""" - gradient(::ForwardMode, f, x; shadows=onehot(x), chunk=nothing) - -Compute the gradient of an array-input function `f` using forward mode. The -optional keyword argument `shadow` is a vector of one-hot vectors of type `x` -which are used to forward-propagate into the return. For performance reasons, -this should be computed once, outside the call to `gradient`, rather than -within this call. - -Example: - -```jldoctest gradfwd -f(x) = x[1]*x[2] - -gradient(Forward, f, [2.0, 3.0]) - -# output - -([3.0, 2.0],) -``` - -```jldoctest gradfwd -gradient(ForwardWithPrimal, f, [2.0, 3.0]) - -# output -(derivs = ([3.0, 2.0],), val = 6.0) -``` - -```jldoctest gradfwd -gradient(Forward, f, [2.0, 3.0]; chunk=Val(1)) - -# output - -([3.0, 2.0],) -``` - -```jldoctest gradfwd -gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=Val(1)) - -# output -(derivs = ([3.0, 2.0],), val = 6.0) -``` - -For functions which return an AbstractArray or scalar, this function will return an AbstractArray -whose shape is `(size(output)..., size(input)...)`. No guarantees are presently made -about the type of the AbstractArray returned by this function (which may or may not be the same -as the input AbstractArray if provided). - -For functions who return other types, this function will retun an AbstractArray -of shape `size(input)` of values of the output type. -```jldoctest -f(x) = [ x[1] * x[2], x[2] + x[3] ] - -grad = gradient(Forward, f, [2.0, 3.0, 4.0]) - -# output -([3.0 2.0 0.0; 0.0 1.0 1.0],) -``` - -This function supports multiple arguments and computes the gradient with respect to each - -```jldoctest gradfwd2 -mul(x, y) = x[1]*y[2] + x[2]*y[1] - -gradient(Forward, mul, [2.0, 3.0], [2.7, 3.1]) - -# output - -([3.1, 2.7], [3.0, 2.0]) -``` - -This includes the ability to mark some arguments as `Const` if its derivative is not needed, returning nothing in the corresponding derivative map. - -```jldoctest gradfwd2 -gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1])) - -# output - -([3.1, 2.7], nothing) -``` -""" -@generated function gradient( - fm::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, - f::F, - x::ty_0, - args::Vararg{Any,N}; - chunk::CS = nothing, - shadows::ST = create_shadows(chunk, x, args...), -) where {F, ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,CS,ST, ty_0, N} - - syms = Union{Symbol,Expr}[:x] - shads = Union{Symbol,Expr}[:(shadows[1])] - tys = Type[ty_0] - for i in 1:N - push!(syms, :(args[$i])) - push!(tys, args[i]) - push!(shads, :(shadows[1+$i])) - end - fval = if F <: Annotation - :(f.val) - else - :f - end - - vals = Union{Symbol,Expr}[] - consts = Union{Symbol,Expr}[] - for (arg, ty) in zip(syms, tys) - if ty <: Const - push!(vals, :($arg.val)) - push!(consts, arg) - else - push!(vals, arg) - push!(consts, :(Const($arg))) - end - end - - if CS == Val{0} - return quote - Base.@_inline_meta - throw(ErrorException("Cannot differentiate with a batch size of 0")) - end - end - - exprs = Union{Symbol,Expr}[] - primal = nothing - derivatives = Union{Symbol,Expr}[] - - primmode = :(fm) - for (i, (arg, ty)) in enumerate(zip(syms, tys)) - if ty <: Const - push!(derivatives, :(nothing)) - continue - end - - argnum = length(ST.parameters[i].parameters) - - argderivative = if ty <: AbstractFloat - dargs = Union{Symbol,Expr}[] - for (j, arg2) in enumerate(syms) - if i == j - push!(dargs, :(Duplicated($arg, one($arg)))) - else - push!(dargs, consts[j]) - end - end - - resp = Symbol("resp_$i") - push!(exprs, quote - $resp = autodiff($primmode, f, Duplicated, $(dargs...)) - end) - if ReturnPrimal && primal == nothing - primal = :($resp[2]) - primmode = NoPrimal(fm()) - end - - :($resp[1]) - elseif argnum == 0 - vals[i] - elseif CS == Nothing - dargs = Union{Symbol,Expr}[] - for (j, arg2) in enumerate(syms) - if i == j - push!(dargs, :(BatchDuplicated($arg, $(shads[i])))) - else - push!(dargs, consts[j]) - end - end - - df = :f - if F <: Enzyme.Duplicated - zeros = Expr[] - for i in 1:argnum - push!(zeros, :(f.dval)) - end - df = :(BatchDuplicated(f.val, ($(zeros...),) )) - end - - resp = Symbol("resp_$i") - push!(exprs, quote - $resp = autodiff($primmode, $df, BatchDuplicated, $(dargs...)) - end) - if ReturnPrimal && primal == nothing - primal = :($resp[2]) - primmode = NoPrimal(fm()) - end - - :(values($resp[1])) - elseif CS == Val{1} - subderivatives = Union{Symbol,Expr}[] - for an in 1:argnum - dargs = Union{Symbol,Expr}[] - for (j, arg2) in enumerate(syms) - if i == j - push!(dargs, :(Duplicated($arg, $(shads[i])[$an]))) - else - push!(dargs, consts[j]) - end - end - - resp = Symbol("resp_$i"*"_"*string(an)) - push!(exprs, quote - $resp = autodiff($primmode, f, Duplicated, $(dargs...)) - end) - if ReturnPrimal && primal == nothing - primal = :($resp[2]) - primmode = NoPrimal(fm()) - end - - push!(subderivatives, :(values($resp[1]))) - end - :(($(subderivatives...),)) - else - subderivatives = Union{Symbol,Expr}[] - for an in 1:argnum - dargs = Union{Symbol,Expr}[] - for (j, arg2) in enumerate(syms) - if i == j - push!(dargs, :(BatchDuplicated($arg, $(shads[i])[$an]))) - else - push!(dargs, consts[j]) - end - end - - resp = Symbol("resp_$i"*"_"*string(an)) - push!(exprs, quote - $resp = autodiff($primmode, f, BatchDuplicated, $(dargs...)) - end) - if ReturnPrimal && primal == nothing - primal = :($resp[2]) - primmode = NoPrimal(fm()) - end - - push!(subderivatives, :(values($resp[1]))) - end - :(tupleconcat($(subderivatives...))) - end - - deriv = if ty <: AbstractFloat - argderivative - else - tmp = Symbol("tmp_$i") - push!(exprs, :($tmp = $argderivative)) - if ty <: AbstractArray - if argnum > 0 - quote - if $tmp[1] isa AbstractArray - inshape = size($(vals[1])) - outshape = size($tmp[1]) - # st : outshape x total inputs - tupstack($tmp, outshape, inshape) - else - specialize_output(TupleArray($tmp, size($arg)), $(vals[1])) - end - end - else - tmp - end - else - tmp - end - end - push!(derivatives, deriv) - end - - # We weirdly asked for no derivatives - if ReturnPrimal && primal == nothing - primal = :($fval($(vals...))) - end - - result = if ReturnPrimal - :((; derivs = ($(derivatives...),), val = $primal)) - else - :(($(derivatives...),)) - end - - return quote - Base.@_inline_meta - $(exprs...) - $result - end -end - -""" - jacobian(::ForwardMode, args...; kwargs...) - -Equivalent to gradient(::ForwardMode, args...; kwargs...) -""" -@inline function jacobian(fm::ForwardMode, args...; kwargs...) - gradient(fm, args...; kwargs...) -end - -""" - jacobian(::ReverseMode, f, x; n_outs=nothing, chunk=nothing) - jacobian(::ReverseMode, f, x) - -Compute the jacobian of a array-output function `f` using (potentially vector) -reverse mode. The `chunk` argument optionally denotes the chunk size to use and -`n_outs` optionally denotes the shape of the array returned by `f` (e.g `size(f(x))`). - -Example: - -```jldoctest -f(x) = [ x[1] * x[2], x[2] + x[3] ] - -jacobian(Reverse, f, [2.0, 3.0, 4.0]) - -# output -([3.0 2.0 0.0; 0.0 1.0 1.0],) -``` - -```jldoctest -f(x) = [ x[1] * x[2], x[2] + x[3] ] - -grad = jacobian(ReverseWithPrimal, f, [2.0, 3.0, 4.0]) - -# output -(derivs = ([3.0 2.0 0.0; 0.0 1.0 1.0],), val = [6.0, 7.0]) -``` - -```jldoctest -f(x) = [ x[1] * x[2], x[2] + x[3] ] - -grad = jacobian(Reverse, f, [2.0, 3.0, 4.0], n_outs=Val((2,))) - -# output -([3.0 2.0 0.0; 0.0 1.0 1.0],) -``` - -```jldoctest -f(x) = [ x[1] * x[2], x[2] + x[3] ] - -grad = jacobian(ReverseWithPrimal, f, [2.0, 3.0, 4.0], n_outs=Val((2,))) - -# output -(derivs = ([3.0 2.0 0.0; 0.0 1.0 1.0],), val = [6.0, 7.0]) -``` - -This function will return an AbstractArray whose shape is `(size(output)..., size(input)...)`. -No guarantees are presently made about the type of the AbstractArray returned by this function -(which may or may not be the same as the input AbstractArray if provided). - -In the future, when this function is extended to handle non-array return types, -this function will retun an AbstractArray of shape `size(output)` of values of the input type. -``` -""" -@inline function jacobian( - mode::ReverseMode{ReturnPrimal,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}, - f::F, - x::X; - n_outs::OutType = nothing, - chunk::CT = nothing, -) where {ReturnPrimal,F,X,RABI<:ABI,ErrIfFuncWritten,RuntimeActivity,OutType,CT,Holomorphic} - - if n_outs == nothing - res = if f isa Const - f.val(x) - else - f(x) - end - jac = if res isa AbstractArray - jacobian( - ReverseMode{false,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}(), - f, - x; - n_outs = Val(size(res)), - chunk, - ) - elseif res isa AbstractFloat - gradient( - ReverseMode{false,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}(), - f, - x, - ) - else - throw( - AssertionError( - "Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))", - ), - ) - end - - return if ReturnPrimal - (; derivs = jac, val = res) - else - jac - end - else - n_out_val = if length(Compiler.element(n_outs)) == 0 - 0 - else - prod(Compiler.element(n_outs)) - end - - if chunk == Val(0) - throw(ErrorException("Cannot differentiate with a batch size of 0")) - end - - XT = Core.Typeof(x) - MD = Compiler.active_reg_inner(XT, (), nothing, Val(true)) == Compiler.ActiveState #=justActive=# - tt = Tuple{XT} - FRT = if f isa Const - Core.Typeof(f.val) - else - Core.Typeof(f) - end - - rt = Compiler.primal_return_type(mode, FRT, tt) - - ModifiedBetweenT = (false, false) - FA = Const{FRT} - - if chunk == Val(1) || chunk == nothing - primal, adjoint = autodiff_thunk( - ReverseModeSplit{ - #=ReturnPrimal=#false, - #=ReturnShadow=#true, - RuntimeActivity, - #=width=#1, - ModifiedBetweenT, - RABI, - Holomorphic, - ErrIfFuncWritten, - #=ShadowInit=#false - }(), - FA, - DuplicatedNoNeed{rt}, - MD ? MixedDuplicated{XT} : Duplicated{XT} - ) - tmp = ntuple(Val(n_out_val)) do i - Base.@_inline_meta - z = make_zero(x) - dx = MD ? Ref(z) : z - res = primal(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx)) - tape = res[1] - @inbounds res[3][i] += Compiler.default_adjoint(eltype(typeof(res[3]))) - adjoint(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx), tape) - return MD ? dx[] : dx, (i == 1 ? size(res[3]) : nothing) - end - rows = map(first, tmp) - outshape = tmp[1][2] - rows, outshape - else - chunksize = Compiler.element(chunk) - primal, adjoint = autodiff_thunk( - ReverseModeSplit{ - #=ReturnPrimal=#false, - #=ReturnShadow=#true, - RuntimeActivity, - chunksize, - ModifiedBetweenT, - RABI, - Holomorphic, - ErrIfFuncWritten, - #=ShadowInit=#false - }(), - FA, - BatchDuplicatedNoNeed{rt, chunksize}, - MD ? BatchMixedDuplicated{XT, chunksize} : BatchDuplicated{XT, chunksize} - ) - - num = ((n_out_val + chunksize - 1) ÷ chunksize) - - if num * chunksize == n_out_val - last_size = chunksize - primal2, adjoint2 = primal, adjoint - else - last_size = n_out_val - (num - 1) * chunksize - tt′ = Tuple{BatchDuplicated{Core.Typeof(x),last_size}} - primal2, adjoint2 = autodiff_thunk( - ReverseModeSplit{ - #=ReturnPrimal=#false, - #=ReturnShadow=#true, - RuntimeActivity, - last_size, - ModifiedBetweenT, - RABI, - Holomorphic, - ErrIfFuncWritten, - #=ShadowInit=#false - }(), - FA, - BatchDuplicatedNoNeed{rt, last_size}, - MD ? BatchMixedDuplicated{XT, last_size} : BatchDuplicated{XT, last_size} - ) - end - - tmp = ntuple(num) do i - Base.@_inline_meta - dx = ntuple(Val(i == num ? last_size : chunksize)) do idx - Base.@_inline_meta - z = make_zero(x) - MD ? Ref(z) : z - end - res = (i == num ? primal2 : primal)( - Const(f), - MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), - ) - tape = res[1] - j = 0 - for shadow in res[3] - j += 1 - @inbounds shadow[(i-1)*chunksize+j] += - Compiler.default_adjoint(eltype(typeof(shadow))) - end - (i == num ? adjoint2 : adjoint)( - Const(f), - MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), - tape, - ) - return MD ? ( - ntuple(Val(i == num ? last_size : chunksize)) do idx - Base.@_inline_meta - dx[idx][] - end - ) : dx, - (i == 1 ? size(res[3][1]) : nothing) - end - rows = tupleconcat(map(first, tmp)...) - outshape = tmp[1][2] - rows, outshape - end - res = if x isa AbstractArray - inshape = size(x) - st2 = tupstack(rows, inshape, outshape) - - st3 = if length(outshape) == 1 && length(inshape) == 1 - transpose(st2) - else - transp = ( - ((length(inshape)+1):(length(inshape)+length(outshape)))..., - (1:length(inshape))..., - ) - PermutedDimsArray(st2, transp) - end - - st3 - else - reshape(collect(rows), outshape) - end - if ReturnPrimal - # TODO optimize away redundant fwd pass - (; derivs = (res,), val = if f isa Enzyme.Const - f.val(x) - else - f(x) - end) - else - (res,) - end - end -end - -""" - hvp(f::F, x::X, v::X) where {F, X} - -Compute the Hessian-vector product of an array-input scalar-output function `f`, as evaluated at `x` times the vector `v`. - -In other words, compute hessian(f)(x) * v - -See [`hvp!`](@ref) for a version which stores the result in an existing buffer and also [`hvp_and_gradient!`](@ref) for a function to compute both the hvp and the gradient in a single call. - -Example: - -```jldoctest hvp; filter = r"([0-9]+\\.[0-9]{8})[0-9]+" => s"\\1***" -f(x) = sin(x[1] * x[2]) - -hvp(f, [2.0, 3.0], [5.0, 2.7]) - -# output -2-element Vector{Float64}: - 19.6926882637302 - 16.201003759768003 -``` -""" -@inline function hvp(f::F, x::X, v::X) where {F,X} - res = make_zero(x) - hvp!(res, f, x, v) - return res -end - - -""" - hvp!(res::X, f::F, x::X, v::X) where {F, X} - -Compute an in-place Hessian-vector product of an array-input scalar-output function `f`, as evaluated at `x` times the vector `v`. -The result will be stored into `res`. The function still allocates and zero's a buffer to store the intermediate gradient, which is -not returned to the user. - -In other words, compute res .= hessian(f)(x) * v - -See [`hvp_and_gradient!`](@ref) for a function to compute both the hvp and the gradient in a single call. - -Example: - -```jldoctest hvpip; filter = r"([0-9]+\\.[0-9]{8})[0-9]+" => s"\\1***" -f(x) = sin(x[1] * x[2]) - -res = Vector{Float64}(undef, 2) -hvp!(res, f, [2.0, 3.0], [5.0, 2.7]) - -res -# output -2-element Vector{Float64}: - 19.6926882637302 - 16.201003759768003 -``` -""" -@inline function hvp!(res::X, f::F, x::X, v::X) where {F,X} - grad = make_zero(x) - Enzyme.autodiff( - Forward, - gradient!, - Const(Reverse), - DuplicatedNoNeed(grad, res), - Const(f), - Duplicated(x, v), - ) - return nothing -end - - - -""" - hvp_and_gradient!(res::X, grad::X, f::F, x::X, v::X) where {F, X} - -Compute an in-place Hessian-vector product of an array-input scalar-output function `f`, as evaluated at `x` times the vector `v` as well as -the gradient, storing the gradient into `grad`. Both the hessian vector product and the gradient can be computed together more efficiently -than computing them separately. - -The result will be stored into `res`. The gradient will be stored into `grad`. - -In other words, compute res .= hessian(f)(x) * v and grad .= gradient(Reverse, f)(x) - -Example: - -```jldoctest hvp_and_gradient; filter = r"([0-9]+\\.[0-9]{8})[0-9]+" => s"\\1***" -f(x) = sin(x[1] * x[2]) - -res = Vector{Float64}(undef, 2) -grad = Vector{Float64}(undef, 2) -hvp_and_gradient!(res, grad, f, [2.0, 3.0], [5.0, 2.7]) - -res -grad -# output -2-element Vector{Float64}: - 2.880510859951098 - 1.920340573300732 -``` -""" -@inline function hvp_and_gradient!(res::X, grad::X, f::F, x::X, v::X) where {F,X} - Enzyme.autodiff( - Forward, - gradient!, - Const(Reverse), - Duplicated(grad, res), - Const(f), - Duplicated(x, v), - ) - return nothing -end - +include("sugar.jl") function _import_frule end # defined in EnzymeChainRulesCoreExt extension diff --git a/src/analyses/activity.jl b/src/analyses/activity.jl new file mode 100644 index 0000000000..f3dcd3a877 --- /dev/null +++ b/src/analyses/activity.jl @@ -0,0 +1,457 @@ +@enum ActivityState begin + AnyState = 0 + ActiveState = 1 + DupState = 2 + MixedState = 3 +end + +@inline function Base.:|(a1::ActivityState, a2::ActivityState) + ActivityState(Int(a1) | Int(a2)) +end + +struct Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed} + world::worldT +end + +@inline element(::Val{T}) where {T} = T + +@inline function (c::Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed})( + f::Int, +) where {seen,worldT,justActive,UnionSret,AbstractIsMixed} + T = element(first(seen)) + + reftype = ismutabletype(T) || (T isa UnionAll && !AbstractIsMixed) + + if justActive && reftype + return Val(AnyState) + end + + subT = typed_fieldtype(T, f) + + if justActive && ismutabletype(subT) + return Val(AnyState) + end + + sub = active_reg_inner( + subT, + seen, + c.world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) + + if sub == AnyState + Val(AnyState) + else + if sub == DupState + if justActive + Val(AnyState) + else + Val(DupState) + end + else + if reftype + Val(DupState) + else + Val(sub) + end + end + end +end + +@inline forcefold(::Val{RT}) where {RT} = RT + +@inline function forcefold(::Val{ty}, ::Val{sty}, C::Vararg{Any,N}) where {ty,sty,N} + if sty == AnyState || sty == ty + return forcefold(Val(ty), C...) + end + if ty == AnyState + return forcefold(Val(sty), C...) + else + return MixedState + end +end + +@inline numbereltype(::Type{<:EnzymeCore.RNumber{T}}) where {T} = T +@inline ptreltype(::Type{<:EnzymeCore.RArray{T}}) where {T} = T +@inline ptreltype(::Type{Ptr{T}}) where {T} = T +@inline ptreltype(::Type{Core.LLVMPtr{T,N}}) where {T,N} = T +@inline ptreltype(::Type{Core.LLVMPtr{T} where N}) where {T} = T +@inline ptreltype(::Type{Base.RefValue{T}}) where {T} = T +@inline ptreltype(::Type{Array{T,N}}) where {T,N} = T +@inline ptreltype(::Type{Array{T,N} where N}) where {T} = T +@inline ptreltype(::Type{Complex{T}}) where {T} = T +@inline ptreltype(::Type{Tuple{Vararg{T}}}) where {T} = T +@inline ptreltype(::Type{IdDict{K,V}}) where {K,V} = V +@inline ptreltype(::Type{IdDict{K,V} where K}) where {V} = V +@inline ptreltype(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = T +@static if VERSION < v"1.11-" +else +@inline ptreltype(::Type{Memory{T}}) where T = T +end + +@inline is_arrayorvararg_ty(::Type) = false +@inline is_arrayorvararg_ty(::Type{Array{T,N}}) where {T,N} = true +@inline is_arrayorvararg_ty(::Type{Array{T,N} where N}) where {T} = true +@inline is_arrayorvararg_ty(::Type{Tuple{Vararg{T2}}}) where {T2} = true +@inline is_arrayorvararg_ty(::Type{Ptr{T}}) where {T} = true +@inline is_arrayorvararg_ty(::Type{Core.LLVMPtr{T,N}}) where {T,N} = true +@inline is_arrayorvararg_ty(::Type{Core.LLVMPtr{T,N} where N}) where {T} = true +@inline is_arrayorvararg_ty(::Type{Base.RefValue{T}}) where {T} = true +@inline is_arrayorvararg_ty(::Type{IdDict{K,V}}) where {K,V} = true +@inline is_arrayorvararg_ty(::Type{IdDict{K,V} where K}) where {V} = true +@inline is_arrayorvararg_ty(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = true +@static if VERSION < v"1.11-" +else +@inline is_arrayorvararg_ty(::Type{Memory{T}}) where T = true +end + +@inline function datatype_fieldcount(t::Type{T}) where {T} + return Base.datatype_fieldcount(t) +end + +@inline function staticInTup(::Val{T}, tup::NTuple{N,Val}) where {T,N} + any(ntuple(Val(N)) do i + Base.@_inline_meta + Val(T) == tup[i] + end) +end + +@inline function active_reg_recur( + ::Type{ST}, + seen::Seen, + world, + ::Val{justActive}, + ::Val{UnionSret}, + ::Val{AbstractIsMixed}, +) where {ST,Seen,justActive,UnionSret,AbstractIsMixed} + if ST isa Union + return forcefold( + Val( + active_reg_recur( + ST.a, + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ), + ), + Val( + active_reg_recur( + ST.b, + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ), + ), + ) + end + return active_reg_inner( + ST, + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) +end + +@inline is_vararg_tup(x) = false +@inline is_vararg_tup(::Type{Tuple{Vararg{T2}}}) where {T2} = true + +@inline function active_reg_inner( + ::Type{T}, + seen::ST, + world::Union{Nothing,UInt}, + ::Val{justActive} = Val(false), + ::Val{UnionSret} = Val(false), + ::Val{AbstractIsMixed} = Val(false), +)::ActivityState where {ST,T,justActive,UnionSret,AbstractIsMixed} + if T === Any + if AbstractIsMixed + return MixedState + else + return DupState + end + end + + if T === Union{} + return AnyState + end + + if T <: Complex && !(T isa UnionAll) + return active_reg_inner( + ptreltype(T), + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) + end + + if T <: BigFloat + return DupState + end + + if T <: AbstractFloat + return ActiveState + end + + if T <: EnzymeCore.RNumber + return active_reg_inner( + numbereltype(T), + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) + end + + if T <: Ptr || + T <: Core.LLVMPtr || + T <: Base.RefValue || + T <: Array || T <: EnzymeCore.RArray + is_arrayorvararg_ty(T) + if justActive + return AnyState + end + + if is_arrayorvararg_ty(T) && + active_reg_inner( + ptreltype(T), + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) == AnyState + return AnyState + else + if AbstractIsMixed && is_vararg_tup(T) + return MixedState + else + return DupState + end + end + end + + if T <: Integer + return AnyState + end + + if isghostty(T) || Core.Compiler.isconstType(T) || T <: Type + return AnyState + end + + inactivety = if typeof(world) === Nothing + EnzymeCore.EnzymeRules.inactive_type(T) + else + inmi = my_methodinstance( + typeof(EnzymeCore.EnzymeRules.inactive_type), + Tuple{Type{T}}, + world, + ) + args = Any[EnzymeCore.EnzymeRules.inactive_type, T] + GC.@preserve T begin + ccall( + :jl_invoke, + Any, + (Any, Ptr{Any}, Cuint, Any), + EnzymeCore.EnzymeRules.inactive_type, + args, + length(args), + inmi, + ) + end + end + + if inactivety + return AnyState + end + + # unknown number of fields + if T isa UnionAll + aT = Base.argument_datatype(T) + if aT === nothing + if AbstractIsMixed + return MixedState + else + return DupState + end + end + if datatype_fieldcount(aT) === nothing + if AbstractIsMixed + return MixedState + else + return DupState + end + end + end + + if T isa Union + # if sret union, the data is stored in a stack memory location and is therefore + # not unique'd preventing the boxing of the union in the default case + if UnionSret && is_sret_union(T) + return active_reg_recur( + T, + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) + else + if justActive + return AnyState + end + if active_reg_inner(T.a, seen, world, Val(justActive), Val(UnionSret)) != + AnyState + if AbstractIsMixed + return MixedState + else + return DupState + end + end + if active_reg_inner(T.b, seen, world, Val(justActive), Val(UnionSret)) != + AnyState + if AbstractIsMixed + return MixedState + else + return DupState + end + end + end + return AnyState + end + + # if abstract it must be by reference + if Base.isabstracttype(T) || T == Tuple + if AbstractIsMixed + return MixedState + else + return DupState + end + end + + if ismutabletype(T) + # if just looking for active of not + # we know for a fact this isn't active + if justActive + return AnyState + end + end + + @assert !Base.isabstracttype(T) + if !(Base.isconcretetype(T) || T <: Tuple || T isa UnionAll) + throw(AssertionError("Type $T is not concrete type or concrete tuple")) + end + + nT = if T <: Tuple && !(T isa UnionAll) + Tuple{( + ntuple(length(T.parameters)) do i + Base.@_inline_meta + sT = T.parameters[i] + if sT isa TypeVar + Any + elseif sT isa Core.TypeofVararg + Any + else + sT + end + end + )...} + else + T + end + + if staticInTup(Val(nT), seen) + return MixedState + end + + seen2 = (Val(nT), seen...) + + fty = Merger{seen2,typeof(world),justActive,UnionSret,AbstractIsMixed}(world) + + ty = forcefold(Val(AnyState), ntuple(fty, Val(fieldcount(nT)))...) + + return ty +end + +@inline @generated function active_reg_nothrow(::Type{T}, ::Val{world}) where {T,world} + return active_reg_inner(T, (), world) +end + +Base.@pure @inline function active_reg( + ::Type{T}, + world::Union{Nothing,UInt} = nothing, +)::Bool where {T} + seen = () + + # check if it could contain an active + if active_reg_inner(T, seen, world, Val(true)) == ActiveState #=justActive=# + state = active_reg_inner(T, seen, world, Val(false)) #=justActive=# + if state == ActiveState + return true + end + @assert state == MixedState + throw( + AssertionError( + string(T) * + " has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information", + ), + ) + else + return false + end +end + +@inline function guaranteed_const(::Type{T}) where {T} + rt = active_reg_nothrow(T, Val(nothing)) + res = rt == AnyState + return res +end + +@inline function guaranteed_const_nongen(::Type{T}, world) where {T} + rt = active_reg_inner(T, (), world) + res = rt == AnyState + return res +end + +# check if a value is guaranteed to be not contain active[register] data +# (aka not either mixed or active) +@inline function guaranteed_nonactive(::Type{T}) where {T} + rt = Enzyme.Compiler.active_reg_nothrow(T, Val(nothing)) + return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState +end + +""" + Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) + +Try to guess the most appropriate [`Annotation`](@ref) for arguments of type `T` passed to [`autodiff`](@ref) with a given `mode`. +""" +@inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = + guess_activity(T, convert(API.CDerivativeMode, mode)) + +@inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T} + ActReg = active_reg_inner(T, (), nothing) + if ActReg == AnyState + return Const{T} + end + if Mode == API.DEM_ForwardMode + return Duplicated{T} + else + if ActReg == ActiveState + return Active{T} + elseif ActReg == MixedState + return MixedDuplicated{T} + else + return Duplicated{T} + end + end +end diff --git a/src/typeanalysis.jl b/src/analyses/type.jl similarity index 100% rename from src/typeanalysis.jl rename to src/analyses/type.jl diff --git a/src/compiler.jl b/src/compiler.jl index c94a6c84b4..16d0119e97 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -159,295 +159,7 @@ const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,Data return nothing, nothing, nothing end -const nofreefns = Set{String}(( - "ClientGetDevice", - "BufferOnCPU", - "pcre2_match_8", - "julia.gcroot_flush", - "pcre2_jit_stack_assign_8", - "pcre2_match_context_create_8", - "pcre2_jit_stack_create_8", - "ijl_gc_enable_finalizers_internal", - "jl_gc_enable_finalizers_internal", - "pcre2_match_data_create_from_pattern_8", - "ijl_gc_run_pending_finalizers", - "jl_gc_run_pending_finalizers", - "ijl_typeassert", - "jl_typeassert", - "ijl_f_isdefined", - "jl_f_isdefined", - "ijl_field_index", - "jl_field_index", - "ijl_specializations_get_linfo", - "jl_specializations_get_linfo", - "ijl_gf_invoke_lookup_worlds", - "jl_gf_invoke_lookup_worlds", - "ijl_gc_get_total_bytes", - "jl_gc_get_total_bytes", - "ijl_array_grow_at", - "jl_array_grow_at", - "ijl_try_substrtod", - "jl_try_substrtod", - "jl_f__apply_iterate", - "ijl_field_index", - "jl_field_index", - "julia.call", - "julia.call2", - "ijl_tagged_gensym", - "jl_tagged_gensym", - "ijl_array_ptr_copy", - "jl_array_ptr_copy", - "ijl_array_copy", - "jl_array_copy", - "ijl_genericmemory_copy_slice", - "jl_genericmemory_copy_slice", - "ijl_get_nth_field_checked", - "ijl_get_nth_field_checked", - "jl_array_del_end", - "ijl_array_del_end", - "jl_get_world_counter", - "ijl_get_world_counter", - "memhash32_seed", - "memhash_seed", - "ijl_module_parent", - "jl_module_parent", - "julia.safepoint", - "ijl_set_task_tid", - "jl_set_task_tid", - "ijl_get_task_tid", - "jl_get_task_tid", - "julia.get_pgcstack_or_new", - "ijl_global_event_loop", - "jl_global_event_loop", - "ijl_gf_invoke_lookup", - "jl_gf_invoke_lookup", - "ijl_f_typeassert", - "jl_f_typeassert", - "ijl_type_unionall", - "jl_type_unionall", - "jl_gc_queue_root", - "gpu_report_exception", - "gpu_signal_exception", - "julia.ptls_states", - "julia.write_barrier", - "julia.typeof", - "jl_backtrace_from_here", - "ijl_backtrace_from_here", - "jl_box_int64", - "jl_box_int32", - "ijl_box_int64", - "ijl_box_int32", - "jl_box_uint64", - "jl_box_uint32", - "ijl_box_uint64", - "ijl_box_uint32", - "ijl_box_char", - "jl_box_char", - "ijl_subtype", - "jl_subtype", - "julia.get_pgcstack", - "jl_in_threaded_region", - "jl_object_id_", - "jl_object_id", - "ijl_object_id_", - "ijl_object_id", - "jl_breakpoint", - "llvm.julia.gc_preserve_begin", - "llvm.julia.gc_preserve_end", - "jl_get_ptls_states", - "ijl_get_ptls_states", - "jl_f_fieldtype", - "jl_symbol_n", - "jl_stored_inline", - "ijl_stored_inline", - "jl_f_apply_type", - "jl_f_issubtype", - "jl_isa", - "ijl_isa", - "jl_matching_methods", - "ijl_matching_methods", - "jl_excstack_state", - "ijl_excstack_state", - "jl_current_exception", - "ijl_current_exception", - "memhash_seed", - "jl_f__typevar", - "ijl_f__typevar", - "jl_f_isa", - "ijl_f_isa", - "jl_set_task_threadpoolid", - "ijl_set_task_threadpoolid", - "jl_types_equal", - "ijl_types_equal", - "jl_invoke", - "ijl_invoke", - "jl_apply_generic", - "ijl_apply_generic", - "jl_egal__unboxed", - "julia.pointer_from_objref", - "_platform_memcmp", - "memcmp", - "julia.except_enter", - "jl_array_grow_end", - "ijl_array_grow_end", - "jl_f_getfield", - "ijl_f_getfield", - "jl_pop_handler", - "ijl_pop_handler", - "jl_pop_handler_noexcept", - "ijl_pop_handler_noexcept", - "jl_string_to_array", - "ijl_string_to_array", - "jl_alloc_string", - "ijl_alloc_string", - "getenv", - "jl_cstr_to_string", - "ijl_cstr_to_string", - "jl_symbol_n", - "ijl_symbol_n", - "uv_os_homedir", - "jl_array_to_string", - "ijl_array_to_string", - "pcre2_jit_compile_8", - "memmove", -)) - -const inactivefns = Set{String}(( - "ClientGetDevice", - "BufferOnCPU", - "pcre2_match_data_create_from_pattern_8", - "ijl_typeassert", - "jl_typeassert", - "ijl_f_isdefined", - "jl_f_isdefined", - "ijl_field_index", - "jl_field_index", - "ijl_specializations_get_linfo", - "jl_specializations_get_linfo", - "ijl_gf_invoke_lookup_worlds", - "jl_gf_invoke_lookup_worlds", - "ijl_gc_get_total_bytes", - "jl_gc_get_total_bytes", - "ijl_try_substrtod", - "jl_try_substrtod", - "ijl_tagged_gensym", - "jl_tagged_gensym", - "jl_get_world_counter", - "ijl_get_world_counter", - "memhash32_seed", - "memhash_seed", - "ijl_module_parent", - "jl_module_parent", - "julia.safepoint", - "ijl_set_task_tid", - "jl_set_task_tid", - "ijl_get_task_tid", - "jl_get_task_tid", - "julia.get_pgcstack_or_new", - "ijl_global_event_loop", - "jl_global_event_loop", - "ijl_gf_invoke_lookup", - "jl_gf_invoke_lookup", - "ijl_f_typeassert", - "jl_f_typeassert", - "ijl_type_unionall", - "jl_type_unionall", - "jl_gc_queue_root", - "gpu_report_exception", - "gpu_signal_exception", - "julia.ptls_states", - "julia.write_barrier", - "julia.typeof", - "jl_backtrace_from_here", - "ijl_backtrace_from_here", - "jl_box_int64", - "jl_box_int32", - "ijl_box_int64", - "ijl_box_int32", - "jl_box_uint64", - "jl_box_uint32", - "ijl_box_uint64", - "ijl_box_uint32", - "ijl_box_char", - "jl_box_char", - "ijl_subtype", - "jl_subtype", - "julia.get_pgcstack", - "jl_in_threaded_region", - "jl_object_id_", - "jl_object_id", - "ijl_object_id_", - "ijl_object_id", - "jl_breakpoint", - "llvm.julia.gc_preserve_begin", - "llvm.julia.gc_preserve_end", - "jl_get_ptls_states", - "ijl_get_ptls_states", - "jl_f_fieldtype", - "jl_symbol_n", - "jl_stored_inline", - "ijl_stored_inline", - "jl_f_apply_type", - "jl_f_issubtype", - "jl_isa", - "ijl_isa", - "jl_matching_methods", - "ijl_matching_methods", - "jl_excstack_state", - "ijl_excstack_state", - "jl_current_exception", - "ijl_current_exception", - "memhash_seed", - "jl_f__typevar", - "ijl_f__typevar", - "jl_f_isa", - "ijl_f_isa", - "jl_set_task_threadpoolid", - "ijl_set_task_threadpoolid", - "jl_types_equal", - "ijl_types_equal", - "jl_string_to_array", - "ijl_string_to_array", - "jl_alloc_string", - "ijl_alloc_string", - "getenv", - "jl_cstr_to_string", - "ijl_cstr_to_string", - "jl_symbol_n", - "ijl_symbol_n", - "uv_os_homedir", - "jl_array_to_string", - "ijl_array_to_string", - "pcre2_jit_compile_8", - # "jl_" -)) - -const activefns = Set{String}(("jl_",)) - -const inactiveglobs = Set{String}(( - "ijl_boxed_uint8_cache", - "jl_boxed_uint8_cache", - "ijl_boxed_int8_cache", - "jl_boxed_int8_cache", - "jl_nothing", -)) - -@enum ActivityState begin - AnyState = 0 - ActiveState = 1 - DupState = 2 - MixedState = 3 -end - -@inline function Base.:|(a1::ActivityState, a2::ActivityState) - ActivityState(Int(a1) | Int(a2)) -end - -struct Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed} - world::worldT -end - -@inline element(::Val{T}) where {T} = T +include("llvm/attributes.jl") # From https://github.com/JuliaLang/julia/blob/81813164963f38dcd779d65ecd222fad8d7ed437/src/cgutils.cpp#L570 @inline function isghostty(@nospecialize(ty)) @@ -463,446 +175,7 @@ end return false end -@inline function (c::Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed})( - f::Int, -) where {seen,worldT,justActive,UnionSret,AbstractIsMixed} - T = element(first(seen)) - - reftype = ismutabletype(T) || (T isa UnionAll && !AbstractIsMixed) - - if justActive && reftype - return Val(AnyState) - end - - subT = typed_fieldtype(T, f) - - if justActive && ismutabletype(subT) - return Val(AnyState) - end - - sub = active_reg_inner( - subT, - seen, - c.world, - Val(justActive), - Val(UnionSret), - Val(AbstractIsMixed), - ) - - if sub == AnyState - Val(AnyState) - else - if sub == DupState - if justActive - Val(AnyState) - else - Val(DupState) - end - else - if reftype - Val(DupState) - else - Val(sub) - end - end - end -end - -@inline forcefold(::Val{RT}) where {RT} = RT - -@inline function forcefold(::Val{ty}, ::Val{sty}, C::Vararg{Any,N}) where {ty,sty,N} - if sty == AnyState || sty == ty - return forcefold(Val(ty), C...) - end - if ty == AnyState - return forcefold(Val(sty), C...) - else - return MixedState - end -end - -@inline numbereltype(::Type{<:EnzymeCore.RNumber{T}}) where {T} = T -@inline ptreltype(::Type{<:EnzymeCore.RArray{T}}) where {T} = T -@inline ptreltype(::Type{Ptr{T}}) where {T} = T -@inline ptreltype(::Type{Core.LLVMPtr{T,N}}) where {T,N} = T -@inline ptreltype(::Type{Core.LLVMPtr{T} where N}) where {T} = T -@inline ptreltype(::Type{Base.RefValue{T}}) where {T} = T -@inline ptreltype(::Type{Array{T,N}}) where {T,N} = T -@inline ptreltype(::Type{Array{T,N} where N}) where {T} = T -@inline ptreltype(::Type{Complex{T}}) where {T} = T -@inline ptreltype(::Type{Tuple{Vararg{T}}}) where {T} = T -@inline ptreltype(::Type{IdDict{K,V}}) where {K,V} = V -@inline ptreltype(::Type{IdDict{K,V} where K}) where {V} = V -@inline ptreltype(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = T -@static if VERSION < v"1.11-" -else -@inline ptreltype(::Type{Memory{T}}) where T = T -end - -@inline is_arrayorvararg_ty(::Type) = false -@inline is_arrayorvararg_ty(::Type{Array{T,N}}) where {T,N} = true -@inline is_arrayorvararg_ty(::Type{Array{T,N} where N}) where {T} = true -@inline is_arrayorvararg_ty(::Type{Tuple{Vararg{T2}}}) where {T2} = true -@inline is_arrayorvararg_ty(::Type{Ptr{T}}) where {T} = true -@inline is_arrayorvararg_ty(::Type{Core.LLVMPtr{T,N}}) where {T,N} = true -@inline is_arrayorvararg_ty(::Type{Core.LLVMPtr{T,N} where N}) where {T} = true -@inline is_arrayorvararg_ty(::Type{Base.RefValue{T}}) where {T} = true -@inline is_arrayorvararg_ty(::Type{IdDict{K,V}}) where {K,V} = true -@inline is_arrayorvararg_ty(::Type{IdDict{K,V} where K}) where {V} = true -@inline is_arrayorvararg_ty(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = true -@static if VERSION < v"1.11-" -else -@inline is_arrayorvararg_ty(::Type{Memory{T}}) where T = true -end - -@inline function datatype_fieldcount(t::Type{T}) where {T} - return Base.datatype_fieldcount(t) -end - -@inline function staticInTup(::Val{T}, tup::NTuple{N,Val}) where {T,N} - any(ntuple(Val(N)) do i - Base.@_inline_meta - Val(T) == tup[i] - end) -end - -@inline function active_reg_recur( - ::Type{ST}, - seen::Seen, - world, - ::Val{justActive}, - ::Val{UnionSret}, - ::Val{AbstractIsMixed}, -) where {ST,Seen,justActive,UnionSret,AbstractIsMixed} - if ST isa Union - return forcefold( - Val( - active_reg_recur( - ST.a, - seen, - world, - Val(justActive), - Val(UnionSret), - Val(AbstractIsMixed), - ), - ), - Val( - active_reg_recur( - ST.b, - seen, - world, - Val(justActive), - Val(UnionSret), - Val(AbstractIsMixed), - ), - ), - ) - end - return active_reg_inner( - ST, - seen, - world, - Val(justActive), - Val(UnionSret), - Val(AbstractIsMixed), - ) -end - -@inline is_vararg_tup(x) = false -@inline is_vararg_tup(::Type{Tuple{Vararg{T2}}}) where {T2} = true - -@inline function active_reg_inner( - ::Type{T}, - seen::ST, - world::Union{Nothing,UInt}, - ::Val{justActive} = Val(false), - ::Val{UnionSret} = Val(false), - ::Val{AbstractIsMixed} = Val(false), -)::ActivityState where {ST,T,justActive,UnionSret,AbstractIsMixed} - if T === Any - if AbstractIsMixed - return MixedState - else - return DupState - end - end - - if T === Union{} - return AnyState - end - - if T <: Complex && !(T isa UnionAll) - return active_reg_inner( - ptreltype(T), - seen, - world, - Val(justActive), - Val(UnionSret), - Val(AbstractIsMixed), - ) - end - - if T <: BigFloat - return DupState - end - - if T <: AbstractFloat - return ActiveState - end - - if T <: EnzymeCore.RNumber - return active_reg_inner( - numbereltype(T), - seen, - world, - Val(justActive), - Val(UnionSret), - Val(AbstractIsMixed), - ) - end - - if T <: Ptr || - T <: Core.LLVMPtr || - T <: Base.RefValue || - T <: Array || T <: EnzymeCore.RArray - is_arrayorvararg_ty(T) - if justActive - return AnyState - end - - if is_arrayorvararg_ty(T) && - active_reg_inner( - ptreltype(T), - seen, - world, - Val(justActive), - Val(UnionSret), - Val(AbstractIsMixed), - ) == AnyState - return AnyState - else - if AbstractIsMixed && is_vararg_tup(T) - return MixedState - else - return DupState - end - end - end - - if T <: Integer - return AnyState - end - - if isghostty(T) || Core.Compiler.isconstType(T) || T <: Type - return AnyState - end - - inactivety = if typeof(world) === Nothing - EnzymeCore.EnzymeRules.inactive_type(T) - else - inmi = my_methodinstance( - typeof(EnzymeCore.EnzymeRules.inactive_type), - Tuple{Type{T}}, - world, - ) - args = Any[EnzymeCore.EnzymeRules.inactive_type, T] - GC.@preserve T begin - ccall( - :jl_invoke, - Any, - (Any, Ptr{Any}, Cuint, Any), - EnzymeCore.EnzymeRules.inactive_type, - args, - length(args), - inmi, - ) - end - end - - if inactivety - return AnyState - end - - # unknown number of fields - if T isa UnionAll - aT = Base.argument_datatype(T) - if aT === nothing - if AbstractIsMixed - return MixedState - else - return DupState - end - end - if datatype_fieldcount(aT) === nothing - if AbstractIsMixed - return MixedState - else - return DupState - end - end - end - - if T isa Union - # if sret union, the data is stored in a stack memory location and is therefore - # not unique'd preventing the boxing of the union in the default case - if UnionSret && is_sret_union(T) - return active_reg_recur( - T, - seen, - world, - Val(justActive), - Val(UnionSret), - Val(AbstractIsMixed), - ) - else - if justActive - return AnyState - end - if active_reg_inner(T.a, seen, world, Val(justActive), Val(UnionSret)) != - AnyState - if AbstractIsMixed - return MixedState - else - return DupState - end - end - if active_reg_inner(T.b, seen, world, Val(justActive), Val(UnionSret)) != - AnyState - if AbstractIsMixed - return MixedState - else - return DupState - end - end - end - return AnyState - end - - # if abstract it must be by reference - if Base.isabstracttype(T) || T == Tuple - if AbstractIsMixed - return MixedState - else - return DupState - end - end - - if ismutabletype(T) - # if just looking for active of not - # we know for a fact this isn't active - if justActive - return AnyState - end - end - - @assert !Base.isabstracttype(T) - if !(Base.isconcretetype(T) || T <: Tuple || T isa UnionAll) - throw(AssertionError("Type $T is not concrete type or concrete tuple")) - end - - nT = if T <: Tuple && !(T isa UnionAll) - Tuple{( - ntuple(length(T.parameters)) do i - Base.@_inline_meta - sT = T.parameters[i] - if sT isa TypeVar - Any - elseif sT isa Core.TypeofVararg - Any - else - sT - end - end - )...} - else - T - end - - if staticInTup(Val(nT), seen) - return MixedState - end - - seen2 = (Val(nT), seen...) - - fty = Merger{seen2,typeof(world),justActive,UnionSret,AbstractIsMixed}(world) - - ty = forcefold(Val(AnyState), ntuple(fty, Val(fieldcount(nT)))...) - - return ty -end - -@inline @generated function active_reg_nothrow(::Type{T}, ::Val{world}) where {T,world} - return active_reg_inner(T, (), world) -end - -Base.@pure @inline function active_reg( - ::Type{T}, - world::Union{Nothing,UInt} = nothing, -)::Bool where {T} - seen = () - - # check if it could contain an active - if active_reg_inner(T, seen, world, Val(true)) == ActiveState #=justActive=# - state = active_reg_inner(T, seen, world, Val(false)) #=justActive=# - if state == ActiveState - return true - end - @assert state == MixedState - throw( - AssertionError( - string(T) * - " has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information", - ), - ) - else - return false - end -end - -@inline function guaranteed_const(::Type{T}) where {T} - rt = active_reg_nothrow(T, Val(nothing)) - res = rt == AnyState - return res -end - -@inline function guaranteed_const_nongen(::Type{T}, world) where {T} - rt = active_reg_inner(T, (), world) - res = rt == AnyState - return res -end - -# check if a value is guaranteed to be not contain active[register] data -# (aka not either mixed or active) -@inline function guaranteed_nonactive(::Type{T}) where {T} - rt = Enzyme.Compiler.active_reg_nothrow(T, Val(nothing)) - return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState -end - -""" - Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) - -Try to guess the most appropriate [`Annotation`](@ref) for arguments of type `T` passed to [`autodiff`](@ref) with a given `mode`. -""" -@inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = - guess_activity(T, convert(API.CDerivativeMode, mode)) - -@inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T} - ActReg = active_reg_inner(T, (), nothing) - if ActReg == AnyState - return Const{T} - end - if Mode == API.DEM_ForwardMode - return Duplicated{T} - else - if ActReg == ActiveState - return Active{T} - elseif ActReg == MixedState - return MixedDuplicated{T} - else - return Duplicated{T} - end - end -end +include("analyses/activity.jl") # User facing interface abstract type AbstractThunk{FA,RT,TT,Width} end @@ -1205,7 +478,7 @@ struct Tape{TapeTy,ShadowTy,ResT} shadow_return::ShadowTy end -include("make_zero.jl") +include("typeutils/make_zero.jl") function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt) funcspec = my_methodinstance(typeof(f), tt, world) @@ -1315,709 +588,71 @@ function removed_ret_parms(F::LLVM.Function) parmrem = a end if kind(a) == "enzyme_retremove" - retRemove = true - end - end - end - if parmrem !== nothing - str = value(parmrem) - for v in eachsplit(str, ",") - push!(parmsRemoved, parse(UInt64, v)) - end - end - return retRemove, parmsRemoved -end - -abstract type CompilationException <: Base.Exception end -struct NoDerivativeException <: CompilationException - msg::String - ir::Union{Nothing,String} - bt::Union{Nothing,Vector{StackTraces.StackFrame}} -end - -function Base.showerror(io::IO, ece::NoDerivativeException) - print(io, "Enzyme compilation failed.\n") - if ece.ir !== nothing - print(io, "Current scope: \n") - print(io, ece.ir) - end - print(io, '\n', ece.msg, '\n') - if ece.bt !== nothing - Base.show_backtrace(io, ece.bt) - println(io) - end -end - -struct IllegalTypeAnalysisException <: CompilationException - msg::String - sval::String - ir::Union{Nothing,String} - bt::Union{Nothing,Vector{StackTraces.StackFrame}} -end - -function Base.showerror(io::IO, ece::IllegalTypeAnalysisException) - print(io, "Enzyme compilation failed due to illegal type analysis.\n") - if ece.ir !== nothing - print(io, "Current scope: \n") - print(io, ece.ir) - end - print(io, "\n Type analysis state: \n") - write(io, ece.sval) - print(io, '\n', ece.msg, '\n') - if ece.bt !== nothing - print(io, "\nCaused by:") - Base.show_backtrace(io, ece.bt) - println(io) - end -end - -struct IllegalFirstPointerException <: CompilationException - msg::String - ir::Union{Nothing,String} - bt::Union{Nothing,Vector{StackTraces.StackFrame}} -end - -function Base.showerror(io::IO, ece::IllegalFirstPointerException) - print(io, "Enzyme compilation failed.\n") - if ece.ir !== nothing - print(io, "Current scope: \n") - print(io, ece.ir) - end - print(io, '\n', ece.msg, '\n') - if ece.bt !== nothing - Base.show_backtrace(io, ece.bt) - println(io) - end -end - -struct EnzymeInternalError <: CompilationException - msg::String - ir::Union{Nothing,String} - bt::Union{Nothing,Vector{StackTraces.StackFrame}} -end - -function Base.showerror(io::IO, ece::EnzymeInternalError) - print(io, "Enzyme compilation failed.\n") - if ece.ir !== nothing - print(io, "Current scope: \n") - print(io, ece.ir) - end - print(io, '\n', ece.msg, '\n') - if ece.bt !== nothing - Base.show_backtrace(io, ece.bt) - println(io) - end -end - -parent_scope(val::LLVM.Function, depth = 0) = depth == 0 ? LLVM.parent(val) : val -parent_scope(val::LLVM.Module, depth = 0) = val -parent_scope(@nospecialize(val::LLVM.Value), depth = 0) = parent_scope(LLVM.parent(val), depth + 1) -parent_scope(val::LLVM.Argument, depth = 0) = - parent_scope(LLVM.Function(LLVM.API.LLVMGetParamParent(val)), depth + 1) - -const CheckNan = Ref(false) -function julia_sanitize( - orig::LLVM.API.LLVMValueRef, - val::LLVM.API.LLVMValueRef, - B::LLVM.API.LLVMBuilderRef, - mask::LLVM.API.LLVMValueRef, -)::LLVM.API.LLVMValueRef - orig = LLVM.Value(orig) - val = LLVM.Value(val) - B = LLVM.IRBuilder(B) - if CheckNan[] - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - ty = LLVM.value_type(val) - vt = LLVM.VoidType() - FT = LLVM.FunctionType(vt, [ty, LLVM.PointerType(LLVM.Int8Type())]) - - stringv = "Enzyme: Found nan while computing derivative of " * string(orig) - if orig !== nothing && isa(orig, LLVM.Instruction) - bt = GPUCompiler.backtrace(orig) - stringv *= sprint(Base.Fix2(Base.show_backtrace, bt)) - end - - fn, _ = get_function!(mod, "julia.sanitize." * string(ty), FT) - if isempty(blocks(fn)) - let builder = IRBuilder() - entry = BasicBlock(fn, "entry") - good = BasicBlock(fn, "good") - bad = BasicBlock(fn, "bad") - position!(builder, entry) - inp, sval = collect(parameters(fn)) - cmp = fcmp!(builder, LLVM.API.LLVMRealUNO, inp, inp) - - br!(builder, cmp, bad, good) - - position!(builder, good) - ret!(builder) - - position!(builder, bad) - - emit_error(builder, nothing, sval, EnzymeNoDerivativeError) - unreachable!(builder) - dispose(builder) - end - end - # val = - call!(B, FT, fn, LLVM.Value[val, globalstring_ptr!(B, stringv)]) - end - return val.ref -end - -function julia_error( - cstr::Cstring, - val::LLVM.API.LLVMValueRef, - errtype::API.ErrorType, - data::Ptr{Cvoid}, - data2::LLVM.API.LLVMValueRef, - B::LLVM.API.LLVMBuilderRef, -)::LLVM.API.LLVMValueRef - msg = Base.unsafe_string(cstr) - bt = nothing - ir = nothing - if val != C_NULL - val = LLVM.Value(val) - if isa(val, LLVM.Instruction) - dbgval = val - while !haskey(metadata(dbgval), LLVM.MD_dbg) - dbgval = LLVM.API.LLVMGetNextInstruction(dbgval) - if dbgval == C_NULL - dbgval = nothing - break - else - dbgval = LLVM.Instruction(dbgval) - end - end - if dbgval !== nothing - bt = GPUCompiler.backtrace(dbgval) - end - end - if isa(val, LLVM.ConstantExpr) - for u in LLVM.uses(val) - u = LLVM.user(u) - if isa(u, LLVM.Instruction) - bt = GPUCompiler.backtrace(val) - end - end - else - # Need to convert function to string, since when the error is going to be printed - # the module might have been destroyed - ir = string(parent_scope(val)) - end - end - - if errtype == API.ET_NoDerivative - if occursin("No create nofree of empty function", msg) || - occursin("No forward mode derivative found for", msg) || - occursin("No augmented forward pass", msg) || - occursin("No reverse pass found", msg) - ir = nothing - end - if B != C_NULL - B = IRBuilder(B) - msg2 = sprint() do io - if ir !== nothing - print(io, "Current scope: \n") - print(io, ir) - end - print(io, '\n', msg, '\n') - if bt !== nothing - Base.show_backtrace(io, bt) - println(io) - end - end - emit_error(B, nothing, msg2, EnzymeNoDerivativeError) - return C_NULL - end - throw(NoDerivativeException(msg, ir, bt)) - elseif errtype == API.ET_NoShadow - gutils = GradientUtils(API.EnzymeGradientUtilsRef(data)) - - msgN = sprint() do io::IO - if isa(val, LLVM.Argument) - fn = parent_scope(val) - ir = string(LLVM.name(fn)) * string(function_type(fn)) - print(io, "Current scope: \n") - print(io, ir) - end - if !isa(val, LLVM.Argument) - print(io, "\n Inverted pointers: \n") - ip = API.EnzymeGradientUtilsInvertedPointersToString(gutils) - sval = Base.unsafe_string(ip) - write(io, sval) - API.EnzymeStringFree(ip) - end - print(io, '\n', msg, '\n') - if bt !== nothing - print(io, "\nCaused by:") - Base.show_backtrace(io, bt) - println(io) - end - end - emit_error(IRBuilder(B), nothing, msgN, EnzymeNoShadowError) - return LLVM.null(get_shadow_type(gutils, value_type(val))).ref - elseif errtype == API.ET_IllegalTypeAnalysis - data = API.EnzymeTypeAnalyzerRef(data) - ip = API.EnzymeTypeAnalyzerToString(data) - sval = Base.unsafe_string(ip) - API.EnzymeStringFree(ip) - - if isa(val, LLVM.Instruction) - mi, rt = enzyme_custom_extract_mi( - LLVM.parent(LLVM.parent(val))::LLVM.Function, - false, - ) #=error=# - if mi !== nothing - msg *= "\n" * string(mi) * "\n" - end - end - throw(IllegalTypeAnalysisException(msg, sval, ir, bt)) - elseif errtype == API.ET_NoType - @assert B != C_NULL - B = IRBuilder(B) - - data = API.EnzymeTypeAnalyzerRef(data) - ip = API.EnzymeTypeAnalyzerToString(data) - sval = Base.unsafe_string(ip) - API.EnzymeStringFree(ip) - - msg2 = sprint() do io::IO - if !occursin("Cannot deduce single type of store", msg) - if ir !== nothing - print(io, "Current scope: \n") - print(io, ir) - end - print(io, "\n Type analysis state: \n") - write(io, sval) - end - print(io, '\n', msg, '\n') - if bt !== nothing - print(io, "\nCaused by:") - Base.show_backtrace(io, bt) - println(io) - end - pscope = parent_scope(val) - mi, rt = enzyme_custom_extract_mi(pscope, false) #=error=# - if mi !== nothing - println(io, "within ", mi) - end - end - emit_error(B, nothing, msg2, EnzymeNoTypeError) - return C_NULL - elseif errtype == API.ET_IllegalFirstPointer - throw(IllegalFirstPointerException(msg, ir, bt)) - elseif errtype == API.ET_InternalError - throw(EnzymeInternalError(msg, ir, bt)) - elseif errtype == API.ET_TypeDepthExceeded - msg2 = sprint() do io - print(io, msg) - println(io) - - if val != C_NULL - println(io, val) - end - - st = API.EnzymeTypeTreeToString(data) - println(io, Base.unsafe_string(st)) - API.EnzymeStringFree(st) - - if bt !== nothing - Base.show_backtrace(io, bt) - end - end - GPUCompiler.@safe_warn msg2 - return C_NULL - elseif errtype == API.ET_IllegalReplaceFicticiousPHIs - data2 = LLVM.Value(data2) - msg2 = sprint() do io - print(io, msg) - println(io) - println(io, string(LLVM.parent(LLVM.parent(data2)))) - println(io, val) - println(io, data2) - end - throw(EnzymeInternalError(msg2, ir, bt)) - elseif errtype == API.ET_MixedActivityError - data2 = LLVM.Value(data2) - badval = nothing - gutils = GradientUtils(API.EnzymeGradientUtilsRef(data)) - # Ignore mismatched activity if phi/store of ghost - seen = Dict{LLVM.Value,LLVM.Value}() - illegal = false - created = LLVM.Instruction[] - world = enzyme_extract_world(LLVM.parent(position(IRBuilder(B)))) - width = get_width(gutils) - function make_batched(@nospecialize(cur::LLVM.Value), B::LLVM.IRBuilder)::LLVM.Value - if width == 1 - return cur - else - shadowres = UndefValue( - LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))), - ) - for idx = 1:width - shadowres = insert_value!(B, shadowres, cur, idx - 1) - if isa(shadowres, LLVM.Instruction) - push!(created, shadowres) - end - end - return shadowres - end - end - - illegalVal = nothing - - function make_replacement(@nospecialize(cur::LLVM.Value), prevbb::LLVM.IRBuilder)::LLVM.Value - ncur = new_from_original(gutils, cur) - if cur in keys(seen) - return seen[cur] - end - - legal, TT, byref = abs_typeof(cur, true) - - if legal - if guaranteed_const_nongen(TT, world) - return make_batched(ncur, prevbb) - end - - legal2, obj = absint(cur) - - # Only do so for the immediate operand/etc to a phi, since otherwise we will make multiple - if legal2 - if active_reg_inner(TT, (), world) == ActiveState && - isa(cur, LLVM.ConstantExpr) && - cur == data2 - if width == 1 - res = emit_allocobj!(prevbb, Base.RefValue{TT}) - push!(created, res) - return res - else - shadowres = UndefValue( - LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))), - ) - for idx = 1:width - res = emit_allocobj!(prevbb, Base.RefValue{TT}) - shadowres = insert_value!(prevbb, shadowres, res, idx - 1) - push!(created, shadowres) - end - return shadowres - end - end - -@static if VERSION < v"1.11-" -else - if obj isa Memory && obj == typeof(obj).instance - return make_batched(ncur, prevbb) - end -end - end - -@static if VERSION < v"1.11-" -else - if isa(cur, LLVM.LoadInst) - larg, off = get_base_and_offset(operands(cur)[1]) - if isa(larg, LLVM.LoadInst) - legal2, obj = absint(larg) - if legal2 && obj isa Memory && obj == typeof(obj).instance - return make_batched(ncur, prevbb) - end - end - end -end - - badval = if legal2 - string(obj) * " of type" * " " * string(TT) - else - "Unknown object of type" * " " * string(TT) - end - @assert !illegal - illegalVal = cur - illegal = true - return make_batched(ncur, prevbb) - end - - if isa(cur, LLVM.PointerNull) - return make_batched(ncur, prevbb) - end - if isa(cur, LLVM.UndefValue) - return make_batched(ncur, prevbb) - end - if isa(cur, LLVM.PoisonValue) - return make_batched(ncur, prevbb) - end - if isa(cur, LLVM.ConstantAggregateZero) - return make_batched(ncur, prevbb) - end - if isa(cur, LLVM.ConstantAggregate) - return make_batched(ncur, prevbb) - end - if isa(cur, LLVM.ConstantInt) - if convert(UInt64, cur) == 0 - return make_batched(ncur, prevbb) - end - end - if isa(cur, LLVM.ConstantFP) - return make_batched(ConstantFP(value_type(cur), 0), prevbb) - end - if isa(cur, LLVM.ConstantDataSequential) - cvals = LLVM.Value[] - changed = false - for v in collect(cur) - tmp = make_replacement(v, prevbb) - if illegal - return ncur - end - if v != tmp - changed = true - end - push!(cvals, tmp) - end - - cur2 = if changed - @assert !illegal - illegalVal = cur - illegal = true - # TODO replace with correct insertions/splats - ncur - else - make_batched(ncur, prevbb) - end - return cur2 - end - if isa(cur, LLVM.ConstantInt) - if LLVM.width(value_type(cur)) <= sizeof(Int) * 8 - return make_batched(ncur, prevbb) - end - if LLVM.width(value_type(cur)) == sizeof(Int) * 8 && - abs(convert(Int, cur)) < 10000 - return make_batched(ncur, prevbb) - end - # if storing a constant int as a non-pointer, presume it is not a GC'd var and is safe - # for activity state to mix - if isa(val, LLVM.StoreInst) - operands(val)[1] == cur && - !isa(value_type(operands(val)[1]), LLVM.PointerType) - return make_batched(ncur, prevbb) - end - end - - if isa(cur, LLVM.SelectInst) - lhs = make_replacement(operands(cur)[2], prevbb) - if illegal - return ncur - end - rhs = make_replacement(operands(cur)[3], prevbb) - if illegal - return ncur - end - if lhs == operands(cur)[2] && rhs == operands(cur)[3] - return make_batched(ncur, prevbb) - end - if width == 1 - nv = select!( - prevbb, - new_from_original(gutils, operands(cur)[1]), - lhs, - rhs, - ) - push!(created, nv) - seen[cur] = nv - return nv - else - shadowres = LLVM.UndefValue(value_type(lhs)) - for idx = 1:width - shadowres = insert_value!( - prevbb, - shadowres, - select!( - prevbb, - new_from_original(gutils, operands(cur)[1]), - extract_value!(prevbb, lhs, idx - 1), - extract_value!(prevbb, rhs, idx - 1), - ), - idx - 1, - ) - if isa(shadowres, LLVM.Instruction) - push!(created, shadowres) - end - end - return shadowres - end - end - - if isa(cur, LLVM.InsertValueInst) - lhs = make_replacement(operands(cur)[1], prevbb) - if illegal - return ncur - end - rhs = make_replacement(operands(cur)[2], prevbb) - if illegal - return ncur - end - if lhs == operands(cur)[1] && rhs == operands(cur)[2] - return make_batched(ncur, prevbb) - end - inds = LLVM.API.LLVMGetIndices(cur.ref) - ninds = LLVM.API.LLVMGetNumIndices(cur.ref) - jinds = Cuint[unsafe_load(inds, i) for i = 1:ninds] - if width == 1 - nv = API.EnzymeInsertValue(prevbb, lhs, rhs, jinds) - push!(created, nv) - seen[cur] = nv - return nv - else - shadowres = lhs - for idx = 1:width - jindsv = copy(jinds) - pushfirst!(jindsv, idx - 1) - shadowres = API.EnzymeInsertValue( - prevbb, - shadowres, - extract_value!(prevbb, rhs, idx - 1), - jindsv, - ) - if isa(shadowres, LLVM.Instruction) - push!(created, shadowres) - end - end - return shadowres - end - end - - if isa(cur, LLVM.LoadInst) || isa(cur, LLVM.BitCastInst) || isa(cur, LLVM.AddrSpaceCastInst) || (isa(cur, LLVM.GetElementPtrInst) && all(Base.Fix2(isa, LLVM.ConstantInt), operands(cur)[2:end])) - lhs = make_replacement(operands(cur)[1], prevbb) - if illegal - return ncur - end - if lhs == operands(ncur)[1] - return make_batched(ncur, prevbb) - elseif width != 1 && isa(lhs, LLVM.InsertValueInst) && operands(lhs)[2] == operands(ncur)[1] - return make_batched(ncur, prevbb) - end - end - - if isa(cur, LLVM.PHIInst) - Bphi = IRBuilder() - position!(Bphi, ncur) - shadowty = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))) - phi2 = phi!(Bphi, shadowty, "tempphi" * LLVM.name(cur)) - seen[cur] = phi2 - changed = false - recsize = length(created) + 1 - for (v, bb) in LLVM.incoming(cur) - B2 = IRBuilder() - position!(B2, new_from_original(gutils, last(instructions(bb)))) - tmp = make_replacement(v, B2) - if illegal - changed = true - break - end - @assert value_type(tmp) == shadowty - if tmp != new_from_original(gutils, v) && v != cur - changed = true - end - push!(LLVM.incoming(phi2), (tmp, new_from_original(gutils, bb))) - end - if !changed || illegal - LLVM.API.LLVMInstructionEraseFromParent(phi2) - seen[cur] = ncur - plen = length(created) - for i = recsize:plen - u = created[i] - replace_uses!(u, LLVM.UndefValue(value_type(u))) - end - for i = recsize:plen - u = created[i] - LLVM.API.LLVMInstructionEraseFromParent(u) - end - for i = recsize:plen - pop!(created) - end - return illegal ? ncur : make_batched(ncur, prevbb) - end - push!(created, phi2) - return phi2 - end - - tt = TypeTree(API.EnzymeGradientUtilsAllocAndGetTypeTree(gutils, cur)) - st = API.EnzymeTypeTreeToString(tt) - st2 = Base.unsafe_string(st) - API.EnzymeStringFree(st) - if st2 == "{[-1]:Integer}" - return make_batched(ncur, prevbb) - end - - if !illegal - illegal = true - illegalVal = cur + retRemove = true end - return ncur end + end + if parmrem !== nothing + str = value(parmrem) + for v in eachsplit(str, ",") + push!(parmsRemoved, parse(UInt64, v)) + end + end + return retRemove, parmsRemoved +end - b = IRBuilder(B) - replacement = make_replacement(data2, b) +include("errors.jl") - if !illegal - return replacement.ref - end - for u in created - replace_uses!(u, LLVM.UndefValue(value_type(u))) - end - for u in created - LLVM.API.LLVMInstructionEraseFromParent(u) - end - if LLVM.API.LLVMIsAReturnInst(val) != C_NULL - mi, rt = enzyme_custom_extract_mi( - LLVM.parent(LLVM.parent(val))::LLVM.Function, - false, - ) #=error=# - if mi !== nothing && isghostty(rt) - return C_NULL - end - end - msg2 = sprint() do io - print(io, msg) - println(io) - if badval !== nothing - println(io, " value=" * badval) - else - ttval = val - if isa(ttval, LLVM.StoreInst) - ttval = operands(ttval)[1] - end - tt = TypeTree(API.EnzymeGradientUtilsAllocAndGetTypeTree(gutils, ttval)) - st = API.EnzymeTypeTreeToString(tt) - print(io, "Type tree: ") - println(io, Base.unsafe_string(st)) - API.EnzymeStringFree(st) - end - if illegalVal !== nothing - println(io, " llvalue=" * string(illegalVal)) - end - if bt !== nothing - Base.show_backtrace(io, bt) - end +const CheckNan = Ref(false) +function julia_sanitize( + orig::LLVM.API.LLVMValueRef, + val::LLVM.API.LLVMValueRef, + B::LLVM.API.LLVMBuilderRef, + mask::LLVM.API.LLVMValueRef, +)::LLVM.API.LLVMValueRef + orig = LLVM.Value(orig) + val = LLVM.Value(val) + B = LLVM.IRBuilder(B) + if CheckNan[] + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + ty = LLVM.value_type(val) + vt = LLVM.VoidType() + FT = LLVM.FunctionType(vt, [ty, LLVM.PointerType(LLVM.Int8Type())]) + + stringv = "Enzyme: Found nan while computing derivative of " * string(orig) + if orig !== nothing && isa(orig, LLVM.Instruction) + bt = GPUCompiler.backtrace(orig) + stringv *= sprint(Base.Fix2(Base.show_backtrace, bt)) end - emit_error(b, nothing, msg2, EnzymeRuntimeActivityError) - return C_NULL - elseif errtype == API.ET_GetIndexError - @assert B != C_NULL - B = IRBuilder(B) - msg5 = sprint() do io::IO - print(io, "Enzyme internal error\n") - print(io, msg, '\n') - if bt !== nothing - print(io, "\nCaused by:") - Base.show_backtrace(io, bt) - println(io) + + fn, _ = get_function!(mod, "julia.sanitize." * string(ty), FT) + if isempty(blocks(fn)) + let builder = IRBuilder() + entry = BasicBlock(fn, "entry") + good = BasicBlock(fn, "good") + bad = BasicBlock(fn, "bad") + position!(builder, entry) + inp, sval = collect(parameters(fn)) + cmp = fcmp!(builder, LLVM.API.LLVMRealUNO, inp, inp) + + br!(builder, cmp, bad, good) + + position!(builder, good) + ret!(builder) + + position!(builder, bad) + + emit_error(builder, nothing, sval, EnzymeNoDerivativeError) + unreachable!(builder) + dispose(builder) end end - emit_error(B, nothing, msg5) - return C_NULL + # val = + call!(B, FT, fn, LLVM.Value[val, globalstring_ptr!(B, stringv)]) end - throw(AssertionError("Unknown errtype")) + return val.ref end function any_jltypes(Type::LLVM.PointerType) @@ -3227,664 +1862,59 @@ function primal_return_type_generator(world::UInt, source, self, @nospecialize(m m, mtypes, msp, - ) - ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo - - # prepare a new code info - new_ci = copy(ci) - empty!(new_ci.code) - @static if isdefined(Core, :DebugInfo) - new_ci.debuginfo = Core.DebugInfo(:none) - else - empty!(new_ci.codelocs) - resize!(new_ci.linetable, 1) # see note below - end - empty!(new_ci.ssaflags) - new_ci.ssavaluetypes = 0 - new_ci.min_world = min_world[] - new_ci.max_world = max_world[] - new_ci.edges = Core.MethodInstance[mi] - # XXX: setting this edge does not give us proper method invalidation, see - # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. - # invoking `code_llvm` also does the necessary codegen, as does calling the - # underlying C methods -- which GPUCompiler does, so everything Just Works. - - # prepare the slots - new_ci.slotnames = Symbol[Symbol("#self#"), :mode, :ft, :tt] - new_ci.slotflags = UInt8[0x00 for i = 1:4] - - # return the codegen world age - res = primal_return_type_world(mode, world, mi) - push!(new_ci.code, Core.Compiler.ReturnNode(res)) - push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` - @static if isdefined(Core, :DebugInfo) - else - push!(new_ci.codelocs, 1) # see note below - end - new_ci.ssavaluetypes += 1 - - # NOTE: we keep the first entry of the original linetable, and use it for location info - # on the call to check_cache. we can't not have a codeloc (using 0 causes - # corruption of the back trace), and reusing the target function's info - # has as advantage that we see the name of the kernel in the backtraces. - - return new_ci -end - -@eval Base.@assume_effects :removable :foldable :nothrow @inline function primal_return_type(mode::Mode, ft::Type, tt::Type) - $(Expr(:meta, :generated_only)) - $(Expr(:meta, :generated, primal_return_type_generator)) -end - -## -# Enzyme compiler step -## - -function annotate!(mod::LLVM.Module) - inactive = LLVM.StringAttribute("enzyme_inactive", "") - active = LLVM.StringAttribute("enzyme_active", "") - no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") - - funcs = Dict{String, Vector{LLVM.Function}}() - for f in functions(mod) - fname = LLVM.name(f) - for fattr in collect(function_attributes(f)) - if isa(fattr, LLVM.StringAttribute) - if kind(fattr) == "enzyme_math" - fname = LLVM.value(fattr) - break - end - end - end - fname = String(fname) - if !haskey(funcs, fname) - funcs[fname] = LLVM.Function[] - end - push!(funcs[String(fname)], f) - API.EnzymeAttributeKnownFunctions(f.ref) - end - - for gname in inactiveglobs - globs = LLVM.globals(mod) - if haskey(globs, gname) - glob = globs[gname] - API.SetMD(glob, "enzyme_inactive", LLVM.MDNode(LLVM.Metadata[])) - end - end - - for fname in inactivefns - if haskey(funcs, fname) - for fn in funcs[fname] - push!(function_attributes(fn), inactive) - push!(function_attributes(fn), no_escaping_alloc) - for u in LLVM.uses(fn) - c = LLVM.user(u) - if !isa(c, LLVM.CallInst) - continue - end - cf = LLVM.called_operand(c) - if !isa(cf, LLVM.Function) - continue - end - if LLVM.name(cf) != "julia.call" && LLVM.name(cf) != "julia.call2" - continue - end - if operands(c)[1] != fn - continue - end - LLVM.API.LLVMAddCallSiteAttribute( - c, - reinterpret( - LLVM.API.LLVMAttributeIndex, - LLVM.API.LLVMAttributeFunctionIndex, - ), - inactive, - ) - LLVM.API.LLVMAddCallSiteAttribute( - c, - reinterpret( - LLVM.API.LLVMAttributeIndex, - LLVM.API.LLVMAttributeFunctionIndex, - ), - no_escaping_alloc, - ) - end - end - end - end - - for fname in nofreefns - if haskey(funcs, fname) - for fn in funcs[fname] - push!(function_attributes(fn), LLVM.EnumAttribute("nofree", 0)) - for u in LLVM.uses(fn) - c = LLVM.user(u) - if !isa(c, LLVM.CallInst) - continue - end - cf = LLVM.called_operand(c) - if !isa(cf, LLVM.Function) - continue - end - if LLVM.name(cf) != "julia.call" && LLVM.name(cf) != "julia.call2" - continue - end - if operands(c)[1] != fn - continue - end - LLVM.API.LLVMAddCallSiteAttribute( - c, - reinterpret( - LLVM.API.LLVMAttributeIndex, - LLVM.API.LLVMAttributeFunctionIndex, - ), - LLVM.EnumAttribute("nofree", 0), - ) - end - end - end - end - - for fname in activefns - if haskey(funcs, fname) - for fn in funcs[fname] - push!(function_attributes(fn), active) - end - end - end - - for fname in - ("julia.typeof", "jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id") - if haskey(funcs, fname) - for fn in funcs[fname] - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) - else - push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) - end - push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) - end - end - end - for fname in ("julia.typeof",) - if haskey(funcs, fname) - for fn in funcs[fname] - push!(function_attributes(fn), LLVM.StringAttribute("enzyme_nocache")) - push!(parameter_attributes(fn, 1), LLVM.EnumAttribute("nocapture")) - end - end - end - - for fname in - ("jl_excstack_state", "ijl_excstack_state", "ijl_field_index", "jl_field_index") - if haskey(funcs, fname) - for fn in funcs[fname] - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) - push!(function_attributes(fn), LLVM.StringAttribute("inaccessiblememonly")) - else - push!( - function_attributes(fn), - EnumAttribute( - "memory", - MemoryEffect( - (MRI_NoModRef << getLocationPos(ArgMem)) | - (MRI_Ref << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ), - ) - end - end - end - end - - for fname in ("jl_types_equal", "ijl_types_equal") - if haskey(funcs, fname) - for fn in funcs[fname] - push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) - end - end - end - - for fname in ( - "UnsafeBufferPointer", - ) - if haskey(funcs, fname) - for fn in funcs[fname] - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.StringAttribute("enzyme_math", "__dynamic_cast")) - end - end - end - end - - for fname in ( - "jl_f_getfield", - "ijl_f_getfield", - "jl_get_nth_field_checked", - "ijl_get_nth_field_checked", - "jl_f__svec_ref", - "ijl_f__svec_ref", - "UnsafeBufferPointer" - ) - if haskey(funcs, fname) - for fn in funcs[fname] - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) - else - push!(function_attributes(fn), - EnumAttribute( - "memory", - MemoryEffect( - (MRI_Ref << getLocationPos(ArgMem)) | - (MRI_NoModRef << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ) - ) - end - for u in LLVM.uses(fn) - c = LLVM.user(u) - if !isa(c, LLVM.CallInst) - continue - end - cf = LLVM.called_operand(c) - if !isa(cf, LLVM.Function) - continue - end - if LLVM.name(cf) != "julia.call" && LLVM.name(cf) != "julia.call2" - continue - end - if operands(c)[1] != fn - continue - end - attr = if LLVM.version().major <= 15 - LLVM.EnumAttribute("readonly") - else - EnumAttribute( - "memory", - MemoryEffect( - (MRI_Ref << getLocationPos(ArgMem)) | - (MRI_NoModRef << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ) - end - LLVM.API.LLVMAddCallSiteAttribute( - c, - reinterpret( - LLVM.API.LLVMAttributeIndex, - LLVM.API.LLVMAttributeFunctionIndex, - ), - attr, - ) - end - end - end - end - - for fname in ("julia.get_pgcstack", "julia.ptls_states", "jl_get_ptls_states") - if haskey(funcs, fname) - for fn in funcs[fname] - # TODO per discussion w keno perhaps this should change to readonly / inaccessiblememonly - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) - else - push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) - end - push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) - end - end - end - - for fname in ("julia.gc_loaded",) - if haskey(funcs, fname) - for fn in funcs[fname] - push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) - push!(function_attributes(fn), LLVM.StringAttribute("enzyme_nocache")) - end - end - end - - for fname in ( - "julia.get_pgcstack", - "julia.ptls_states", - "jl_get_ptls_states", - "julia.safepoint", - "ijl_throw", - "julia.pointer_from_objref", - "ijl_array_grow_end", - "jl_array_grow_end", - "ijl_array_del_end", - "jl_array_del_end", - "ijl_array_grow_beg", - "jl_array_grow_beg", - "ijl_array_del_beg", - "jl_array_del_beg", - "ijl_array_grow_at", - "jl_array_grow_at", - "ijl_array_del_at", - "jl_array_del_at", - "ijl_pop_handler", - "jl_pop_handler", - "ijl_pop_handler_noexcept", - "jl_pop_handler_noexcept", - "ijl_push_handler", - "jl_push_handler", - "ijl_module_name", - "jl_module_name", - "ijl_restore_excstack", - "jl_restore_excstack", - "julia.except_enter", - "ijl_get_nth_field_checked", - "jl_get_nth_field_checked", - "jl_egal__unboxed", - "ijl_reshape_array", - "jl_reshape_array", - "ijl_eqtable_get", - "jl_eqtable_get", - "jl_gc_run_pending_finalizers", - "ijl_try_substrtod", - "jl_try_substrtod", - ) - if haskey(funcs, fname) - for fn in funcs[fname] - push!(function_attributes(fn), no_escaping_alloc) - end - end - end - - + ) + ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo - for fname in ("julia.pointer_from_objref",) - if haskey(funcs, fname) - for fn in funcs[fname] - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) - else - push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) - end - end - end + # prepare a new code info + new_ci = copy(ci) + empty!(new_ci.code) + @static if isdefined(Core, :DebugInfo) + new_ci.debuginfo = Core.DebugInfo(:none) + else + empty!(new_ci.codelocs) + resize!(new_ci.linetable, 1) # see note below end + empty!(new_ci.ssaflags) + new_ci.ssavaluetypes = 0 + new_ci.min_world = min_world[] + new_ci.max_world = max_world[] + new_ci.edges = Core.MethodInstance[mi] + # XXX: setting this edge does not give us proper method invalidation, see + # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. + # invoking `code_llvm` also does the necessary codegen, as does calling the + # underlying C methods -- which GPUCompiler does, so everything Just Works. - for fname in ( - "julia.gc_alloc_obj", - "jl_gc_alloc_typed", - "ijl_gc_alloc_typed", - ) - if haskey(funcs, fname) - for fn in funcs[fname] - LLVM.API.LLVMRemoveEnumAttributeAtIndex( - fn, - reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), - kind(EnumAttribute("allockind", AllocFnKind(AFKE_Alloc).data)), - ) - push!(function_attributes(fn), no_escaping_alloc) - push!(function_attributes(fn), LLVM.EnumAttribute("allockind", (AllocFnKind(AFKE_Alloc) | AllocFnKind(AFKE_Uninitialized)).data)) - end - end - end - - for fname in ( - "julia.gc_alloc_obj", - "jl_gc_alloc_typed", - "ijl_gc_alloc_typed", - "jl_box_float32", - "jl_box_float64", - "jl_box_int32", - "jl_box_int64", - "ijl_box_float32", - "ijl_box_float64", - "ijl_box_int32", - "ijl_box_int64", - "jl_alloc_genericmemory", - "ijl_alloc_genericmemory", - "jl_alloc_array_1d", - "jl_alloc_array_2d", - "jl_alloc_array_3d", - "ijl_alloc_array_1d", - "ijl_alloc_array_2d", - "ijl_alloc_array_3d", - "jl_array_copy", - "ijl_array_copy", - "jl_genericmemory_copy_slice", - "ijl_genericmemory_copy_slice", - "jl_alloc_genericmemory", - "ijl_alloc_genericmemory", - "jl_idtable_rehash", - "ijl_idtable_rehash", - "jl_f_tuple", - "ijl_f_tuple", - "jl_new_structv", - "ijl_new_structv", - "ijl_new_array", - "jl_new_array", - ) - if haskey(funcs, fname) - for fn in funcs[fname] - push!(return_attributes(fn), LLVM.EnumAttribute("noalias", 0)) - push!(return_attributes(fn), LLVM.EnumAttribute("nonnull", 0)) - push!(function_attributes(fn), no_escaping_alloc) - push!(function_attributes(fn), LLVM.EnumAttribute("mustprogress")) - push!(function_attributes(fn), LLVM.EnumAttribute("willreturn")) - push!(function_attributes(fn), LLVM.EnumAttribute("nounwind")) - push!(function_attributes(fn), LLVM.EnumAttribute("nofree")) - accattr = if LLVM.version().major <= 15 - LLVM.EnumAttribute("inaccessiblememonly") - else - if fname in ( - "jl_genericmemory_copy_slice", - "ijl_genericmemory_copy_slice",) - EnumAttribute( - "memory", - MemoryEffect( - (MRI_Ref << getLocationPos(ArgMem)) | - (MRI_ModRef << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ) - else - EnumAttribute( - "memory", - MemoryEffect( - (MRI_NoModRef << getLocationPos(ArgMem)) | - (MRI_ModRef << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ) - end - end - if !( - fname in ( - "jl_array_copy", - "ijl_array_copy", - "jl_idtable_rehash", - "ijl_idtable_rehash", - ) - ) - push!(function_attributes(fn), accattr) - end - for u in LLVM.uses(fn) - c = LLVM.user(u) - if !isa(c, LLVM.CallInst) - continue - end - cf = LLVM.called_operand(c) - if cf == fn - LLVM.API.LLVMAddCallSiteAttribute( - c, - LLVM.API.LLVMAttributeReturnIndex, - LLVM.EnumAttribute("noalias", 0), - ) - if !( - fname in ( - "jl_array_copy", - "ijl_array_copy", - "jl_idtable_rehash", - "ijl_idtable_rehash", - ) - ) - LLVM.API.LLVMAddCallSiteAttribute( - c, - reinterpret( - LLVM.API.LLVMAttributeIndex, - LLVM.API.LLVMAttributeFunctionIndex, - ), - accattr, - ) - end - end - if !isa(cf, LLVM.Function) - continue - end - if !(cf == fn || - ((LLVM.name(cf) == "julia.call" || LLVM.name(cf) != "julia.call2") && operands(c)[1] == fn)) - continue - end - LLVM.API.LLVMAddCallSiteAttribute( - c, - LLVM.API.LLVMAttributeReturnIndex, - LLVM.EnumAttribute("noalias", 0), - ) - LLVM.API.LLVMAddCallSiteAttribute( - c, - reinterpret( - LLVM.API.LLVMAttributeIndex, - LLVM.API.LLVMAttributeFunctionIndex, - ), - no_escaping_alloc, - ) - if !( - fname in ( - "jl_array_copy", - "ijl_array_copy", - "jl_idtable_rehash", - "ijl_idtable_rehash", - ) - ) - LLVM.API.LLVMAddCallSiteAttribute( - c, - reinterpret( - LLVM.API.LLVMAttributeIndex, - LLVM.API.LLVMAttributeFunctionIndex, - ), - accattr, - ) - end - end - end - end - end + # prepare the slots + new_ci.slotnames = Symbol[Symbol("#self#"), :mode, :ft, :tt] + new_ci.slotflags = UInt8[0x00 for i = 1:4] - for fname in ("llvm.julia.gc_preserve_begin", "llvm.julia.gc_preserve_end") - if haskey(funcs, fname) - for fn in funcs[fname] - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly")) - else - push!( - function_attributes(fn), - EnumAttribute( - "memory", - MemoryEffect( - (MRI_NoModRef << getLocationPos(ArgMem)) | - (MRI_ModRef << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ), - ) - end - end - end + # return the codegen world age + res = primal_return_type_world(mode, world, mi) + push!(new_ci.code, Core.Compiler.ReturnNode(res)) + push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` + @static if isdefined(Core, :DebugInfo) + else + push!(new_ci.codelocs, 1) # see note below end + new_ci.ssavaluetypes += 1 - # Key of jl_eqtable_get/put is inactive, definitionally - for fname in ("jl_eqtable_get", "ijl_eqtable_get") - if haskey(funcs, fname) - for fn in funcs[fname] - push!(parameter_attributes(fn, 2), LLVM.StringAttribute("enzyme_inactive")) - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) - push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) - else - push!( - function_attributes(fn), - EnumAttribute( - "memory", - MemoryEffect( - (MRI_Ref << getLocationPos(ArgMem)) | - (MRI_NoModRef << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ), - ) - end - end - end - end - - for fname in ("jl_reshape_array", "ijl_reshape_array") - if haskey(funcs, fname) - for fn in funcs[fname] - push!(parameter_attributes(fn, 3), LLVM.EnumAttribute("readonly")) - push!(parameter_attributes(fn, 3), LLVM.EnumAttribute("nocapture")) - end - end - end - - # Key of jl_eqtable_get/put is inactive, definitionally - for fname in ("jl_eqtable_put", "ijl_eqtable_put") - if haskey(funcs, fname) - for fn in funcs[fname] - push!(parameter_attributes(fn, 2), LLVM.StringAttribute("enzyme_inactive")) - push!(parameter_attributes(fn, 4), LLVM.StringAttribute("enzyme_inactive")) - push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("writeonly")) - push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("nocapture")) - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) - else - push!( - function_attributes(fn), - EnumAttribute( - "memory", - MemoryEffect( - (MRI_ModRef << getLocationPos(ArgMem)) | - (MRI_NoModRef << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ), - ) - end - end - end - end + # NOTE: we keep the first entry of the original linetable, and use it for location info + # on the call to check_cache. we can't not have a codeloc (using 0 causes + # corruption of the back trace), and reusing the target function's info + # has as advantage that we see the name of the kernel in the backtraces. - for fname in ("jl_in_threaded_region_", "jl_in_threaded_region") - if haskey(funcs, fname) - for fn in funcs[fname] - if LLVM.version().major <= 15 - push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) - push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly")) - else - push!( - function_attributes(fn), - EnumAttribute( - "memory", - MemoryEffect( - (MRI_NoModRef << getLocationPos(ArgMem)) | - (MRI_Ref << getLocationPos(InaccessibleMem)) | - (MRI_NoModRef << getLocationPos(Other)), - ).data, - ), - ) - end - end - end - end + return new_ci +end + +@eval Base.@assume_effects :removable :foldable :nothrow @inline function primal_return_type(mode::Mode, ft::Type, tt::Type) + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, primal_return_type_generator)) end +## +# Enzyme compiler step +## + function enzyme_extract_world(fn::LLVM.Function)::UInt for fattr in collect(function_attributes(fn)) if isa(fattr, LLVM.StringAttribute) @@ -7494,185 +5524,7 @@ end LLVM.run!(pm, mod) end if parent_job !== nothing - if parent_job.config.target isa GPUCompiler.PTXCompilerTarget - arg1 = ( - "sin", - "cos", - "tan", - "log2", - "exp", - "exp2", - "exp10", - "cosh", - "sinh", - "tanh", - "atan", - "asin", - "acos", - "log", - "log10", - "log1p", - "acosh", - "asinh", - "atanh", - "expm1", - "cbrt", - "rcbrt", - "j0", - "j1", - "y0", - "y1", - "erf", - "erfinv", - "erfc", - "erfcx", - "erfcinv", - "remquo", - "tgamma", - "round", - "fdim", - "logb", - "isinf", - "sqrt", - "fabs", - "atan2", - ) - # isinf, finite "modf", "fmod", "remainder", - # "rnorm3d", "norm4d", "rnorm4d", "norm", "rnorm", - # "hypot", "rhypot", - # "yn", "jn", "norm3d", "ilogb", powi - # "normcdfinv", "normcdf", "lgamma", "ldexp", "scalbn", "frexp", - # arg1 = ("atan2", "fmax", "pow") - for n in arg1, - (T, pf, lpf) in - ((LLVM.DoubleType(), "", "f64"), (LLVM.FloatType(), "f", "f32")) - - fname = "__nv_" * n * pf - if !haskey(functions(mod), fname) - FT = LLVM.FunctionType(T, [T], vararg = false) - wrapper_f = LLVM.Function(mod, fname, FT) - llname = "llvm." * n * "." * lpf - push!( - function_attributes(wrapper_f), - StringAttribute("implements", llname), - ) - push!( - function_attributes(wrapper_f), - StringAttribute("implements2", n * pf) - ) - end - end - end - if parent_job.config.target isa GPUCompiler.GCNCompilerTarget - arg1 = ( - "acos", - "acosh", - "asin", - "asinh", - "atan2", - "atan", - "atanh", - "cbrt", - "ceil", - "copysign", - "cos", - "native_cos", - "cosh", - "cospi", - "i0", - "i1", - "erfc", - "erfcinv", - "erfcx", - "erf", - "erfinv", - "exp10", - "native_exp10", - "exp2", - "exp", - "native_exp", - "expm1", - "fabs", - "fdim", - "floor", - "fma", - "fmax", - "fmin", - "fmod", - "frexp", - "hypot", - "ilogb", - "isfinite", - "isinf", - "isnan", - "j0", - "j1", - "ldexp", - "lgamma", - "log10", - "native_log10", - "log1p", - "log2", - "log2", - "logb", - "log", - "native_log", - "modf", - "nearbyint", - "nextafter", - "len3", - "len4", - "ncdf", - "ncdfinv", - "pow", - "pown", - "rcbrt", - "remainder", - "remquo", - "rhypot", - "rint", - "rlen3", - "rlen4", - "round", - "rsqrt", - "scalb", - "scalbn", - "signbit", - "sincos", - "sincospi", - "sin", - "native_sin", - "sinh", - "sinpi", - "sqrt", - "native_sqrt", - "tan", - "tanh", - "tgamma", - "trunc", - "y0", - "y1", - ) - for n in arg1, - (T, pf, lpf) in - ((LLVM.DoubleType(), "", "f64"), (LLVM.FloatType(), "f", "f32")) - - fname = "__ocml_" * n * "_" * lpf - if !haskey(functions(mod), fname) - FT = LLVM.FunctionType(T, [T], vararg = false) - wrapper_f = LLVM.Function(mod, fname, FT) - llname = "llvm." * n * "." * lpf - push!( - function_attributes(wrapper_f), - StringAttribute("implements", llname), - ) - push!( - function_attributes(wrapper_f), - StringAttribute("implements2", n * pf) - ) - end - end - end + mark_gpu_intrinsics!(parent_job.config.target, mod) end for (name, fnty) in fnsToInject for (T, JT, pf) in @@ -7888,92 +5740,7 @@ end end end -# Recursively return x + f(y), where y is active, otherwise x - -@inline function recursive_add( - x::T, - y::T, - f::F = identity, - forcelhs::F2 = guaranteed_const, -) where {T,F,F2} - if forcelhs(T) - return x - end - splatnew(T, ntuple(Val(fieldcount(T))) do i - Base.@_inline_meta - prev = getfield(x, i) - next = getfield(y, i) - recursive_add(prev, next, f, forcelhs) - end) -end - -@inline function recursive_add( - x::T, - y::T, - f::F = identity, - forcelhs::F2 = guaranteed_const, -) where {T<:AbstractFloat,F,F2} - if forcelhs(T) - return x - end - return x + f(y) -end - -@inline function recursive_add( - x::T, - y::T, - f::F = identity, - forcelhs::F2 = guaranteed_const, -) where {T<:Complex,F,F2} - if forcelhs(T) - return x - end - return x + f(y) -end - -@inline mutable_register(::Type{T}) where {T<:Integer} = true -@inline mutable_register(::Type{T}) where {T<:AbstractFloat} = false -@inline mutable_register(::Type{Complex{T}}) where {T<:AbstractFloat} = false -@inline mutable_register(::Type{T}) where {T<:Tuple} = false -@inline mutable_register(::Type{T}) where {T<:NamedTuple} = false -@inline mutable_register(::Type{Core.Box}) = true -@inline mutable_register(::Type{T}) where {T<:Array} = true -@inline mutable_register(::Type{T}) where {T} = ismutabletype(T) - -# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) -@inline function recursive_accumulate(x::Array{T}, y::Array{T}, f::F = identity) where {T,F} - if !mutable_register(T) - for I in eachindex(x) - prev = x[I] - @inbounds x[I] = recursive_add(x[I], (@inbounds y[I]), f, mutable_register) - end - end -end - - -# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) -@inline function recursive_accumulate(x::Core.Box, y::Core.Box, f::F = identity) where {F} - recursive_accumulate(x.contents, y.contents, seen, f) -end - -@inline function recursive_accumulate(x::T, y::T, f::F = identity) where {T,F} - @assert !Base.isabstracttype(T) - @assert Base.isconcretetype(T) - nf = fieldcount(T) - - for i = 1:nf - if isdefined(x, i) - xi = getfield(x, i) - ST = Core.Typeof(xi) - if !mutable_register(ST) - @assert ismutable(x) - yi = getfield(y, i) - nexti = recursive_add(xi, yi, f, mutable_register) - setfield!(x, i, nexti) - end - end - end -end +include("typeutils/recursive_add.jl") @inline function default_adjoint(T) if T == Union{} @@ -9064,126 +6831,4 @@ end include("compiler/reflection.jl") -@generated function onehot_internal(fn::F, x::T, startv::Int, lengthv::Int) where {F, T<:Array} - ir = JuliaContext() do ctx - Base.@_inline_meta - - target = Compiler.DefaultCompilerTarget() - params = Compiler.PrimalCompilerParams(API.DEM_ForwardMode) - mi = my_methodinstance(fn, Tuple{T, Int}) - job = CompilerJob(mi, CompilerConfig(target, params; kernel = false)) - mod, meta = GPUCompiler.codegen( - :llvm, - job; - optimize = false, - cleanup = false, - validate = false, - ) - copysetfn = meta.entry - blk = first(blocks(copysetfn)) - iter = LLVM.API.LLVMGetFirstInstruction(blk) - while iter != C_NULL - inst = LLVM.Instruction(iter) - iter = LLVM.API.LLVMGetNextInstruction(iter) - if isa(inst, LLVM.FenceInst) - eraseInst(blk, inst) - end - if isa(inst, LLVM.CallInst) - fn = LLVM.called_operand(inst) - if isa(fn, LLVM.Function) - if LLVM.name(fn) == "julia.safepoint" - eraseInst(blk, inst) - end - end - end - end - hasNoRet = any( - map( - k -> kind(k) == kind(EnumAttribute("noreturn")), - collect(function_attributes(copysetfn)), - ), - ) - @assert !hasNoRet - if !hasNoRet - push!(function_attributes(copysetfn), EnumAttribute("alwaysinline", 0)) - end - ity = convert(LLVMType, Int) - jlvaluet = convert(LLVMType, T; allow_boxed=true) - - FT = LLVM.FunctionType(jlvaluet, LLVMType[jlvaluet, ity, ity]) - llvm_f = LLVM.Function(mod, "f", FT) - push!(function_attributes(llvm_f), EnumAttribute("alwaysinline", 0)) - - # Check if Julia version has https://github.com/JuliaLang/julia/pull/46914 - # and also https://github.com/JuliaLang/julia/pull/47076 - # and also https://github.com/JuliaLang/julia/pull/48620 - needs_dynamic_size_workaround = !(VERSION >= v"1.10.5") - - builder = LLVM.IRBuilder() - entry = BasicBlock(llvm_f, "entry") - position!(builder, entry) - inp, lstart, len = collect(LLVM.Value, parameters(llvm_f)) - - boxed_count = if sizeof(Int) == sizeof(Int64) - emit_box_int64!(builder, len) - else - emit_box_int32!(builder, len) - end - - tag = emit_apply_type!(builder, NTuple, LLVM.Value[boxed_count, unsafe_to_llvm(builder, T)]) - - fullsize = nuwmul!(builder, len, LLVM.ConstantInt(sizeof(Int))) - obj = emit_allocobj!(builder, tag, fullsize, needs_dynamic_size_workaround) - - T_int8 = LLVM.Int8Type() - LLVM.memset!(builder, obj, LLVM.ConstantInt(T_int8, 0), fullsize, 0) - - alloc = pointercast!(builder, obj, LLVM.PointerType(jlvaluet, Tracked)) - alloc = pointercast!(builder, alloc, LLVM.PointerType(jlvaluet, 11)) - - loop = BasicBlock(llvm_f, "loop") - exit = BasicBlock(llvm_f, "exit") - - br!(builder, icmp!(builder, LLVM.API.LLVMIntEQ, LLVM.ConstantInt(0), len), exit, loop) - - position!(builder, loop) - idx = phi!(builder, ity) - - push!(LLVM.incoming(idx), (LLVM.ConstantInt(0), entry)) - inc = add!(builder, idx, LLVM.ConstantInt(1)) - push!(LLVM.incoming(idx), (inc, loop)) - rval = add!(builder, inc, lstart) - res = call!(builder, LLVM.function_type(copysetfn), copysetfn, [inp, rval]) - if !hasNoRet - gidx = gep!(builder, jlvaluet, alloc, [idx]) - store!(builder, res, gidx) - emit_writebarrier!(builder, get_julia_inner_types(builder, obj, res)) - end - - br!(builder, icmp!(builder, LLVM.API.LLVMIntEQ, inc, len), exit, loop) - - - T_int32 = LLVM.Int32Type() - - reinsert_gcmarker!(llvm_f) - - position!(builder, exit) - ret!(builder, obj) - - string(mod) - end - return quote - Base.@_inline_meta - Base.llvmcall( - ($ir, "f"), - Tuple{Vararg{T}}, - Tuple{T, Int, Int}, - x, - startv, - lengthv - ) - end -end - - end diff --git a/src/errors.jl b/src/errors.jl new file mode 100644 index 0000000000..c6dd78b781 --- /dev/null +++ b/src/errors.jl @@ -0,0 +1,640 @@ +abstract type CompilationException <: Base.Exception end +struct NoDerivativeException <: CompilationException + msg::String + ir::Union{Nothing,String} + bt::Union{Nothing,Vector{StackTraces.StackFrame}} +end + +function Base.showerror(io::IO, ece::NoDerivativeException) + print(io, "Enzyme compilation failed.\n") + if ece.ir !== nothing + print(io, "Current scope: \n") + print(io, ece.ir) + end + print(io, '\n', ece.msg, '\n') + if ece.bt !== nothing + Base.show_backtrace(io, ece.bt) + println(io) + end +end + +struct IllegalTypeAnalysisException <: CompilationException + msg::String + sval::String + ir::Union{Nothing,String} + bt::Union{Nothing,Vector{StackTraces.StackFrame}} +end + +function Base.showerror(io::IO, ece::IllegalTypeAnalysisException) + print(io, "Enzyme compilation failed due to illegal type analysis.\n") + if ece.ir !== nothing + print(io, "Current scope: \n") + print(io, ece.ir) + end + print(io, "\n Type analysis state: \n") + write(io, ece.sval) + print(io, '\n', ece.msg, '\n') + if ece.bt !== nothing + print(io, "\nCaused by:") + Base.show_backtrace(io, ece.bt) + println(io) + end +end + +struct IllegalFirstPointerException <: CompilationException + msg::String + ir::Union{Nothing,String} + bt::Union{Nothing,Vector{StackTraces.StackFrame}} +end + +function Base.showerror(io::IO, ece::IllegalFirstPointerException) + print(io, "Enzyme compilation failed.\n") + if ece.ir !== nothing + print(io, "Current scope: \n") + print(io, ece.ir) + end + print(io, '\n', ece.msg, '\n') + if ece.bt !== nothing + Base.show_backtrace(io, ece.bt) + println(io) + end +end + +struct EnzymeInternalError <: CompilationException + msg::String + ir::Union{Nothing,String} + bt::Union{Nothing,Vector{StackTraces.StackFrame}} +end + +function Base.showerror(io::IO, ece::EnzymeInternalError) + print(io, "Enzyme compilation failed.\n") + if ece.ir !== nothing + print(io, "Current scope: \n") + print(io, ece.ir) + end + print(io, '\n', ece.msg, '\n') + if ece.bt !== nothing + Base.show_backtrace(io, ece.bt) + println(io) + end +end + +parent_scope(val::LLVM.Function, depth = 0) = depth == 0 ? LLVM.parent(val) : val +parent_scope(val::LLVM.Module, depth = 0) = val +parent_scope(@nospecialize(val::LLVM.Value), depth = 0) = parent_scope(LLVM.parent(val), depth + 1) +parent_scope(val::LLVM.Argument, depth = 0) = + parent_scope(LLVM.Function(LLVM.API.LLVMGetParamParent(val)), depth + 1) + +function julia_error( + cstr::Cstring, + val::LLVM.API.LLVMValueRef, + errtype::API.ErrorType, + data::Ptr{Cvoid}, + data2::LLVM.API.LLVMValueRef, + B::LLVM.API.LLVMBuilderRef, +)::LLVM.API.LLVMValueRef + msg = Base.unsafe_string(cstr) + bt = nothing + ir = nothing + if val != C_NULL + val = LLVM.Value(val) + if isa(val, LLVM.Instruction) + dbgval = val + while !haskey(metadata(dbgval), LLVM.MD_dbg) + dbgval = LLVM.API.LLVMGetNextInstruction(dbgval) + if dbgval == C_NULL + dbgval = nothing + break + else + dbgval = LLVM.Instruction(dbgval) + end + end + if dbgval !== nothing + bt = GPUCompiler.backtrace(dbgval) + end + end + if isa(val, LLVM.ConstantExpr) + for u in LLVM.uses(val) + u = LLVM.user(u) + if isa(u, LLVM.Instruction) + bt = GPUCompiler.backtrace(val) + end + end + else + # Need to convert function to string, since when the error is going to be printed + # the module might have been destroyed + ir = string(parent_scope(val)) + end + end + + if errtype == API.ET_NoDerivative + if occursin("No create nofree of empty function", msg) || + occursin("No forward mode derivative found for", msg) || + occursin("No augmented forward pass", msg) || + occursin("No reverse pass found", msg) + ir = nothing + end + if B != C_NULL + B = IRBuilder(B) + msg2 = sprint() do io + if ir !== nothing + print(io, "Current scope: \n") + print(io, ir) + end + print(io, '\n', msg, '\n') + if bt !== nothing + Base.show_backtrace(io, bt) + println(io) + end + end + emit_error(B, nothing, msg2, EnzymeNoDerivativeError) + return C_NULL + end + throw(NoDerivativeException(msg, ir, bt)) + elseif errtype == API.ET_NoShadow + gutils = GradientUtils(API.EnzymeGradientUtilsRef(data)) + + msgN = sprint() do io::IO + if isa(val, LLVM.Argument) + fn = parent_scope(val) + ir = string(LLVM.name(fn)) * string(function_type(fn)) + print(io, "Current scope: \n") + print(io, ir) + end + if !isa(val, LLVM.Argument) + print(io, "\n Inverted pointers: \n") + ip = API.EnzymeGradientUtilsInvertedPointersToString(gutils) + sval = Base.unsafe_string(ip) + write(io, sval) + API.EnzymeStringFree(ip) + end + print(io, '\n', msg, '\n') + if bt !== nothing + print(io, "\nCaused by:") + Base.show_backtrace(io, bt) + println(io) + end + end + emit_error(IRBuilder(B), nothing, msgN, EnzymeNoShadowError) + return LLVM.null(get_shadow_type(gutils, value_type(val))).ref + elseif errtype == API.ET_IllegalTypeAnalysis + data = API.EnzymeTypeAnalyzerRef(data) + ip = API.EnzymeTypeAnalyzerToString(data) + sval = Base.unsafe_string(ip) + API.EnzymeStringFree(ip) + + if isa(val, LLVM.Instruction) + mi, rt = enzyme_custom_extract_mi( + LLVM.parent(LLVM.parent(val))::LLVM.Function, + false, + ) #=error=# + if mi !== nothing + msg *= "\n" * string(mi) * "\n" + end + end + throw(IllegalTypeAnalysisException(msg, sval, ir, bt)) + elseif errtype == API.ET_NoType + @assert B != C_NULL + B = IRBuilder(B) + + data = API.EnzymeTypeAnalyzerRef(data) + ip = API.EnzymeTypeAnalyzerToString(data) + sval = Base.unsafe_string(ip) + API.EnzymeStringFree(ip) + + msg2 = sprint() do io::IO + if !occursin("Cannot deduce single type of store", msg) + if ir !== nothing + print(io, "Current scope: \n") + print(io, ir) + end + print(io, "\n Type analysis state: \n") + write(io, sval) + end + print(io, '\n', msg, '\n') + if bt !== nothing + print(io, "\nCaused by:") + Base.show_backtrace(io, bt) + println(io) + end + pscope = parent_scope(val) + mi, rt = enzyme_custom_extract_mi(pscope, false) #=error=# + if mi !== nothing + println(io, "within ", mi) + end + end + emit_error(B, nothing, msg2, EnzymeNoTypeError) + return C_NULL + elseif errtype == API.ET_IllegalFirstPointer + throw(IllegalFirstPointerException(msg, ir, bt)) + elseif errtype == API.ET_InternalError + throw(EnzymeInternalError(msg, ir, bt)) + elseif errtype == API.ET_TypeDepthExceeded + msg2 = sprint() do io + print(io, msg) + println(io) + + if val != C_NULL + println(io, val) + end + + st = API.EnzymeTypeTreeToString(data) + println(io, Base.unsafe_string(st)) + API.EnzymeStringFree(st) + + if bt !== nothing + Base.show_backtrace(io, bt) + end + end + GPUCompiler.@safe_warn msg2 + return C_NULL + elseif errtype == API.ET_IllegalReplaceFicticiousPHIs + data2 = LLVM.Value(data2) + msg2 = sprint() do io + print(io, msg) + println(io) + println(io, string(LLVM.parent(LLVM.parent(data2)))) + println(io, val) + println(io, data2) + end + throw(EnzymeInternalError(msg2, ir, bt)) + elseif errtype == API.ET_MixedActivityError + data2 = LLVM.Value(data2) + badval = nothing + gutils = GradientUtils(API.EnzymeGradientUtilsRef(data)) + # Ignore mismatched activity if phi/store of ghost + seen = Dict{LLVM.Value,LLVM.Value}() + illegal = false + created = LLVM.Instruction[] + world = enzyme_extract_world(LLVM.parent(position(IRBuilder(B)))) + width = get_width(gutils) + function make_batched(@nospecialize(cur::LLVM.Value), B::LLVM.IRBuilder)::LLVM.Value + if width == 1 + return cur + else + shadowres = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))), + ) + for idx = 1:width + shadowres = insert_value!(B, shadowres, cur, idx - 1) + if isa(shadowres, LLVM.Instruction) + push!(created, shadowres) + end + end + return shadowres + end + end + + illegalVal = nothing + + function make_replacement(@nospecialize(cur::LLVM.Value), prevbb::LLVM.IRBuilder)::LLVM.Value + ncur = new_from_original(gutils, cur) + if cur in keys(seen) + return seen[cur] + end + + legal, TT, byref = abs_typeof(cur, true) + + if legal + if guaranteed_const_nongen(TT, world) + return make_batched(ncur, prevbb) + end + + legal2, obj = absint(cur) + + # Only do so for the immediate operand/etc to a phi, since otherwise we will make multiple + if legal2 + if active_reg_inner(TT, (), world) == ActiveState && + isa(cur, LLVM.ConstantExpr) && + cur == data2 + if width == 1 + res = emit_allocobj!(prevbb, Base.RefValue{TT}) + push!(created, res) + return res + else + shadowres = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))), + ) + for idx = 1:width + res = emit_allocobj!(prevbb, Base.RefValue{TT}) + shadowres = insert_value!(prevbb, shadowres, res, idx - 1) + push!(created, shadowres) + end + return shadowres + end + end + +@static if VERSION < v"1.11-" +else + if obj isa Memory && obj == typeof(obj).instance + return make_batched(ncur, prevbb) + end +end + end + +@static if VERSION < v"1.11-" +else + if isa(cur, LLVM.LoadInst) + larg, off = get_base_and_offset(operands(cur)[1]) + if isa(larg, LLVM.LoadInst) + legal2, obj = absint(larg) + if legal2 && obj isa Memory && obj == typeof(obj).instance + return make_batched(ncur, prevbb) + end + end + end +end + + badval = if legal2 + string(obj) * " of type" * " " * string(TT) + else + "Unknown object of type" * " " * string(TT) + end + @assert !illegal + illegalVal = cur + illegal = true + return make_batched(ncur, prevbb) + end + + if isa(cur, LLVM.PointerNull) + return make_batched(ncur, prevbb) + end + if isa(cur, LLVM.UndefValue) + return make_batched(ncur, prevbb) + end + if isa(cur, LLVM.PoisonValue) + return make_batched(ncur, prevbb) + end + if isa(cur, LLVM.ConstantAggregateZero) + return make_batched(ncur, prevbb) + end + if isa(cur, LLVM.ConstantAggregate) + return make_batched(ncur, prevbb) + end + if isa(cur, LLVM.ConstantInt) + if convert(UInt64, cur) == 0 + return make_batched(ncur, prevbb) + end + end + if isa(cur, LLVM.ConstantFP) + return make_batched(ConstantFP(value_type(cur), 0), prevbb) + end + if isa(cur, LLVM.ConstantDataSequential) + cvals = LLVM.Value[] + changed = false + for v in collect(cur) + tmp = make_replacement(v, prevbb) + if illegal + return ncur + end + if v != tmp + changed = true + end + push!(cvals, tmp) + end + + cur2 = if changed + @assert !illegal + illegalVal = cur + illegal = true + # TODO replace with correct insertions/splats + ncur + else + make_batched(ncur, prevbb) + end + return cur2 + end + if isa(cur, LLVM.ConstantInt) + if LLVM.width(value_type(cur)) <= sizeof(Int) * 8 + return make_batched(ncur, prevbb) + end + if LLVM.width(value_type(cur)) == sizeof(Int) * 8 && + abs(convert(Int, cur)) < 10000 + return make_batched(ncur, prevbb) + end + # if storing a constant int as a non-pointer, presume it is not a GC'd var and is safe + # for activity state to mix + if isa(val, LLVM.StoreInst) + operands(val)[1] == cur && + !isa(value_type(operands(val)[1]), LLVM.PointerType) + return make_batched(ncur, prevbb) + end + end + + if isa(cur, LLVM.SelectInst) + lhs = make_replacement(operands(cur)[2], prevbb) + if illegal + return ncur + end + rhs = make_replacement(operands(cur)[3], prevbb) + if illegal + return ncur + end + if lhs == operands(cur)[2] && rhs == operands(cur)[3] + return make_batched(ncur, prevbb) + end + if width == 1 + nv = select!( + prevbb, + new_from_original(gutils, operands(cur)[1]), + lhs, + rhs, + ) + push!(created, nv) + seen[cur] = nv + return nv + else + shadowres = LLVM.UndefValue(value_type(lhs)) + for idx = 1:width + shadowres = insert_value!( + prevbb, + shadowres, + select!( + prevbb, + new_from_original(gutils, operands(cur)[1]), + extract_value!(prevbb, lhs, idx - 1), + extract_value!(prevbb, rhs, idx - 1), + ), + idx - 1, + ) + if isa(shadowres, LLVM.Instruction) + push!(created, shadowres) + end + end + return shadowres + end + end + + if isa(cur, LLVM.InsertValueInst) + lhs = make_replacement(operands(cur)[1], prevbb) + if illegal + return ncur + end + rhs = make_replacement(operands(cur)[2], prevbb) + if illegal + return ncur + end + if lhs == operands(cur)[1] && rhs == operands(cur)[2] + return make_batched(ncur, prevbb) + end + inds = LLVM.API.LLVMGetIndices(cur.ref) + ninds = LLVM.API.LLVMGetNumIndices(cur.ref) + jinds = Cuint[unsafe_load(inds, i) for i = 1:ninds] + if width == 1 + nv = API.EnzymeInsertValue(prevbb, lhs, rhs, jinds) + push!(created, nv) + seen[cur] = nv + return nv + else + shadowres = lhs + for idx = 1:width + jindsv = copy(jinds) + pushfirst!(jindsv, idx - 1) + shadowres = API.EnzymeInsertValue( + prevbb, + shadowres, + extract_value!(prevbb, rhs, idx - 1), + jindsv, + ) + if isa(shadowres, LLVM.Instruction) + push!(created, shadowres) + end + end + return shadowres + end + end + + if isa(cur, LLVM.LoadInst) || isa(cur, LLVM.BitCastInst) || isa(cur, LLVM.AddrSpaceCastInst) || (isa(cur, LLVM.GetElementPtrInst) && all(Base.Fix2(isa, LLVM.ConstantInt), operands(cur)[2:end])) + lhs = make_replacement(operands(cur)[1], prevbb) + if illegal + return ncur + end + if lhs == operands(ncur)[1] + return make_batched(ncur, prevbb) + elseif width != 1 && isa(lhs, LLVM.InsertValueInst) && operands(lhs)[2] == operands(ncur)[1] + return make_batched(ncur, prevbb) + end + end + + if isa(cur, LLVM.PHIInst) + Bphi = IRBuilder() + position!(Bphi, ncur) + shadowty = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))) + phi2 = phi!(Bphi, shadowty, "tempphi" * LLVM.name(cur)) + seen[cur] = phi2 + changed = false + recsize = length(created) + 1 + for (v, bb) in LLVM.incoming(cur) + B2 = IRBuilder() + position!(B2, new_from_original(gutils, last(instructions(bb)))) + tmp = make_replacement(v, B2) + if illegal + changed = true + break + end + @assert value_type(tmp) == shadowty + if tmp != new_from_original(gutils, v) && v != cur + changed = true + end + push!(LLVM.incoming(phi2), (tmp, new_from_original(gutils, bb))) + end + if !changed || illegal + LLVM.API.LLVMInstructionEraseFromParent(phi2) + seen[cur] = ncur + plen = length(created) + for i = recsize:plen + u = created[i] + replace_uses!(u, LLVM.UndefValue(value_type(u))) + end + for i = recsize:plen + u = created[i] + LLVM.API.LLVMInstructionEraseFromParent(u) + end + for i = recsize:plen + pop!(created) + end + return illegal ? ncur : make_batched(ncur, prevbb) + end + push!(created, phi2) + return phi2 + end + + tt = TypeTree(API.EnzymeGradientUtilsAllocAndGetTypeTree(gutils, cur)) + st = API.EnzymeTypeTreeToString(tt) + st2 = Base.unsafe_string(st) + API.EnzymeStringFree(st) + if st2 == "{[-1]:Integer}" + return make_batched(ncur, prevbb) + end + + if !illegal + illegal = true + illegalVal = cur + end + return ncur + end + + b = IRBuilder(B) + replacement = make_replacement(data2, b) + + if !illegal + return replacement.ref + end + for u in created + replace_uses!(u, LLVM.UndefValue(value_type(u))) + end + for u in created + LLVM.API.LLVMInstructionEraseFromParent(u) + end + if LLVM.API.LLVMIsAReturnInst(val) != C_NULL + mi, rt = enzyme_custom_extract_mi( + LLVM.parent(LLVM.parent(val))::LLVM.Function, + false, + ) #=error=# + if mi !== nothing && isghostty(rt) + return C_NULL + end + end + msg2 = sprint() do io + print(io, msg) + println(io) + if badval !== nothing + println(io, " value=" * badval) + else + ttval = val + if isa(ttval, LLVM.StoreInst) + ttval = operands(ttval)[1] + end + tt = TypeTree(API.EnzymeGradientUtilsAllocAndGetTypeTree(gutils, ttval)) + st = API.EnzymeTypeTreeToString(tt) + print(io, "Type tree: ") + println(io, Base.unsafe_string(st)) + API.EnzymeStringFree(st) + end + if illegalVal !== nothing + println(io, " llvalue=" * string(illegalVal)) + end + if bt !== nothing + Base.show_backtrace(io, bt) + end + end + emit_error(b, nothing, msg2, EnzymeRuntimeActivityError) + return C_NULL + elseif errtype == API.ET_GetIndexError + @assert B != C_NULL + B = IRBuilder(B) + msg5 = sprint() do io::IO + print(io, "Enzyme internal error\n") + print(io, msg, '\n') + if bt !== nothing + print(io, "\nCaused by:") + Base.show_backtrace(io, bt) + println(io) + end + end + emit_error(B, nothing, msg5) + return C_NULL + end + throw(AssertionError("Unknown errtype")) +end + diff --git a/src/llvm/attributes.jl b/src/llvm/attributes.jl new file mode 100644 index 0000000000..3dd5973421 --- /dev/null +++ b/src/llvm/attributes.jl @@ -0,0 +1,1060 @@ +const nofreefns = Set{String}(( + "ClientGetDevice", + "BufferOnCPU", + "pcre2_match_8", + "julia.gcroot_flush", + "pcre2_jit_stack_assign_8", + "pcre2_match_context_create_8", + "pcre2_jit_stack_create_8", + "ijl_gc_enable_finalizers_internal", + "jl_gc_enable_finalizers_internal", + "pcre2_match_data_create_from_pattern_8", + "ijl_gc_run_pending_finalizers", + "jl_gc_run_pending_finalizers", + "ijl_typeassert", + "jl_typeassert", + "ijl_f_isdefined", + "jl_f_isdefined", + "ijl_field_index", + "jl_field_index", + "ijl_specializations_get_linfo", + "jl_specializations_get_linfo", + "ijl_gf_invoke_lookup_worlds", + "jl_gf_invoke_lookup_worlds", + "ijl_gc_get_total_bytes", + "jl_gc_get_total_bytes", + "ijl_array_grow_at", + "jl_array_grow_at", + "ijl_try_substrtod", + "jl_try_substrtod", + "jl_f__apply_iterate", + "ijl_field_index", + "jl_field_index", + "julia.call", + "julia.call2", + "ijl_tagged_gensym", + "jl_tagged_gensym", + "ijl_array_ptr_copy", + "jl_array_ptr_copy", + "ijl_array_copy", + "jl_array_copy", + "ijl_genericmemory_copy_slice", + "jl_genericmemory_copy_slice", + "ijl_get_nth_field_checked", + "ijl_get_nth_field_checked", + "jl_array_del_end", + "ijl_array_del_end", + "jl_get_world_counter", + "ijl_get_world_counter", + "memhash32_seed", + "memhash_seed", + "ijl_module_parent", + "jl_module_parent", + "julia.safepoint", + "ijl_set_task_tid", + "jl_set_task_tid", + "ijl_get_task_tid", + "jl_get_task_tid", + "julia.get_pgcstack_or_new", + "ijl_global_event_loop", + "jl_global_event_loop", + "ijl_gf_invoke_lookup", + "jl_gf_invoke_lookup", + "ijl_f_typeassert", + "jl_f_typeassert", + "ijl_type_unionall", + "jl_type_unionall", + "jl_gc_queue_root", + "gpu_report_exception", + "gpu_signal_exception", + "julia.ptls_states", + "julia.write_barrier", + "julia.typeof", + "jl_backtrace_from_here", + "ijl_backtrace_from_here", + "jl_box_int64", + "jl_box_int32", + "ijl_box_int64", + "ijl_box_int32", + "jl_box_uint64", + "jl_box_uint32", + "ijl_box_uint64", + "ijl_box_uint32", + "ijl_box_char", + "jl_box_char", + "ijl_subtype", + "jl_subtype", + "julia.get_pgcstack", + "jl_in_threaded_region", + "jl_object_id_", + "jl_object_id", + "ijl_object_id_", + "ijl_object_id", + "jl_breakpoint", + "llvm.julia.gc_preserve_begin", + "llvm.julia.gc_preserve_end", + "jl_get_ptls_states", + "ijl_get_ptls_states", + "jl_f_fieldtype", + "jl_symbol_n", + "jl_stored_inline", + "ijl_stored_inline", + "jl_f_apply_type", + "jl_f_issubtype", + "jl_isa", + "ijl_isa", + "jl_matching_methods", + "ijl_matching_methods", + "jl_excstack_state", + "ijl_excstack_state", + "jl_current_exception", + "ijl_current_exception", + "memhash_seed", + "jl_f__typevar", + "ijl_f__typevar", + "jl_f_isa", + "ijl_f_isa", + "jl_set_task_threadpoolid", + "ijl_set_task_threadpoolid", + "jl_types_equal", + "ijl_types_equal", + "jl_invoke", + "ijl_invoke", + "jl_apply_generic", + "ijl_apply_generic", + "jl_egal__unboxed", + "julia.pointer_from_objref", + "_platform_memcmp", + "memcmp", + "julia.except_enter", + "jl_array_grow_end", + "ijl_array_grow_end", + "jl_f_getfield", + "ijl_f_getfield", + "jl_pop_handler", + "ijl_pop_handler", + "jl_pop_handler_noexcept", + "ijl_pop_handler_noexcept", + "jl_string_to_array", + "ijl_string_to_array", + "jl_alloc_string", + "ijl_alloc_string", + "getenv", + "jl_cstr_to_string", + "ijl_cstr_to_string", + "jl_symbol_n", + "ijl_symbol_n", + "uv_os_homedir", + "jl_array_to_string", + "ijl_array_to_string", + "pcre2_jit_compile_8", + "memmove", +)) + +const inactivefns = Set{String}(( + "ClientGetDevice", + "BufferOnCPU", + "pcre2_match_data_create_from_pattern_8", + "ijl_typeassert", + "jl_typeassert", + "ijl_f_isdefined", + "jl_f_isdefined", + "ijl_field_index", + "jl_field_index", + "ijl_specializations_get_linfo", + "jl_specializations_get_linfo", + "ijl_gf_invoke_lookup_worlds", + "jl_gf_invoke_lookup_worlds", + "ijl_gc_get_total_bytes", + "jl_gc_get_total_bytes", + "ijl_try_substrtod", + "jl_try_substrtod", + "ijl_tagged_gensym", + "jl_tagged_gensym", + "jl_get_world_counter", + "ijl_get_world_counter", + "memhash32_seed", + "memhash_seed", + "ijl_module_parent", + "jl_module_parent", + "julia.safepoint", + "ijl_set_task_tid", + "jl_set_task_tid", + "ijl_get_task_tid", + "jl_get_task_tid", + "julia.get_pgcstack_or_new", + "ijl_global_event_loop", + "jl_global_event_loop", + "ijl_gf_invoke_lookup", + "jl_gf_invoke_lookup", + "ijl_f_typeassert", + "jl_f_typeassert", + "ijl_type_unionall", + "jl_type_unionall", + "jl_gc_queue_root", + "gpu_report_exception", + "gpu_signal_exception", + "julia.ptls_states", + "julia.write_barrier", + "julia.typeof", + "jl_backtrace_from_here", + "ijl_backtrace_from_here", + "jl_box_int64", + "jl_box_int32", + "ijl_box_int64", + "ijl_box_int32", + "jl_box_uint64", + "jl_box_uint32", + "ijl_box_uint64", + "ijl_box_uint32", + "ijl_box_char", + "jl_box_char", + "ijl_subtype", + "jl_subtype", + "julia.get_pgcstack", + "jl_in_threaded_region", + "jl_object_id_", + "jl_object_id", + "ijl_object_id_", + "ijl_object_id", + "jl_breakpoint", + "llvm.julia.gc_preserve_begin", + "llvm.julia.gc_preserve_end", + "jl_get_ptls_states", + "ijl_get_ptls_states", + "jl_f_fieldtype", + "jl_symbol_n", + "jl_stored_inline", + "ijl_stored_inline", + "jl_f_apply_type", + "jl_f_issubtype", + "jl_isa", + "ijl_isa", + "jl_matching_methods", + "ijl_matching_methods", + "jl_excstack_state", + "ijl_excstack_state", + "jl_current_exception", + "ijl_current_exception", + "memhash_seed", + "jl_f__typevar", + "ijl_f__typevar", + "jl_f_isa", + "ijl_f_isa", + "jl_set_task_threadpoolid", + "ijl_set_task_threadpoolid", + "jl_types_equal", + "ijl_types_equal", + "jl_string_to_array", + "ijl_string_to_array", + "jl_alloc_string", + "ijl_alloc_string", + "getenv", + "jl_cstr_to_string", + "ijl_cstr_to_string", + "jl_symbol_n", + "ijl_symbol_n", + "uv_os_homedir", + "jl_array_to_string", + "ijl_array_to_string", + "pcre2_jit_compile_8", + # "jl_" +)) + +const activefns = Set{String}(("jl_",)) + +const inactiveglobs = Set{String}(( + "ijl_boxed_uint8_cache", + "jl_boxed_uint8_cache", + "ijl_boxed_int8_cache", + "jl_boxed_int8_cache", + "jl_nothing", +)) + +function annotate!(mod::LLVM.Module) + inactive = LLVM.StringAttribute("enzyme_inactive", "") + active = LLVM.StringAttribute("enzyme_active", "") + no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") + + funcs = Dict{String, Vector{LLVM.Function}}() + for f in functions(mod) + fname = LLVM.name(f) + for fattr in collect(function_attributes(f)) + if isa(fattr, LLVM.StringAttribute) + if kind(fattr) == "enzyme_math" + fname = LLVM.value(fattr) + break + end + end + end + fname = String(fname) + if !haskey(funcs, fname) + funcs[fname] = LLVM.Function[] + end + push!(funcs[String(fname)], f) + API.EnzymeAttributeKnownFunctions(f.ref) + end + + for gname in inactiveglobs + globs = LLVM.globals(mod) + if haskey(globs, gname) + glob = globs[gname] + API.SetMD(glob, "enzyme_inactive", LLVM.MDNode(LLVM.Metadata[])) + end + end + + for fname in inactivefns + if haskey(funcs, fname) + for fn in funcs[fname] + push!(function_attributes(fn), inactive) + push!(function_attributes(fn), no_escaping_alloc) + for u in LLVM.uses(fn) + c = LLVM.user(u) + if !isa(c, LLVM.CallInst) + continue + end + cf = LLVM.called_operand(c) + if !isa(cf, LLVM.Function) + continue + end + if LLVM.name(cf) != "julia.call" && LLVM.name(cf) != "julia.call2" + continue + end + if operands(c)[1] != fn + continue + end + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + inactive, + ) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + no_escaping_alloc, + ) + end + end + end + end + + for fname in nofreefns + if haskey(funcs, fname) + for fn in funcs[fname] + push!(function_attributes(fn), LLVM.EnumAttribute("nofree", 0)) + for u in LLVM.uses(fn) + c = LLVM.user(u) + if !isa(c, LLVM.CallInst) + continue + end + cf = LLVM.called_operand(c) + if !isa(cf, LLVM.Function) + continue + end + if LLVM.name(cf) != "julia.call" && LLVM.name(cf) != "julia.call2" + continue + end + if operands(c)[1] != fn + continue + end + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + LLVM.EnumAttribute("nofree", 0), + ) + end + end + end + end + + for fname in activefns + if haskey(funcs, fname) + for fn in funcs[fname] + push!(function_attributes(fn), active) + end + end + end + + for fname in + ("julia.typeof", "jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id") + if haskey(funcs, fname) + for fn in funcs[fname] + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + end + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) + end + end + end + for fname in ("julia.typeof",) + if haskey(funcs, fname) + for fn in funcs[fname] + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_nocache")) + push!(parameter_attributes(fn, 1), LLVM.EnumAttribute("nocapture")) + end + end + end + + for fname in + ("jl_excstack_state", "ijl_excstack_state", "ijl_field_index", "jl_field_index") + if haskey(funcs, fname) + for fn in funcs[fname] + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) + push!(function_attributes(fn), LLVM.StringAttribute("inaccessiblememonly")) + else + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_Ref << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) + end + end + end + end + + for fname in ("jl_types_equal", "ijl_types_equal") + if haskey(funcs, fname) + for fn in funcs[fname] + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) + end + end + end + + for fname in ( + "UnsafeBufferPointer", + ) + if haskey(funcs, fname) + for fn in funcs[fname] + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_math", "__dynamic_cast")) + end + end + end + end + + for fname in ( + "jl_f_getfield", + "ijl_f_getfield", + "jl_get_nth_field_checked", + "ijl_get_nth_field_checked", + "jl_f__svec_ref", + "ijl_f__svec_ref", + "UnsafeBufferPointer" + ) + if haskey(funcs, fname) + for fn in funcs[fname] + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) + else + push!(function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) + ) + end + for u in LLVM.uses(fn) + c = LLVM.user(u) + if !isa(c, LLVM.CallInst) + continue + end + cf = LLVM.called_operand(c) + if !isa(cf, LLVM.Function) + continue + end + if LLVM.name(cf) != "julia.call" && LLVM.name(cf) != "julia.call2" + continue + end + if operands(c)[1] != fn + continue + end + attr = if LLVM.version().major <= 15 + LLVM.EnumAttribute("readonly") + else + EnumAttribute( + "memory", + MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) + end + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + attr, + ) + end + end + end + end + + for fname in ("julia.get_pgcstack", "julia.ptls_states", "jl_get_ptls_states") + if haskey(funcs, fname) + for fn in funcs[fname] + # TODO per discussion w keno perhaps this should change to readonly / inaccessiblememonly + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + end + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) + end + end + end + + for fname in ("julia.gc_loaded",) + if haskey(funcs, fname) + for fn in funcs[fname] + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_nocache")) + end + end + end + + for fname in ( + "julia.get_pgcstack", + "julia.ptls_states", + "jl_get_ptls_states", + "julia.safepoint", + "ijl_throw", + "julia.pointer_from_objref", + "ijl_array_grow_end", + "jl_array_grow_end", + "ijl_array_del_end", + "jl_array_del_end", + "ijl_array_grow_beg", + "jl_array_grow_beg", + "ijl_array_del_beg", + "jl_array_del_beg", + "ijl_array_grow_at", + "jl_array_grow_at", + "ijl_array_del_at", + "jl_array_del_at", + "ijl_pop_handler", + "jl_pop_handler", + "ijl_pop_handler_noexcept", + "jl_pop_handler_noexcept", + "ijl_push_handler", + "jl_push_handler", + "ijl_module_name", + "jl_module_name", + "ijl_restore_excstack", + "jl_restore_excstack", + "julia.except_enter", + "ijl_get_nth_field_checked", + "jl_get_nth_field_checked", + "jl_egal__unboxed", + "ijl_reshape_array", + "jl_reshape_array", + "ijl_eqtable_get", + "jl_eqtable_get", + "jl_gc_run_pending_finalizers", + "ijl_try_substrtod", + "jl_try_substrtod", + ) + if haskey(funcs, fname) + for fn in funcs[fname] + push!(function_attributes(fn), no_escaping_alloc) + end + end + end + + + + for fname in ("julia.pointer_from_objref",) + if haskey(funcs, fname) + for fn in funcs[fname] + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + end + end + end + end + + for fname in ( + "julia.gc_alloc_obj", + "jl_gc_alloc_typed", + "ijl_gc_alloc_typed", + ) + if haskey(funcs, fname) + for fn in funcs[fname] + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + fn, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + kind(EnumAttribute("allockind", AllocFnKind(AFKE_Alloc).data)), + ) + push!(function_attributes(fn), no_escaping_alloc) + push!(function_attributes(fn), LLVM.EnumAttribute("allockind", (AllocFnKind(AFKE_Alloc) | AllocFnKind(AFKE_Uninitialized)).data)) + end + end + end + + for fname in ( + "julia.gc_alloc_obj", + "jl_gc_alloc_typed", + "ijl_gc_alloc_typed", + "jl_box_float32", + "jl_box_float64", + "jl_box_int32", + "jl_box_int64", + "ijl_box_float32", + "ijl_box_float64", + "ijl_box_int32", + "ijl_box_int64", + "jl_alloc_genericmemory", + "ijl_alloc_genericmemory", + "jl_alloc_array_1d", + "jl_alloc_array_2d", + "jl_alloc_array_3d", + "ijl_alloc_array_1d", + "ijl_alloc_array_2d", + "ijl_alloc_array_3d", + "jl_array_copy", + "ijl_array_copy", + "jl_genericmemory_copy_slice", + "ijl_genericmemory_copy_slice", + "jl_alloc_genericmemory", + "ijl_alloc_genericmemory", + "jl_idtable_rehash", + "ijl_idtable_rehash", + "jl_f_tuple", + "ijl_f_tuple", + "jl_new_structv", + "ijl_new_structv", + "ijl_new_array", + "jl_new_array", + ) + if haskey(funcs, fname) + for fn in funcs[fname] + push!(return_attributes(fn), LLVM.EnumAttribute("noalias", 0)) + push!(return_attributes(fn), LLVM.EnumAttribute("nonnull", 0)) + push!(function_attributes(fn), no_escaping_alloc) + push!(function_attributes(fn), LLVM.EnumAttribute("mustprogress")) + push!(function_attributes(fn), LLVM.EnumAttribute("willreturn")) + push!(function_attributes(fn), LLVM.EnumAttribute("nounwind")) + push!(function_attributes(fn), LLVM.EnumAttribute("nofree")) + accattr = if LLVM.version().major <= 15 + LLVM.EnumAttribute("inaccessiblememonly") + else + if fname in ( + "jl_genericmemory_copy_slice", + "ijl_genericmemory_copy_slice",) + EnumAttribute( + "memory", + MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_ModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) + else + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_ModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) + end + end + if !( + fname in ( + "jl_array_copy", + "ijl_array_copy", + "jl_idtable_rehash", + "ijl_idtable_rehash", + ) + ) + push!(function_attributes(fn), accattr) + end + for u in LLVM.uses(fn) + c = LLVM.user(u) + if !isa(c, LLVM.CallInst) + continue + end + cf = LLVM.called_operand(c) + if cf == fn + LLVM.API.LLVMAddCallSiteAttribute( + c, + LLVM.API.LLVMAttributeReturnIndex, + LLVM.EnumAttribute("noalias", 0), + ) + if !( + fname in ( + "jl_array_copy", + "ijl_array_copy", + "jl_idtable_rehash", + "ijl_idtable_rehash", + ) + ) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + accattr, + ) + end + end + if !isa(cf, LLVM.Function) + continue + end + if !(cf == fn || + ((LLVM.name(cf) == "julia.call" || LLVM.name(cf) != "julia.call2") && operands(c)[1] == fn)) + continue + end + LLVM.API.LLVMAddCallSiteAttribute( + c, + LLVM.API.LLVMAttributeReturnIndex, + LLVM.EnumAttribute("noalias", 0), + ) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + no_escaping_alloc, + ) + if !( + fname in ( + "jl_array_copy", + "ijl_array_copy", + "jl_idtable_rehash", + "ijl_idtable_rehash", + ) + ) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + accattr, + ) + end + end + end + end + end + + for fname in ("llvm.julia.gc_preserve_begin", "llvm.julia.gc_preserve_end") + if haskey(funcs, fname) + for fn in funcs[fname] + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly")) + else + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_ModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) + end + end + end + end + + # Key of jl_eqtable_get/put is inactive, definitionally + for fname in ("jl_eqtable_get", "ijl_eqtable_get") + if haskey(funcs, fname) + for fn in funcs[fname] + push!(parameter_attributes(fn, 2), LLVM.StringAttribute("enzyme_inactive")) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) + push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) + else + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) + end + end + end + end + + for fname in ("jl_reshape_array", "ijl_reshape_array") + if haskey(funcs, fname) + for fn in funcs[fname] + push!(parameter_attributes(fn, 3), LLVM.EnumAttribute("readonly")) + push!(parameter_attributes(fn, 3), LLVM.EnumAttribute("nocapture")) + end + end + end + + # Key of jl_eqtable_get/put is inactive, definitionally + for fname in ("jl_eqtable_put", "ijl_eqtable_put") + if haskey(funcs, fname) + for fn in funcs[fname] + push!(parameter_attributes(fn, 2), LLVM.StringAttribute("enzyme_inactive")) + push!(parameter_attributes(fn, 4), LLVM.StringAttribute("enzyme_inactive")) + push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("writeonly")) + push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("nocapture")) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) + else + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_ModRef << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) + end + end + end + end + + for fname in ("jl_in_threaded_region_", "jl_in_threaded_region") + if haskey(funcs, fname) + for fn in funcs[fname] + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) + push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly")) + else + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_Ref << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) + end + end + end + end +end + +function mark_gpu_intrinsics!(target, mod::LLVM.Module) + if target isa GPUCompiler.PTXCompilerTarget + + arg1 = ( + "sin", + "cos", + "tan", + "log2", + "exp", + "exp2", + "exp10", + "cosh", + "sinh", + "tanh", + "atan", + "asin", + "acos", + "log", + "log10", + "log1p", + "acosh", + "asinh", + "atanh", + "expm1", + "cbrt", + "rcbrt", + "j0", + "j1", + "y0", + "y1", + "erf", + "erfinv", + "erfc", + "erfcx", + "erfcinv", + "remquo", + "tgamma", + "round", + "fdim", + "logb", + "isinf", + "sqrt", + "fabs", + "atan2", + ) + # isinf, finite "modf", "fmod", "remainder", + # "rnorm3d", "norm4d", "rnorm4d", "norm", "rnorm", + # "hypot", "rhypot", + # "yn", "jn", "norm3d", "ilogb", powi + # "normcdfinv", "normcdf", "lgamma", "ldexp", "scalbn", "frexp", + # arg1 = ("atan2", "fmax", "pow") + for n in arg1, + (T, pf, lpf) in + ((LLVM.DoubleType(), "", "f64"), (LLVM.FloatType(), "f", "f32")) + + fname = "__nv_" * n * pf + if !haskey(functions(mod), fname) + FT = LLVM.FunctionType(T, [T], vararg = false) + wrapper_f = LLVM.Function(mod, fname, FT) + llname = "llvm." * n * "." * lpf + push!( + function_attributes(wrapper_f), + StringAttribute("implements", llname), + ) + push!( + function_attributes(wrapper_f), + StringAttribute("implements2", n * pf) + ) + end + end + end + if target isa GPUCompiler.GCNCompilerTarget + arg1 = ( + "acos", + "acosh", + "asin", + "asinh", + "atan2", + "atan", + "atanh", + "cbrt", + "ceil", + "copysign", + "cos", + "native_cos", + "cosh", + "cospi", + "i0", + "i1", + "erfc", + "erfcinv", + "erfcx", + "erf", + "erfinv", + "exp10", + "native_exp10", + "exp2", + "exp", + "native_exp", + "expm1", + "fabs", + "fdim", + "floor", + "fma", + "fmax", + "fmin", + "fmod", + "frexp", + "hypot", + "ilogb", + "isfinite", + "isinf", + "isnan", + "j0", + "j1", + "ldexp", + "lgamma", + "log10", + "native_log10", + "log1p", + "log2", + "log2", + "logb", + "log", + "native_log", + "modf", + "nearbyint", + "nextafter", + "len3", + "len4", + "ncdf", + "ncdfinv", + "pow", + "pown", + "rcbrt", + "remainder", + "remquo", + "rhypot", + "rint", + "rlen3", + "rlen4", + "round", + "rsqrt", + "scalb", + "scalbn", + "signbit", + "sincos", + "sincospi", + "sin", + "native_sin", + "sinh", + "sinpi", + "sqrt", + "native_sqrt", + "tan", + "tanh", + "tgamma", + "trunc", + "y0", + "y1", + ) + for n in arg1, + (T, pf, lpf) in + ((LLVM.DoubleType(), "", "f64"), (LLVM.FloatType(), "f", "f32")) + + fname = "__ocml_" * n * "_" * lpf + if !haskey(functions(mod), fname) + FT = LLVM.FunctionType(T, [T], vararg = false) + wrapper_f = LLVM.Function(mod, fname, FT) + llname = "llvm." * n * "." * lpf + push!( + function_attributes(wrapper_f), + StringAttribute("implements", llname), + ) + push!( + function_attributes(wrapper_f), + StringAttribute("implements2", n * pf) + ) + end + end + end +end diff --git a/src/sugar.jl b/src/sugar.jl new file mode 100644 index 0000000000..ba53f46a00 --- /dev/null +++ b/src/sugar.jl @@ -0,0 +1,1155 @@ +# Syntactic sugar over autodiff calls (e.g. Enzyme.gradient and Enzyme.jacobian) + + +function zerosetfn(x, i::Int) + res = zero(x) + @inbounds res[i] = 1 + return res +end + +@generated function onehot_internal(fn::F, x::T, startv::Int, lengthv::Int) where {F, T<:Array} + ir = GPUCompiler.JuliaContext() do ctx + Base.@_inline_meta + + target = Compiler.DefaultCompilerTarget() + params = Compiler.PrimalCompilerParams(API.DEM_ForwardMode) + mi = my_methodinstance(fn, Tuple{T, Int}) + job = GPUCompiler.CompilerJob(mi, GPUCompiler.CompilerConfig(target, params; kernel = false)) + mod, meta = GPUCompiler.codegen( + :llvm, + job; + optimize = false, + cleanup = false, + validate = false, + ) + copysetfn = meta.entry + blk = first(LLVM.blocks(copysetfn)) + iter = LLVM.API.LLVMGetFirstInstruction(blk) + while iter != C_NULL + inst = LLVM.Instruction(iter) + iter = LLVM.API.LLVMGetNextInstruction(iter) + if isa(inst, LLVM.FenceInst) + Compiler.eraseInst(blk, inst) + end + if isa(inst, LLVM.CallInst) + fn = LLVM.called_operand(inst) + if isa(fn, LLVM.Function) + if LLVM.name(fn) == "julia.safepoint" + Compiler.eraseInst(blk, inst) + end + end + end + end + hasNoRet = any( + map( + k -> kind(k) == kind(LLVM.EnumAttribute("noreturn")), + collect(LLVM.function_attributes(copysetfn)), + ), + ) + @assert !hasNoRet + if !hasNoRet + push!(LLVM.function_attributes(copysetfn), LLVM.EnumAttribute("alwaysinline", 0)) + end + ity = convert(LLVM.LLVMType, Int) + jlvaluet = convert(LLVM.LLVMType, T; allow_boxed=true) + + FT = LLVM.FunctionType(jlvaluet, LLVM.LLVMType[jlvaluet, ity, ity]) + llvm_f = LLVM.Function(mod, "f", FT) + push!(LLVM.function_attributes(llvm_f), LLVM.EnumAttribute("alwaysinline", 0)) + + # Check if Julia version has https://github.com/JuliaLang/julia/pull/46914 + # and also https://github.com/JuliaLang/julia/pull/47076 + # and also https://github.com/JuliaLang/julia/pull/48620 + needs_dynamic_size_workaround = !(VERSION >= v"1.10.5") + + builder = LLVM.IRBuilder() + entry = LLVM.BasicBlock(llvm_f, "entry") + LLVM.position!(builder, entry) + inp, lstart, len = collect(LLVM.Value, LLVM.parameters(llvm_f)) + + boxed_count = if sizeof(Int) == sizeof(Int64) + Compiler.emit_box_int64!(builder, len) + else + Compiler.emit_box_int32!(builder, len) + end + + tag = Compiler.emit_apply_type!(builder, NTuple, LLVM.Value[boxed_count, unsafe_to_llvm(builder, T)]) + + fullsize = LLVM.nuwmul!(builder, len, LLVM.ConstantInt(sizeof(Int))) + obj = Compiler.emit_allocobj!(builder, tag, fullsize, needs_dynamic_size_workaround) + + T_int8 = LLVM.Int8Type() + LLVM.memset!(builder, obj, LLVM.ConstantInt(T_int8, 0), fullsize, 0) + + alloc = LLVM.pointercast!(builder, obj, LLVM.PointerType(jlvaluet, Tracked)) + alloc = LLVM.pointercast!(builder, alloc, LLVM.PointerType(jlvaluet, 11)) + + loop = LLVM.BasicBlock(llvm_f, "loop") + exit = LLVM.BasicBlock(llvm_f, "exit") + + LLVM.br!(builder, LLVM.icmp!(builder, LLVM.API.LLVMIntEQ, LLVM.ConstantInt(0), len), exit, loop) + + LLVM.position!(builder, loop) + idx = LLVM.phi!(builder, ity) + + push!(LLVM.incoming(idx), (LLVM.ConstantInt(0), entry)) + inc = LLVM.add!(builder, idx, LLVM.ConstantInt(1)) + push!(LLVM.incoming(idx), (inc, loop)) + rval = LLVM.add!(builder, inc, lstart) + res = LLVM.call!(builder, LLVM.function_type(copysetfn), copysetfn, [inp, rval]) + if !hasNoRet + gidx = LLVM.gep!(builder, jlvaluet, alloc, [idx]) + LLVM.store!(builder, res, gidx) + Compiler.emit_writebarrier!(builder, Compiler.get_julia_inner_types(builder, obj, res)) + end + + LLVM.br!(builder, LLVM.icmp!(builder, LLVM.API.LLVMIntEQ, inc, len), exit, loop) + + + T_int32 = LLVM.Int32Type() + + Compiler.reinsert_gcmarker!(llvm_f) + + LLVM.position!(builder, exit) + LLVM.ret!(builder, obj) + + string(mod) + end + return quote + Base.@_inline_meta + Base.llvmcall( + ($ir, "f"), + Tuple{Vararg{T}}, + Tuple{T, Int, Int}, + x, + startv, + lengthv + ) + end +end + +@inline function onehot(x::Array) + onehot_internal(zerosetfn, x, 0, length(x)) +end + +@inline function onehot(x::Array, start::Int, endl::Int) + onehot_internal(zerosetfn, x, start-1, endl-start+1) +end + +@inline function onehot(x::AbstractArray) + N = length(x) + ntuple(Val(N)) do i + Base.@_inline_meta + res = similar(x) + for idx = 1:N + @inbounds res[idx] = (i == idx) ? 1.0 : 0.0 + end + return res + end +end +@inline function onehot(x::AbstractArray, start::Int, endl::Int) + ntuple(Val(endl - start + 1)) do i + Base.@_inline_meta + res = similar(x) + for idx = 1:length(x) + @inbounds res[idx] = (i + start - 1 == idx) ? 1.0 : 0.0 + end + return res + end +end + +@inline function onehot(::Type{NTuple{N,T}}) where {T,N} + ntuple(Val(N)) do i + Base.@_inline_meta + ntuple(Val(N)) do idx + Base.@_inline_meta + return (i == idx) ? T(1) : T(0) + end + end +end +@inline onehot(x::Tuple{}) = () +@inline function onehot(x::NTuple{N,T}) where {T,N} + onehot(NTuple{N,T}) +end +@inline function onehot(x::NTuple{N,T}, start, endl) where {T,N} + ntuple(Val(endl - start + 1)) do i + Base.@_inline_meta + ntuple(Val(N)) do idx + Base.@_inline_meta + return (i + start - 1 == idx) ? T(1) : T(0) + end + end +end + +@inline function onehot(x::AbstractFloat) + return (one(x),) +end + +""" + gradient(::ReverseMode, f, args...) + +Compute the gradient of a real-valued function `f` using reverse mode. +For each differentiable argument, this function will allocate and return new derivative object, returning +a tuple of derivatives for each argument. If an argument is not differentiable, the element of the returned +tuple with be nothing. + +In reverse mode (here), the derivatives will be the same type as the original argument. + +This is a structure gradient. For a struct `x` it returns another instance of the same type, +whose fields contain the components of the gradient. +In the result, `grad.a` contains `∂f/∂x.a` for any differential `x.a`, +while `grad.c == x.c` for other types. + +Examples: + +```jldoctest gradient +f(x) = x[1]*x[2] + +grad = gradient(Reverse, f, [2.0, 3.0]) + +# output +([3.0, 2.0],) +``` + +```jldoctest gradient +grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) + +# output + +((a = 3.0, b = [2.0], c = "str"),) +``` + +```jldoctest gradient +mul(x, y) = x[1]*y[1] + +grad = gradient(Reverse, mul, [2.0], [3.0]) + +# output +([3.0], [2.0]) +``` + +```jldoctest gradient + +grad = gradient(Reverse, mul, [2.0], Const([3.0])) + +# output +([3.0], nothing) +``` + +If passing a mode that returns the primal (e.g. ReverseWithPrimal), the return type will instead be +a tuple where the first element contains the derivatives, and the second element contains the result of the original computation. + +```jldoctest gradient + +grad = gradient(ReverseWithPrimal, f, [2.0, 3.0]) + +# output +(derivs = ([3.0, 2.0],), val = 6.0) +``` +```jldoctest gradient + +grad = gradient(ReverseWithPrimal, mul, [2.0], [3.0]) + +# output +(derivs = ([3.0], [2.0]), val = 6.0) +``` + +```jldoctest gradient +grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) + +# output +(derivs = ([3.0], nothing), val = 6.0) +``` + +""" +@generated function gradient( + rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, + f::F, + x::ty_0, + args::Vararg{Any,N}, +) where {F,ty_0,ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten,N} + toemit = Expr[quote + act_0 = + !(x isa Enzyme.Const) && + Compiler.active_reg_inner(Core.Typeof(x), (), nothing, Val(true)) == + Compiler.ActiveState #=justActive=# + end] + rargs = Union{Symbol,Expr}[:x] + acts = Symbol[Symbol("act_0")] + + for i = 1:N + argidx = quote + args[$i] + end + push!(rargs, argidx) + sym = Symbol("act_$i") + push!(acts, sym) + push!( + toemit, + quote + $sym = + !($argidx isa Enzyme.Const) && + Compiler.active_reg_inner( + Core.Typeof($argidx), + (), + nothing, + Val(true), + ) == Compiler.ActiveState #=justActive=# + end, + ) + end + + idx = 0 + shadows = Symbol[] + enz_args = Expr[] + resargs = Expr[] + for (arg, act) in zip(rargs, acts) + shad = Symbol("shad_$idx") + push!(shadows, shad) + push!(toemit, quote + $shad = if $arg isa Enzyme.Const + nothing + elseif $act + Ref(make_zero($arg)) + else + make_zero($arg) + end + end) + push!(enz_args, quote + if $arg isa Enzyme.Const + $arg + elseif $act + MixedDuplicated($arg, $shad) + else + Duplicated($arg, $shad) + end + end) + push!(resargs, quote + if $arg isa Enzyme.Const + nothing + elseif $act + $shad[] + else + $shad + end + end) + idx += 1 + end + push!(toemit, quote + res = autodiff(rm, f, Active, $(enz_args...)) + end) + + if ReturnPrimal + return quote + Base.@_inline_meta + $(toemit...) + (; derivs = ($(resargs...),), val = res[2]) + end + else + return quote + Base.@_inline_meta + $(toemit...) + ($(resargs...),) + end + end +end + +""" + gradient!(::ReverseMode, dx, f, x) + +Compute the gradient of an array-input function `f` using reverse mode, +storing the derivative result in an existing array `dx`. +Both `x` and `dx` must be `Array`s of the same type. + +Example: + +```jldoctest gradip +f(x) = x[1]*x[2] + +dx = [0.0, 0.0] +gradient!(Reverse, dx, f, [2.0, 3.0]) + +# output +([3.0, 2.0],) +``` + +```jldoctest gradip +dx = [0.0, 0.0] +gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0]) + +# output +(derivs = ([3.0, 2.0],), val = 6.0) +``` +""" +@inline function gradient!( + rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, + dx::X, + f::F, + x::X, +) where {X<:Array,F,ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} + make_zero!(dx) + res = autodiff(rm, f, Active, Duplicated(x, dx)) + return if ReturnPrimal + (; derivs = (dx,), val = res[2]) + else + (dx,) + end +end + +@inline function chunkedonehot(x, ::Val{chunk}) where {chunk} + sz = length(x) + num = ((sz + chunk - 1) ÷ chunk) + ntuple(Val(num)) do i + Base.@_inline_meta + onehot(x, (i - 1) * chunk + 1, i == num ? sz : (i * chunk)) + end +end + +@inline function chunkedonehot(x::AbstractFloat, ::Val{chunk}) where {chunk} + return ((one(x),),) +end + +@inline tupleconcat(x) = x +@inline tupleconcat(x, y) = (x..., y...) +@inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...) + +@generated function create_shadows(chunk::ChunkTy, x::X, vargs::Vararg{Any,N}) where {ChunkTy, X, N} + args = Union{Symbol,Expr}[:x] + tys = Type[X] + for i in 1:N + push!(args, :(vargs[$i])) + push!(tys, vargs[i]) + end + + exprs = Union{Symbol,Expr}[] + for (arg, ty) in zip(args, tys) + if ty <: Enzyme.Const + push!(exprs, :(nothing)) + elseif ty <: AbstractFloat + push!(exprs, :(nothing)) + elseif ChunkTy == Nothing || ChunkTy == Val{1} + push!(exprs, :(onehot($arg))) + else + push!(exprs, :(chunkedonehot($arg, chunk))) + end + end + return quote + Base.@_inline_meta + ($(exprs...),) + end +end + +struct TupleArray{T,Shape,Length,N} <: AbstractArray{T,N} + data::NTuple{Length,T} +end +TupleArray(data::NTuple{Length,T}, Shape) where {Length,T} = + TupleArray{T,Shape,Length,length(Shape)}(data) + +@inline Base.eltype(::TupleArray{T}) where {T} = T +@inline Base.eltype(::Type{<:TupleArray{T}}) where {T} = T +@inline Base.size(::TupleArray{<:Any,Shape}) where {Shape} = Shape +@inline Base.ndims(::TupleArray{<:Any,<:Any,<:Any,N}) where {N} = N + +function Base.convert( + ::Type{Array{T,N}}, + X::TupleArray{T,Shape,Length,N}, +) where {T,Shape,Length,N} + vals = Array{T,N}(undef, Shape...) + for i = 1:Length + @inbounds val[i] = X.data[i] + end + return vals +end + +function Base.getindex(a::TupleArray, args::Vararg{Int,N}) where {N} + start = 0 + for i = 1:N + start *= size(a, N - i + 1) + start += (args[N-i+1] - 1) + end + start += 1 + return a.data[start] +end + +@inline function tupstack(data::Tuple{Vararg{Array{T}}}, outshape::Tuple{Vararg{Int}}, inshape::Tuple{Vararg{Int}}) where {T} + num = prod(outshape) + res = Array{T}(undef, outshape..., inshape...) + for (i, val) in enumerate(data) + Base.unsafe_copyto!(res, num*(i-1)+1, val, 1, Base.reinterpret(UInt, num)) + end + res +end + +@inline function tupstack(x, outshape::Tuple{Vararg{Int}}, inshape::Tuple{Vararg{Int}}) + st = Base.stack(x) + if length(outshape) == 1 + st + else + reshape(st, (outshape..., inshape...)) + end +end + +@inline specialize_output(output, input) = output + +""" + gradient(::ForwardMode, f, x; shadows=onehot(x), chunk=nothing) + +Compute the gradient of an array-input function `f` using forward mode. The +optional keyword argument `shadow` is a vector of one-hot vectors of type `x` +which are used to forward-propagate into the return. For performance reasons, +this should be computed once, outside the call to `gradient`, rather than +within this call. + +Example: + +```jldoctest gradfwd +f(x) = x[1]*x[2] + +gradient(Forward, f, [2.0, 3.0]) + +# output + +([3.0, 2.0],) +``` + +```jldoctest gradfwd +gradient(ForwardWithPrimal, f, [2.0, 3.0]) + +# output +(derivs = ([3.0, 2.0],), val = 6.0) +``` + +```jldoctest gradfwd +gradient(Forward, f, [2.0, 3.0]; chunk=Val(1)) + +# output + +([3.0, 2.0],) +``` + +```jldoctest gradfwd +gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=Val(1)) + +# output +(derivs = ([3.0, 2.0],), val = 6.0) +``` + +For functions which return an AbstractArray or scalar, this function will return an AbstractArray +whose shape is `(size(output)..., size(input)...)`. No guarantees are presently made +about the type of the AbstractArray returned by this function (which may or may not be the same +as the input AbstractArray if provided). + +For functions who return other types, this function will retun an AbstractArray +of shape `size(input)` of values of the output type. +```jldoctest +f(x) = [ x[1] * x[2], x[2] + x[3] ] + +grad = gradient(Forward, f, [2.0, 3.0, 4.0]) + +# output +([3.0 2.0 0.0; 0.0 1.0 1.0],) +``` + +This function supports multiple arguments and computes the gradient with respect to each + +```jldoctest gradfwd2 +mul(x, y) = x[1]*y[2] + x[2]*y[1] + +gradient(Forward, mul, [2.0, 3.0], [2.7, 3.1]) + +# output + +([3.1, 2.7], [3.0, 2.0]) +``` + +This includes the ability to mark some arguments as `Const` if its derivative is not needed, returning nothing in the corresponding derivative map. + +```jldoctest gradfwd2 +gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1])) + +# output + +([3.1, 2.7], nothing) +``` +""" +@generated function gradient( + fm::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, + f::F, + x::ty_0, + args::Vararg{Any,N}; + chunk::CS = nothing, + shadows::ST = create_shadows(chunk, x, args...), +) where {F, ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,CS,ST, ty_0, N} + + syms = Union{Symbol,Expr}[:x] + shads = Union{Symbol,Expr}[:(shadows[1])] + tys = Type[ty_0] + for i in 1:N + push!(syms, :(args[$i])) + push!(tys, args[i]) + push!(shads, :(shadows[1+$i])) + end + fval = if F <: Annotation + :(f.val) + else + :f + end + + vals = Union{Symbol,Expr}[] + consts = Union{Symbol,Expr}[] + for (arg, ty) in zip(syms, tys) + if ty <: Const + push!(vals, :($arg.val)) + push!(consts, arg) + else + push!(vals, arg) + push!(consts, :(Const($arg))) + end + end + + if CS == Val{0} + return quote + Base.@_inline_meta + throw(ErrorException("Cannot differentiate with a batch size of 0")) + end + end + + exprs = Union{Symbol,Expr}[] + primal = nothing + derivatives = Union{Symbol,Expr}[] + + primmode = :(fm) + for (i, (arg, ty)) in enumerate(zip(syms, tys)) + if ty <: Const + push!(derivatives, :(nothing)) + continue + end + + argnum = length(ST.parameters[i].parameters) + + argderivative = if ty <: AbstractFloat + dargs = Union{Symbol,Expr}[] + for (j, arg2) in enumerate(syms) + if i == j + push!(dargs, :(Duplicated($arg, one($arg)))) + else + push!(dargs, consts[j]) + end + end + + resp = Symbol("resp_$i") + push!(exprs, quote + $resp = autodiff($primmode, f, Duplicated, $(dargs...)) + end) + if ReturnPrimal && primal == nothing + primal = :($resp[2]) + primmode = NoPrimal(fm()) + end + + :($resp[1]) + elseif argnum == 0 + vals[i] + elseif CS == Nothing + dargs = Union{Symbol,Expr}[] + for (j, arg2) in enumerate(syms) + if i == j + push!(dargs, :(BatchDuplicated($arg, $(shads[i])))) + else + push!(dargs, consts[j]) + end + end + + df = :f + if F <: Enzyme.Duplicated + zeros = Expr[] + for i in 1:argnum + push!(zeros, :(f.dval)) + end + df = :(BatchDuplicated(f.val, ($(zeros...),) )) + end + + resp = Symbol("resp_$i") + push!(exprs, quote + $resp = autodiff($primmode, $df, BatchDuplicated, $(dargs...)) + end) + if ReturnPrimal && primal == nothing + primal = :($resp[2]) + primmode = NoPrimal(fm()) + end + + :(values($resp[1])) + elseif CS == Val{1} + subderivatives = Union{Symbol,Expr}[] + for an in 1:argnum + dargs = Union{Symbol,Expr}[] + for (j, arg2) in enumerate(syms) + if i == j + push!(dargs, :(Duplicated($arg, $(shads[i])[$an]))) + else + push!(dargs, consts[j]) + end + end + + resp = Symbol("resp_$i"*"_"*string(an)) + push!(exprs, quote + $resp = autodiff($primmode, f, Duplicated, $(dargs...)) + end) + if ReturnPrimal && primal == nothing + primal = :($resp[2]) + primmode = NoPrimal(fm()) + end + + push!(subderivatives, :(values($resp[1]))) + end + :(($(subderivatives...),)) + else + subderivatives = Union{Symbol,Expr}[] + for an in 1:argnum + dargs = Union{Symbol,Expr}[] + for (j, arg2) in enumerate(syms) + if i == j + push!(dargs, :(BatchDuplicated($arg, $(shads[i])[$an]))) + else + push!(dargs, consts[j]) + end + end + + resp = Symbol("resp_$i"*"_"*string(an)) + push!(exprs, quote + $resp = autodiff($primmode, f, BatchDuplicated, $(dargs...)) + end) + if ReturnPrimal && primal == nothing + primal = :($resp[2]) + primmode = NoPrimal(fm()) + end + + push!(subderivatives, :(values($resp[1]))) + end + :(tupleconcat($(subderivatives...))) + end + + deriv = if ty <: AbstractFloat + argderivative + else + tmp = Symbol("tmp_$i") + push!(exprs, :($tmp = $argderivative)) + if ty <: AbstractArray + if argnum > 0 + quote + if $tmp[1] isa AbstractArray + inshape = size($(vals[1])) + outshape = size($tmp[1]) + # st : outshape x total inputs + tupstack($tmp, outshape, inshape) + else + specialize_output(TupleArray($tmp, size($arg)), $(vals[1])) + end + end + else + tmp + end + else + tmp + end + end + push!(derivatives, deriv) + end + + # We weirdly asked for no derivatives + if ReturnPrimal && primal == nothing + primal = :($fval($(vals...))) + end + + result = if ReturnPrimal + :((; derivs = ($(derivatives...),), val = $primal)) + else + :(($(derivatives...),)) + end + + return quote + Base.@_inline_meta + $(exprs...) + $result + end +end + +""" + jacobian(::ForwardMode, args...; kwargs...) + +Equivalent to gradient(::ForwardMode, args...; kwargs...) +""" +@inline function jacobian(fm::ForwardMode, args...; kwargs...) + gradient(fm, args...; kwargs...) +end + +""" + jacobian(::ReverseMode, f, x; n_outs=nothing, chunk=nothing) + jacobian(::ReverseMode, f, x) + +Compute the jacobian of a array-output function `f` using (potentially vector) +reverse mode. The `chunk` argument optionally denotes the chunk size to use and +`n_outs` optionally denotes the shape of the array returned by `f` (e.g `size(f(x))`). + +Example: + +```jldoctest +f(x) = [ x[1] * x[2], x[2] + x[3] ] + +jacobian(Reverse, f, [2.0, 3.0, 4.0]) + +# output +([3.0 2.0 0.0; 0.0 1.0 1.0],) +``` + +```jldoctest +f(x) = [ x[1] * x[2], x[2] + x[3] ] + +grad = jacobian(ReverseWithPrimal, f, [2.0, 3.0, 4.0]) + +# output +(derivs = ([3.0 2.0 0.0; 0.0 1.0 1.0],), val = [6.0, 7.0]) +``` + +```jldoctest +f(x) = [ x[1] * x[2], x[2] + x[3] ] + +grad = jacobian(Reverse, f, [2.0, 3.0, 4.0], n_outs=Val((2,))) + +# output +([3.0 2.0 0.0; 0.0 1.0 1.0],) +``` + +```jldoctest +f(x) = [ x[1] * x[2], x[2] + x[3] ] + +grad = jacobian(ReverseWithPrimal, f, [2.0, 3.0, 4.0], n_outs=Val((2,))) + +# output +(derivs = ([3.0 2.0 0.0; 0.0 1.0 1.0],), val = [6.0, 7.0]) +``` + +This function will return an AbstractArray whose shape is `(size(output)..., size(input)...)`. +No guarantees are presently made about the type of the AbstractArray returned by this function +(which may or may not be the same as the input AbstractArray if provided). + +In the future, when this function is extended to handle non-array return types, +this function will retun an AbstractArray of shape `size(output)` of values of the input type. +``` +""" +@inline function jacobian( + mode::ReverseMode{ReturnPrimal,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}, + f::F, + x::X; + n_outs::OutType = nothing, + chunk::CT = nothing, +) where {ReturnPrimal,F,X,RABI<:ABI,ErrIfFuncWritten,RuntimeActivity,OutType,CT,Holomorphic} + + if n_outs == nothing + res = if f isa Const + f.val(x) + else + f(x) + end + jac = if res isa AbstractArray + jacobian( + ReverseMode{false,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}(), + f, + x; + n_outs = Val(size(res)), + chunk, + ) + elseif res isa AbstractFloat + gradient( + ReverseMode{false,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}(), + f, + x, + ) + else + throw( + AssertionError( + "Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))", + ), + ) + end + + return if ReturnPrimal + (; derivs = jac, val = res) + else + jac + end + else + n_out_val = if length(Compiler.element(n_outs)) == 0 + 0 + else + prod(Compiler.element(n_outs)) + end + + if chunk == Val(0) + throw(ErrorException("Cannot differentiate with a batch size of 0")) + end + + XT = Core.Typeof(x) + MD = Compiler.active_reg_inner(XT, (), nothing, Val(true)) == Compiler.ActiveState #=justActive=# + tt = Tuple{XT} + FRT = if f isa Const + Core.Typeof(f.val) + else + Core.Typeof(f) + end + + rt = Compiler.primal_return_type(mode, FRT, tt) + + ModifiedBetweenT = (false, false) + FA = Const{FRT} + + if chunk == Val(1) || chunk == nothing + primal, adjoint = autodiff_thunk( + ReverseModeSplit{ + #=ReturnPrimal=#false, + #=ReturnShadow=#true, + RuntimeActivity, + #=width=#1, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#false + }(), + FA, + DuplicatedNoNeed{rt}, + MD ? MixedDuplicated{XT} : Duplicated{XT} + ) + tmp = ntuple(Val(n_out_val)) do i + Base.@_inline_meta + z = make_zero(x) + dx = MD ? Ref(z) : z + res = primal(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx)) + tape = res[1] + @inbounds res[3][i] += Compiler.default_adjoint(eltype(typeof(res[3]))) + adjoint(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx), tape) + return MD ? dx[] : dx, (i == 1 ? size(res[3]) : nothing) + end + rows = map(first, tmp) + outshape = tmp[1][2] + rows, outshape + else + chunksize = Compiler.element(chunk) + primal, adjoint = autodiff_thunk( + ReverseModeSplit{ + #=ReturnPrimal=#false, + #=ReturnShadow=#true, + RuntimeActivity, + chunksize, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#false + }(), + FA, + BatchDuplicatedNoNeed{rt, chunksize}, + MD ? BatchMixedDuplicated{XT, chunksize} : BatchDuplicated{XT, chunksize} + ) + + num = ((n_out_val + chunksize - 1) ÷ chunksize) + + if num * chunksize == n_out_val + last_size = chunksize + primal2, adjoint2 = primal, adjoint + else + last_size = n_out_val - (num - 1) * chunksize + tt′ = Tuple{BatchDuplicated{Core.Typeof(x),last_size}} + primal2, adjoint2 = autodiff_thunk( + ReverseModeSplit{ + #=ReturnPrimal=#false, + #=ReturnShadow=#true, + RuntimeActivity, + last_size, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#false + }(), + FA, + BatchDuplicatedNoNeed{rt, last_size}, + MD ? BatchMixedDuplicated{XT, last_size} : BatchDuplicated{XT, last_size} + ) + end + + tmp = ntuple(num) do i + Base.@_inline_meta + dx = ntuple(Val(i == num ? last_size : chunksize)) do idx + Base.@_inline_meta + z = make_zero(x) + MD ? Ref(z) : z + end + res = (i == num ? primal2 : primal)( + Const(f), + MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), + ) + tape = res[1] + j = 0 + for shadow in res[3] + j += 1 + @inbounds shadow[(i-1)*chunksize+j] += + Compiler.default_adjoint(eltype(typeof(shadow))) + end + (i == num ? adjoint2 : adjoint)( + Const(f), + MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), + tape, + ) + return MD ? ( + ntuple(Val(i == num ? last_size : chunksize)) do idx + Base.@_inline_meta + dx[idx][] + end + ) : dx, + (i == 1 ? size(res[3][1]) : nothing) + end + rows = tupleconcat(map(first, tmp)...) + outshape = tmp[1][2] + rows, outshape + end + res = if x isa AbstractArray + inshape = size(x) + st2 = tupstack(rows, inshape, outshape) + + st3 = if length(outshape) == 1 && length(inshape) == 1 + transpose(st2) + else + transp = ( + ((length(inshape)+1):(length(inshape)+length(outshape)))..., + (1:length(inshape))..., + ) + PermutedDimsArray(st2, transp) + end + + st3 + else + reshape(collect(rows), outshape) + end + if ReturnPrimal + # TODO optimize away redundant fwd pass + (; derivs = (res,), val = if f isa Enzyme.Const + f.val(x) + else + f(x) + end) + else + (res,) + end + end +end + +""" + hvp(f::F, x::X, v::X) where {F, X} + +Compute the Hessian-vector product of an array-input scalar-output function `f`, as evaluated at `x` times the vector `v`. + +In other words, compute hessian(f)(x) * v + +See [`hvp!`](@ref) for a version which stores the result in an existing buffer and also [`hvp_and_gradient!`](@ref) for a function to compute both the hvp and the gradient in a single call. + +Example: + +```jldoctest hvp; filter = r"([0-9]+\\.[0-9]{8})[0-9]+" => s"\\1***" +f(x) = sin(x[1] * x[2]) + +hvp(f, [2.0, 3.0], [5.0, 2.7]) + +# output +2-element Vector{Float64}: + 19.6926882637302 + 16.201003759768003 +``` +""" +@inline function hvp(f::F, x::X, v::X) where {F,X} + res = make_zero(x) + hvp!(res, f, x, v) + return res +end + + +""" + hvp!(res::X, f::F, x::X, v::X) where {F, X} + +Compute an in-place Hessian-vector product of an array-input scalar-output function `f`, as evaluated at `x` times the vector `v`. +The result will be stored into `res`. The function still allocates and zero's a buffer to store the intermediate gradient, which is +not returned to the user. + +In other words, compute res .= hessian(f)(x) * v + +See [`hvp_and_gradient!`](@ref) for a function to compute both the hvp and the gradient in a single call. + +Example: + +```jldoctest hvpip; filter = r"([0-9]+\\.[0-9]{8})[0-9]+" => s"\\1***" +f(x) = sin(x[1] * x[2]) + +res = Vector{Float64}(undef, 2) +hvp!(res, f, [2.0, 3.0], [5.0, 2.7]) + +res +# output +2-element Vector{Float64}: + 19.6926882637302 + 16.201003759768003 +``` +""" +@inline function hvp!(res::X, f::F, x::X, v::X) where {F,X} + grad = make_zero(x) + Enzyme.autodiff( + Forward, + gradient!, + Const(Reverse), + DuplicatedNoNeed(grad, res), + Const(f), + Duplicated(x, v), + ) + return nothing +end + + + +""" + hvp_and_gradient!(res::X, grad::X, f::F, x::X, v::X) where {F, X} + +Compute an in-place Hessian-vector product of an array-input scalar-output function `f`, as evaluated at `x` times the vector `v` as well as +the gradient, storing the gradient into `grad`. Both the hessian vector product and the gradient can be computed together more efficiently +than computing them separately. + +The result will be stored into `res`. The gradient will be stored into `grad`. + +In other words, compute res .= hessian(f)(x) * v and grad .= gradient(Reverse, f)(x) + +Example: + +```jldoctest hvp_and_gradient; filter = r"([0-9]+\\.[0-9]{8})[0-9]+" => s"\\1***" +f(x) = sin(x[1] * x[2]) + +res = Vector{Float64}(undef, 2) +grad = Vector{Float64}(undef, 2) +hvp_and_gradient!(res, grad, f, [2.0, 3.0], [5.0, 2.7]) + +res +grad +# output +2-element Vector{Float64}: + 2.880510859951098 + 1.920340573300732 +``` +""" +@inline function hvp_and_gradient!(res::X, grad::X, f::F, x::X, v::X) where {F,X} + Enzyme.autodiff( + Forward, + gradient!, + Const(Reverse), + Duplicated(grad, res), + Const(f), + Duplicated(x, v), + ) + return nothing +end + diff --git a/src/make_zero.jl b/src/typeutils/make_zero.jl similarity index 100% rename from src/make_zero.jl rename to src/typeutils/make_zero.jl diff --git a/src/typeutils/recursive_add.jl b/src/typeutils/recursive_add.jl new file mode 100644 index 0000000000..039f7d3d0c --- /dev/null +++ b/src/typeutils/recursive_add.jl @@ -0,0 +1,86 @@ +# Recursively return x + f(y), where y is active, otherwise x + +@inline function recursive_add( + x::T, + y::T, + f::F = identity, + forcelhs::F2 = guaranteed_const, +) where {T,F,F2} + if forcelhs(T) + return x + end + splatnew(T, ntuple(Val(fieldcount(T))) do i + Base.@_inline_meta + prev = getfield(x, i) + next = getfield(y, i) + recursive_add(prev, next, f, forcelhs) + end) +end + +@inline function recursive_add( + x::T, + y::T, + f::F = identity, + forcelhs::F2 = guaranteed_const, +) where {T<:AbstractFloat,F,F2} + if forcelhs(T) + return x + end + return x + f(y) +end + +@inline function recursive_add( + x::T, + y::T, + f::F = identity, + forcelhs::F2 = guaranteed_const, +) where {T<:Complex,F,F2} + if forcelhs(T) + return x + end + return x + f(y) +end + +@inline mutable_register(::Type{T}) where {T<:Integer} = true +@inline mutable_register(::Type{T}) where {T<:AbstractFloat} = false +@inline mutable_register(::Type{Complex{T}}) where {T<:AbstractFloat} = false +@inline mutable_register(::Type{T}) where {T<:Tuple} = false +@inline mutable_register(::Type{T}) where {T<:NamedTuple} = false +@inline mutable_register(::Type{Core.Box}) = true +@inline mutable_register(::Type{T}) where {T<:Array} = true +@inline mutable_register(::Type{T}) where {T} = ismutabletype(T) + +# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) +@inline function recursive_accumulate(x::Array{T}, y::Array{T}, f::F = identity) where {T,F} + if !mutable_register(T) + for I in eachindex(x) + prev = x[I] + @inbounds x[I] = recursive_add(x[I], (@inbounds y[I]), f, mutable_register) + end + end +end + + +# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) +@inline function recursive_accumulate(x::Core.Box, y::Core.Box, f::F = identity) where {F} + recursive_accumulate(x.contents, y.contents, seen, f) +end + +@inline function recursive_accumulate(x::T, y::T, f::F = identity) where {T,F} + @assert !Base.isabstracttype(T) + @assert Base.isconcretetype(T) + nf = fieldcount(T) + + for i = 1:nf + if isdefined(x, i) + xi = getfield(x, i) + ST = Core.Typeof(xi) + if !mutable_register(ST) + @assert ismutable(x) + yi = getfield(y, i) + nexti = recursive_add(xi, yi, f, mutable_register) + setfield!(x, i, nexti) + end + end + end +end From 22818bf12ff50285e306570d1725b33103d582e3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 29 Nov 2024 11:28:57 -0500 Subject: [PATCH 462/495] Fix fwd to not have ref on active rtfix (#2142) * Fix fwd to not have ref on active rtfix * Update runtests.jl --- src/errors.jl | 21 +++++++++++++++++---- test/runtests.jl | 28 ++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/errors.jl b/src/errors.jl index c6dd78b781..187e42b30c 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -286,6 +286,7 @@ function julia_error( end illegalVal = nothing + mode = get_mode(gutils) function make_replacement(@nospecialize(cur::LLVM.Value), prevbb::LLVM.IRBuilder)::LLVM.Value ncur = new_from_original(gutils, cur) @@ -308,15 +309,27 @@ function julia_error( isa(cur, LLVM.ConstantExpr) && cur == data2 if width == 1 - res = emit_allocobj!(prevbb, Base.RefValue{TT}) - push!(created, res) - return res + if mode == API.DEM_ForwardMode + instance = make_zero(obj) + return unsafe_to_llvm(prevbb, instance) + else + res = emit_allocobj!(prevbb, Base.RefValue{TT}) + push!(created, res) + return res + end else shadowres = UndefValue( LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))), ) for idx = 1:width - res = emit_allocobj!(prevbb, Base.RefValue{TT}) + res = if mode == API.DEM_ForwardMode + instance = make_zero(obj) + unsafe_to_llvm(prevbb, instance) + else + sres = emit_allocobj!(prevbb, Base.RefValue{TT}) + push!(created, sres) + sres + end shadowres = insert_value!(prevbb, shadowres, res, idx - 1) push!(created, shadowres) end diff --git a/test/runtests.jl b/test/runtests.jl index e331645378..549978b894 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1272,6 +1272,34 @@ end @test dweights[2] ≈ 20. end + +abstract type AbsFwdType end + +# Two copies of the same type. +struct FwdNormal1{T<:Real} <: AbsFwdType +σ::T +end + +struct FwdNormal2{T<:Real} <: AbsFwdType +σ::T +end + +fwdlogpdf(d) = d.σ + +function absactfunc(x) + dists = AbsFwdType[FwdNormal1{Float64}(1.0), FwdNormal2{Float64}(x)] + res = Vector{Float64}(undef, 2) + for i in 1:length(dists) + @inbounds res[i] = fwdlogpdf(dists[i]) + end + return @inbounds res[1] + @inbounds res[2] +end + +@testset "Forward Mode active runtime activity" begin + res = Enzyme.autodiff(Enzyme.Forward, Enzyme.Const(absactfunc), Duplicated(2.7, 3.1)) + @test res[1] ≈ 3.1 +end + # dot product (https://github.com/EnzymeAD/Enzyme.jl/issues/495) @testset "Dot product" for T in (Float32, Float64) xx = rand(T, 10) From 2a24bb50ceeaf9a8751a6ff0fa71a3fb6f29a553 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 29 Nov 2024 19:54:23 -0500 Subject: [PATCH 463/495] Speed up activity results no that in noworldage regieme (#2146) --- src/analyses/activity.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/analyses/activity.jl b/src/analyses/activity.jl index f3dcd3a877..10aebb46cd 100644 --- a/src/analyses/activity.jl +++ b/src/analyses/activity.jl @@ -439,7 +439,7 @@ Try to guess the most appropriate [`Annotation`](@ref) for arguments of type `T` guess_activity(T, convert(API.CDerivativeMode, mode)) @inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T} - ActReg = active_reg_inner(T, (), nothing) + ActReg = active_reg_nothrow(T, Val(nothing)) if ActReg == AnyState return Const{T} end From c9c79acda8a5f4941e5912d7cd48643cda39d73a Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 29 Nov 2024 19:54:45 -0500 Subject: [PATCH 464/495] Fix batched forward constant get (#2143) * Fix batched forward constant get * fix * Update typeunstable.jl * Update typeunstable.jl * Update typeunstable.jl * fixup --- src/errors.jl | 15 ++++++++++++++- src/rules/typeunstablerules.jl | 11 +++++++++-- test/typeunstable.jl | 22 +++++++++++++++++++++- 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/src/errors.jl b/src/errors.jl index 187e42b30c..a1a81580e6 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -94,6 +94,17 @@ function julia_error( B::LLVM.API.LLVMBuilderRef, )::LLVM.API.LLVMValueRef msg = Base.unsafe_string(cstr) + julia_error(msg, val, errtype, data, data2, B) +end + +function julia_error( + msg::String, + val::LLVM.API.LLVMValueRef, + errtype::API.ErrorType, + data::Ptr{Cvoid}, + data2::LLVM.API.LLVMValueRef, + B::LLVM.API.LLVMBuilderRef, +)::LLVM.API.LLVMValueRef bt = nothing ir = nothing if val != C_NULL @@ -331,7 +342,9 @@ function julia_error( sres end shadowres = insert_value!(prevbb, shadowres, res, idx - 1) - push!(created, shadowres) + if shadowres isa LLVM.Instruction + push!(created, shadowres) + end end return shadowres end diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index fd157890e2..d130bc0abe 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -1067,6 +1067,7 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) shadowres = UndefValue( LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal))), ) + position!(B, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(normal))) for idx = 1:width shadowres = insert_value!(B, shadowres, normal, idx - 1) end @@ -1534,8 +1535,13 @@ end end origops = collect(operands(orig)) width = get_width(gutils) - if !is_constant_value(gutils, origops[1]) - shadowin = invert_pointer(gutils, origops[1], B) + if !is_constant_value(gutils, origops[1]) || !get_runtime_activity(gutils) + shadowin = if !is_constant_value(gutils, origops[1]) + invert_pointer(gutils, origops[1], B) + else + estr = "Mismatched activity for: " * string(orig) * " const input " *string(origops[1]) * ", differentiable return" + LLVM.Value(julia_error(estr, orig.ref, API.ET_MixedActivityError, gutils.ref, origops[1].ref, B.ref)) + end if width == 1 args = LLVM.Value[ shadowin @@ -1565,6 +1571,7 @@ end shadowres = UndefValue( LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal))), ) + position!(B, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(normal))) for idx = 1:width shadowres = insert_value!(B, shadowres, normal, idx - 1) end diff --git a/test/typeunstable.jl b/test/typeunstable.jl index b3600413a1..2992a7fa3f 100644 --- a/test/typeunstable.jl +++ b/test/typeunstable.jl @@ -101,4 +101,24 @@ end res = Enzyme.autodiff(Forward, toactivepair, BatchDuplicated(2.7f0, (2.0f0, 5.0f0)), BatchDuplicated(3.1, (3.0, 7.0))) @test res[1][1] ≈ 2.7f0 * 3.0 + 2.0f0 * 3.1 @test res[1][2] ≈ 2.7f0 * 7.0 + 5.0f0 * 3.1 -end \ No newline at end of file +end + +struct InsFwdNormal1{T<:Real} + σ::T +end + +struct InsFwdNormal2{T<:Real} + σ::T +end + +insfwdlogpdf(d, x) = d.σ + +function insfwdfunc(x) + dists = [InsFwdNormal1{Float64}(1.0), InsFwdNormal2{Float64}(1.0)] + return sum(Base.Fix2(insfwdlogpdf, x), dists) +end + +@testset "Forward Batch Constant insertion" begin + res = Enzyme.gradient(Enzyme.Forward, insfwdfunc, [0.5, 0.7])[1] + @test res ≈ [0.0, 0.0] +end From e69f3c2d208b5b2f7d152dda3d0f939075a400f7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 1 Dec 2024 01:48:27 -0500 Subject: [PATCH 465/495] Spew more debug info on extract assertion (#2149) --- src/absint.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/absint.jl b/src/absint.jl index 5a72fa0873..1a684a1422 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -636,8 +636,11 @@ function abs_typeof( return (false, nothing, nothing) end if byref == GPUCompiler.BITS_VALUE + ltyp = typ for ind in offset - @assert Base.isconcretetype(typ) + if !Base.isconcretetype(typ) + throw(AssertionError("Illegal absint of $(string(arg)) ltyp=$ltyp, typ=$typ, offset=$offset, ind=$ind")) + end cnt = 0 for i = 1:fieldcount(typ) styp = typed_fieldtype(typ, i) From 2839d3fdc85ede0dccdd7a4b1692bbef6082a48f Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sun, 1 Dec 2024 15:03:50 +0100 Subject: [PATCH 466/495] Don't call specialize_method again (#2148) * Don't call specialize_method again * fixup * fix * fix * fix * fix * fixups * more cleaning * fix * cleanup * fix * fix * fix * fix * fixup * fix * fixup * fewer calls in custom rules * more cleanup * fix * fix * fix * fix * fix * fix * ar --------- Co-authored-by: William S. Moses --- src/absint.jl | 9 +- src/analyses/activity.jl | 26 +- src/api.jl | 15 + src/compiler.jl | 1122 +-------------- src/compiler/optimize.jl | 2104 ---------------------------- src/compiler/utils.jl | 81 +- src/compiler/validation.jl | 2 +- src/errors.jl | 79 ++ src/{compiler => llvm}/passes.jl | 1 + src/llvm/transforms.jl | 2181 ++++++++++++++++++++++++++++++ src/rules/activityrules.jl | 9 +- src/rules/customrules.jl | 220 +-- src/rules/llvmrules.jl | 42 +- src/rules/parallelrules.jl | 5 + src/sugar.jl | 18 +- src/typeutils/conversion.jl | 128 ++ src/typeutils/inference.jl | 182 +++ src/typeutils/jltypes.jl | 297 ++++ src/typeutils/lltypes.jl | 200 +++ src/utils.jl | 10 +- 20 files changed, 3290 insertions(+), 3441 deletions(-) rename src/{compiler => llvm}/passes.jl (99%) create mode 100644 src/llvm/transforms.jl create mode 100644 src/typeutils/conversion.jl create mode 100644 src/typeutils/inference.jl create mode 100644 src/typeutils/jltypes.jl create mode 100644 src/typeutils/lltypes.jl diff --git a/src/absint.jl b/src/absint.jl index 1a684a1422..169db8965f 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -740,6 +740,13 @@ function abs_typeof( return (false, nothing, nothing) end +@inline function is_zero(@nospecialize(x::LLVM.Value))::Bool + if x isa LLVM.ConstantInt + return convert(UInt, x) == 0 + end + return false +end + function abs_cstring(@nospecialize(arg::LLVM.Value))::Tuple{Bool,String} if isa(arg, ConstantExpr) ce = arg @@ -747,7 +754,7 @@ function abs_cstring(@nospecialize(arg::LLVM.Value))::Tuple{Bool,String} if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr ce = operands(ce)[1] elseif opcode(ce) == LLVM.API.LLVMGetElementPtr - if all(x -> x isa LLVM.ConstantInt && convert(UInt, x) == 0, operands(ce)[2:end]) + if all(is_zero, operands(ce)[2:end]) ce = operands(ce)[1] else break diff --git a/src/analyses/activity.jl b/src/analyses/activity.jl index 10aebb46cd..3c29838e70 100644 --- a/src/analyses/activity.jl +++ b/src/analyses/activity.jl @@ -62,7 +62,7 @@ end @inline forcefold(::Val{RT}) where {RT} = RT -@inline function forcefold(::Val{ty}, ::Val{sty}, C::Vararg{Any,N}) where {ty,sty,N} +@inline function forcefold(::Val{ty}, ::Val{sty}, C::Vararg{Any,N})::ActivityState where {ty,sty,N} if sty == AnyState || sty == ty return forcefold(Val(ty), C...) end @@ -107,11 +107,7 @@ else @inline is_arrayorvararg_ty(::Type{Memory{T}}) where T = true end -@inline function datatype_fieldcount(t::Type{T}) where {T} - return Base.datatype_fieldcount(t) -end - -@inline function staticInTup(::Val{T}, tup::NTuple{N,Val}) where {T,N} +Base.@assume_effects :removable :foldable :nothrow @inline function staticInTup(::Val{T}, tup::NTuple{N,Val})::Bool where {T,N} any(ntuple(Val(N)) do i Base.@_inline_meta Val(T) == tup[i] @@ -125,7 +121,7 @@ end ::Val{justActive}, ::Val{UnionSret}, ::Val{AbstractIsMixed}, -) where {ST,Seen,justActive,UnionSret,AbstractIsMixed} +)::ActivityState where {ST,Seen,justActive,UnionSret,AbstractIsMixed} if ST isa Union return forcefold( Val( @@ -285,7 +281,7 @@ end return DupState end end - if datatype_fieldcount(aT) === nothing + if Base.datatype_fieldcount(aT) === nothing if AbstractIsMixed return MixedState else @@ -383,11 +379,11 @@ end return ty end -@inline @generated function active_reg_nothrow(::Type{T}, ::Val{world}) where {T,world} +Base.@assume_effects :removable :foldable @inline @generated function active_reg_nothrow(::Type{T}, ::Val{world})::ActivityState where {T,world} return active_reg_inner(T, (), world) end -Base.@pure @inline function active_reg( +Base.@assume_effects :removable :foldable @inline function active_reg( ::Type{T}, world::Union{Nothing,UInt} = nothing, )::Bool where {T} @@ -411,13 +407,13 @@ Base.@pure @inline function active_reg( end end -@inline function guaranteed_const(::Type{T}) where {T} +Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_const(::Type{T})::Bool where {T} rt = active_reg_nothrow(T, Val(nothing)) res = rt == AnyState return res end -@inline function guaranteed_const_nongen(::Type{T}, world) where {T} +Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_const_nongen(::Type{T}, world)::Bool where {T} rt = active_reg_inner(T, (), world) res = rt == AnyState return res @@ -425,7 +421,7 @@ end # check if a value is guaranteed to be not contain active[register] data # (aka not either mixed or active) -@inline function guaranteed_nonactive(::Type{T}) where {T} +Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive(::Type{T})::Bool where {T} rt = Enzyme.Compiler.active_reg_nothrow(T, Val(nothing)) return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState end @@ -435,10 +431,10 @@ end Try to guess the most appropriate [`Annotation`](@ref) for arguments of type `T` passed to [`autodiff`](@ref) with a given `mode`. """ -@inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = +Base.@assume_effects :removable :foldable :nothrow @inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = guess_activity(T, convert(API.CDerivativeMode, mode)) -@inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T} +Base.@assume_effects :removable :foldable :nothrow @inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T} ActReg = active_reg_nothrow(T, Val(nothing)) if ActReg == AnyState return Const{T} diff --git a/src/api.jl b/src/api.jl index 3cdba76ae8..0d5de6a15c 100644 --- a/src/api.jl +++ b/src/api.jl @@ -2,6 +2,7 @@ module API import LLVM.API: LLVMValueRef, LLVMModuleRef, LLVMTypeRef, LLVMContextRef using Enzyme_jll +using EnzymeCore using Libdl using LLVM using CEnum @@ -208,6 +209,20 @@ end # but don't need the forward ) +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:EnzymeCore.Const} = API.DFT_CONSTANT +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:EnzymeCore.Active} = + API.DFT_OUT_DIFF +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:EnzymeCore.Duplicated} = + API.DFT_DUP_ARG +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:EnzymeCore.BatchDuplicated} = + API.DFT_DUP_ARG +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:EnzymeCore.BatchDuplicatedFunc} = + API.DFT_DUP_ARG +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:EnzymeCore.DuplicatedNoNeed} = + API.DFT_DUP_NONEED +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:EnzymeCore.BatchDuplicatedNoNeed} = + API.DFT_DUP_NONEED + @cenum( CDerivativeMode, DEM_ForwardMode = 0, diff --git a/src/compiler.jl b/src/compiler.jl index 16d0119e97..0db62b6d48 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -161,20 +161,6 @@ end include("llvm/attributes.jl") -# From https://github.com/JuliaLang/julia/blob/81813164963f38dcd779d65ecd222fad8d7ed437/src/cgutils.cpp#L570 -@inline function isghostty(@nospecialize(ty)) - if ty === Union{} - return true - end - if Base.isconcretetype(ty) && !ismutabletype(ty) - if sizeof(ty) == 0 - return true - end - # TODO consider struct_to_llvm ? - end - return false -end - include("analyses/activity.jl") # User facing interface @@ -222,87 +208,13 @@ end using .JIT include("jlrt.jl") +include("errors.jl") -AnyArray(Length::Int) = NamedTuple{ntuple(Symbol, Val(Length)),NTuple{Length,Any}} - -struct EnzymeRuntimeException <: Base.Exception - msg::Cstring -end - -function Base.showerror(io::IO, ece::EnzymeRuntimeException) - print(io, "Enzyme execution failed.\n") - msg = Base.unsafe_string(ece.msg) - print(io, msg, '\n') -end - -struct EnzymeMutabilityException <: Base.Exception - msg::Cstring -end - -function Base.showerror(io::IO, ece::EnzymeMutabilityException) - msg = Base.unsafe_string(ece.msg) - print(io, msg, '\n') -end - -struct EnzymeRuntimeActivityError <: Base.Exception - msg::Cstring -end - -function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError) - println(io, "Constant memory is stored (or returned) to a differentiable variable.") - println( - io, - "As a result, Enzyme cannot provably ensure correctness and throws this error.", - ) - println( - io, - "This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).", - ) - println( - io, - "If Enzyme should be able to prove this use non-differentable, open an issue!", - ) - println(io, "To work around this issue, either:") - println( - io, - " a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or", - ) - println( - io, - " b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.", - ) - msg = Base.unsafe_string(ece.msg) - print(io, msg, '\n') -end - -struct EnzymeNoTypeError <: Base.Exception - msg::Cstring -end - -function Base.showerror(io::IO, ece::EnzymeNoTypeError) - print(io, "Enzyme cannot deduce type\n") - msg = Base.unsafe_string(ece.msg) - print(io, msg, '\n') -end - -struct EnzymeNoShadowError <: Base.Exception - msg::Cstring -end - -function Base.showerror(io::IO, ece::EnzymeNoShadowError) - print(io, "Enzyme could not find shadow for value\n") - msg = Base.unsafe_string(ece.msg) - print(io, msg, '\n') -end - -struct EnzymeNoDerivativeError <: Base.Exception - msg::Cstring -end +include("typeutils/conversion.jl") +include("typeutils/jltypes.jl") +include("typeutils/lltypes.jl") -function Base.showerror(io::IO, ece::EnzymeNoDerivativeError) - msg = Base.unsafe_string(ece.msg) - print(io, msg, '\n') -end +AnyArray(Length::Int) = NamedTuple{ntuple(Symbol, Val(Length)),NTuple{Length,Any}} const JuliaEnzymeNameMap = Dict{String,Any}( "enz_val_true" => Val(true), @@ -390,94 +302,8 @@ const JuliaGlobalNameMap = Dict{String,Any}( ) include("absint.jl") - -# Force sret -struct Return2 - ret1::Any - ret2::Any -end - -function force_recompute!(mod::LLVM.Module) - for f in functions(mod), bb in blocks(f) - iter = LLVM.API.LLVMGetFirstInstruction(bb) - while iter != C_NULL - inst = LLVM.Instruction(iter) - iter = LLVM.API.LLVMGetNextInstruction(iter) - if isa(inst, LLVM.LoadInst) - has_loaded = false - for u in LLVM.uses(inst) - v = LLVM.user(u) - if isa(v, LLVM.CallInst) - cf = LLVM.called_operand(v) - if isa(cf, LLVM.Function) && LLVM.name(cf) == "julia.gc_loaded" && operands(v)[2] == inst - has_loaded = true - break - end - end - if isa(v, LLVM.BitCastInst) - for u2 in LLVM.uses(v) - v2 = LLVM.user(u2) - if isa(v2, LLVM.CallInst) - cf = LLVM.called_operand(v2) - if isa(cf, LLVM.Function) && LLVM.name(cf) == "julia.gc_loaded" && operands(v2)[2] == v - has_loaded = true - break - end - end - end - end - end - if has_loaded - metadata(inst)["enzyme_nocache"] = MDNode(LLVM.Metadata[]) - end - end - if isa(inst, LLVM.CallInst) - cf = LLVM.called_operand(inst) - if isa(cf, LLVM.Function) - if LLVM.name(cf) == "llvm.julia.gc_preserve_begin" - has_use = false - for u2 in LLVM.uses(inst) - has_use = true - break - end - if !has_use - eraseInst(bb, inst) - end - end - end - end - end - end -end - -function permit_inlining!(f::LLVM.Function) - for bb in blocks(f), inst in instructions(bb) - # remove illegal invariant.load and jtbaa_const invariants - if isa(inst, LLVM.LoadInst) - md = metadata(inst) - if haskey(md, LLVM.MD_tbaa) - modified = LLVM.Metadata( - ccall( - (:EnzymeMakeNonConstTBAA, API.libEnzyme), - LLVM.API.LLVMMetadataRef, - (LLVM.API.LLVMMetadataRef,), - md[LLVM.MD_tbaa], - ), - ) - setindex!(md, modified, LLVM.MD_tbaa) - end - if haskey(md, LLVM.MD_invariant_load) - delete!(md, LLVM.MD_invariant_load) - end - end - end -end - -struct Tape{TapeTy,ShadowTy,ResT} - internal_tape::TapeTy - shadow_return::ShadowTy -end - +include("llvm/transforms.jl") +include("llvm/passes.jl") include("typeutils/make_zero.jl") function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt) @@ -498,7 +324,7 @@ function prepare_llvm(mod::LLVM.Module, job, meta) end llvmfn = functions(mod)[k_name] - RT = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype + RT = return_type(interp, mi) _, _, returnRoots = get_return_info(RT) returnRoots = returnRoots !== nothing @@ -542,17 +368,9 @@ function nested_codegen!( params = PrimalCompilerParams(mode) job = CompilerJob(funcspec, CompilerConfig(target, params; kernel = false), world) - # TODO - parent_job = nothing - - otherMod, meta = GPUCompiler.codegen( - :llvm, - job; - optimize = false, - cleanup = false, - validate = false, - parent_job = parent_job, - ) + GPUCompiler.prepare_job!(job) + otherMod, meta = GPUCompiler.emit_llvm(job; libraries=true, toplevel=true, optimize=false, cleanup=false, only_entry=false, validate=false) + prepare_llvm(otherMod, job, meta) entry = name(meta.entry) @@ -601,8 +419,6 @@ function removed_ret_parms(F::LLVM.Function) return retRemove, parmsRemoved end -include("errors.jl") - const CheckNan = Ref(false) function julia_sanitize( orig::LLVM.API.LLVMValueRef, @@ -655,187 +471,11 @@ function julia_sanitize( return val.ref end -function any_jltypes(Type::LLVM.PointerType) - if 10 <= LLVM.addrspace(Type) <= 12 - return true - else - # do we care about {} addrspace(11)** - return false - end -end - -any_jltypes(Type::LLVM.StructType) = any(any_jltypes, LLVM.elements(Type)) -any_jltypes(Type::Union{LLVM.VectorType,LLVM.ArrayType}) = any_jltypes(eltype(Type)) -any_jltypes(::LLVM.IntegerType) = false -any_jltypes(::LLVM.FloatingPointType) = false -any_jltypes(::LLVM.VoidType) = false - -@inline any_jltypes(::Type{Nothing}) = false -@inline any_jltypes(::Type{T}) where {T<:AbstractFloat} = false -@inline any_jltypes(::Type{T}) where {T<:Integer} = false -@inline any_jltypes(::Type{Complex{T}}) where {T} = any_jltypes(T) -@inline any_jltypes(::Type{Tuple{}}) = false -@inline any_jltypes(::Type{NTuple{Size,T}}) where {Size,T} = any_jltypes(T) -@inline any_jltypes(::Type{Core.LLVMPtr{T,Addr}}) where {T,Addr} = 10 <= Addr <= 12 -@inline any_jltypes(::Type{Any}) = true -@inline any_jltypes(::Type{NamedTuple{A,B}}) where {A,B} = - any(any_jltypes(b) for b in B.parameters) -@inline any_jltypes(::Type{T}) where {T<:Tuple} = any(any_jltypes(b) for b in T.parameters) - -nfields(Type::LLVM.StructType) = length(LLVM.elements(Type)) -nfields(Type::LLVM.VectorType) = size(Type) -nfields(Type::LLVM.ArrayType) = length(Type) -nfields(Type::LLVM.PointerType) = 1 - mutable struct EnzymeTapeToLoad{T} data::T end Base.eltype(::EnzymeTapeToLoad{T}) where {T} = T -const TapeTypes = Dict{String,DataType}() - -base_type(T::UnionAll) = base_type(T.body) -base_type(T::DataType) = T - -const WideIntWidths = [256, 512, 1024, 2048] - -let - for n ∈ WideIntWidths - let T = Symbol(:UInt, n) - eval(quote - primitive type $T <: Unsigned $n end - end) - end - end -end -# return result and if contains any -function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} - tkind = LLVM.API.LLVMGetTypeKind(Type) - if tkind == LLVM.API.LLVMStructTypeKind - tys = DataType[] - nelems = LLVM.API.LLVMCountStructElementTypes(Type) - containsAny = false - syms = Symbol[] - for i = 1:nelems - e = LLVM.API.LLVMStructGetTypeAtIndex(Type, i - 1) - T, sub = to_tape_type(e) - containsAny |= sub - push!(tys, T) - push!(syms, Symbol(i)) - end - Tup = Tuple{tys...} - if containsAny - res = (syms...,) - return NamedTuple{res,Tup}, false - else - return Tup, false - end - end - if tkind == LLVM.API.LLVMPointerTypeKind - addrspace = LLVM.API.LLVMGetPointerAddressSpace(Type) - if 10 <= addrspace <= 12 - return Any, true - else - e = LLVM.API.LLVMGetElementType(Type) - tkind2 = LLVM.API.LLVMGetTypeKind(e) - if tkind2 == LLVM.API.LLVMFunctionTypeKind - return Core.LLVMPtr{Cvoid,Int(addrspace)}, false - else - return Core.LLVMPtr{to_tape_type(e)[1],Int(addrspace)}, false - end - end - end - if tkind == LLVM.API.LLVMArrayTypeKind - e = LLVM.API.LLVMGetElementType(Type) - T, sub = to_tape_type(e) - len = Int(LLVM.API.LLVMGetArrayLength(Type)) - Tup = NTuple{len,T} - if sub - return NamedTuple{ntuple(Core.Symbol, Val(len)),Tup}, false - else - return Tup, false - end - end - if tkind == LLVM.API.LLVMVectorTypeKind - e = LLVM.API.LLVMGetElementType(Type) - T, sub = to_tape_type(e) - len = Int(LLVM.API.LLVMGetVectorSize(Type)) - Tup = NTuple{len,T} - if sub - return NamedTuple{ntuple(Core.Symbol, Val(len)),Tup}, false - else - return Tup, false - end - end - if tkind == LLVM.API.LLVMIntegerTypeKind - N = LLVM.API.LLVMGetIntTypeWidth(Type) - if N == 1 - return Bool, false - elseif N == 8 - return UInt8, false - elseif N == 16 - return UInt16, false - elseif N == 32 - return UInt32, false - elseif N == 64 - return UInt64, false - elseif N == 128 - return UInt128, false - elseif N == 256 - return UInt256, false - elseif N == 512 - return UInt512, false - elseif N == 1024 - return UInt1024, false - elseif N == 2048 - return UInt2048, false - else - error("Can't construct tape type for integer of width $N") - end - end - if tkind == LLVM.API.LLVMHalfTypeKind - return Float16, false - end - if tkind == LLVM.API.LLVMFloatTypeKind - return Float32, false - end - if tkind == LLVM.API.LLVMDoubleTypeKind - return Float64, false - end - if tkind == LLVM.API.LLVMFP128TypeKind - return Float128, false - end - error("Can't construct tape type for $Type $(string(Type)) $tkind") -end - -function tape_type(@nospecialize(LLVMType::LLVM.LLVMType)) - TT, isAny = to_tape_type(LLVMType.ref) - if isAny - return AnonymousStruct(Tuple{Any}) - end - return TT -end - -from_tape_type(::Type{T}) where {T<:AbstractFloat} = convert(LLVMType, T) -from_tape_type(::Type{T}) where {T<:Integer} = convert(LLVMType, T) -from_tape_type(::Type{NTuple{Size,T}}) where {Size,T} = - LLVM.ArrayType(from_tape_type(T), Size) -from_tape_type(::Type{Core.LLVMPtr{T,Addr}}) where {T,Addr} = - LLVM.PointerType(from_tape_type(UInt8), Addr) -# from_tape_type(::Type{Core.LLVMPtr{T, Addr}}, ctx) where {T, Addr} = LLVM.PointerType(from_tape_type(T, ctx), Addr) -from_tape_type(::Type{Any}) = LLVM.PointerType(LLVM.StructType(LLVM.LLVMType[]), Tracked) -function from_tape_type(::Type{NamedTuple{A,B}}) where {A,B} - from_tape_type(B) -end -function from_tape_type(::Type{B}) where {B<:Tuple} - ar = LLVM.LLVMType[from_tape_type(b) for b in B.parameters] - if length(B.parameters) >= 1 && all(ar[1] == b for b in ar) - return LLVM.ArrayType(ar[1], length(B.parameters)) - else - return LLVM.StructType(LLVM.LLVMType[from_tape_type(b) for b in B.parameters]) - end -end - # See get_current_task_from_pgcstack (used from 1.7+) current_task_offset() = -(unsafe_load(cglobal(:jl_task_gcstack_offset, Cint)) ÷ sizeof(Ptr{Cvoid})) @@ -844,126 +484,6 @@ current_task_offset() = current_ptls_offset() = unsafe_load(cglobal(:jl_task_ptls_offset, Cint)) ÷ sizeof(Ptr{Cvoid}) -function store_nonjl_types!(B::LLVM.IRBuilder, @nospecialize(startval::LLVM.Value), @nospecialize(p::LLVM.Value)) - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - vals = LLVM.Value[] - if p != nothing - push!(vals, p) - end - todo = Tuple{Tuple,LLVM.Value}[((), startval)] - while length(todo) != 0 - path, cur = popfirst!(todo) - ty = value_type(cur) - if isa(ty, LLVM.PointerType) - if any_jltypes(ty) - continue - end - end - if isa(ty, LLVM.ArrayType) - if any_jltypes(ty) - for i = 1:length(ty) - ev = extract_value!(B, cur, i - 1) - push!(todo, ((path..., i - 1), ev)) - end - continue - end - end - if isa(ty, LLVM.StructType) - if any_jltypes(ty) - for (i, t) in enumerate(LLVM.elements(ty)) - ev = extract_value!(B, cur, i - 1) - push!(todo, ((path..., i - 1), ev)) - end - continue - end - end - parray = LLVM.Value[LLVM.ConstantInt(LLVM.IntType(64), 0)] - for v in path - push!(parray, LLVM.ConstantInt(LLVM.IntType(32), v)) - end - gptr = gep!(B, value_type(startval), p, parray) - st = store!(B, cur, gptr) - end - return -end - -function get_julia_inner_types(B::LLVM.IRBuilder, @nospecialize(p::Union{Nothing, LLVM.Value}), @nospecialize(startvals::Vararg{LLVM.Value}); added = LLVM.API.LLVMValueRef[]) - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - vals = LLVM.Value[] - if p != nothing - push!(vals, p) - end - todo = LLVM.Value[startvals...] - while length(todo) != 0 - cur = popfirst!(todo) - ty = value_type(cur) - if isa(ty, LLVM.PointerType) - if any_jltypes(ty) - if addrspace(ty) != Tracked - cur = addrspacecast!( - B, - cur, - LLVM.PointerType(eltype(ty), Tracked), - LLVM.name(cur) * ".innertracked", - ) - if isa(cur, LLVM.Instruction) - push!(added, cur.ref) - end - end - if value_type(cur) != T_prjlvalue - cur = bitcast!(B, cur, T_prjlvalue) - if isa(cur, LLVM.Instruction) - push!(added, cur.ref) - end - end - push!(vals, cur) - end - continue - end - if isa(ty, LLVM.ArrayType) - if any_jltypes(ty) - for i = 1:length(ty) - ev = extract_value!(B, cur, i - 1) - if isa(ev, LLVM.Instruction) - push!(added, ev.ref) - end - push!(todo, ev) - end - end - continue - end - if isa(ty, LLVM.StructType) - for (i, t) in enumerate(LLVM.elements(ty)) - if any_jltypes(t) - ev = extract_value!(B, cur, i - 1) - if isa(ev, LLVM.Instruction) - push!(added, ev.ref) - end - push!(todo, ev) - end - end - continue - end - if isa(ty, LLVM.IntegerType) - continue - end - if isa(ty, LLVM.FloatingPointType) - continue - end - msg = sprint() do io - println(io, "Enzyme illegal subtype") - println(io, "ty=", ty) - println(io, "cur=", cur) - println(io, "p=", p) - println(io, "startvals=", startvals) - end - throw(AssertionError(msg)) - end - return vals -end - function julia_post_cache_store( SI::LLVM.API.LLVMValueRef, B::LLVM.API.LLVMBuilderRef, @@ -1124,7 +644,7 @@ function zero_allocation(B::LLVM.API.LLVMBuilderRef, LLVMType::LLVM.API.LLVMType B = LLVM.IRBuilder(B) LLVMType = LLVM.LLVMType(LLVMType) obj = LLVM.Value(obj) - jlType = tape_type(LLVMType) + jlType = Compiler.tape_type(LLVMType) zeroAll = isTape == 0 func = LLVM.parent(position(B)) mod = LLVM.parent(func) @@ -1313,7 +833,7 @@ function julia_allocator(B::LLVM.IRBuilder, @nospecialize(LLVMType::LLVM.LLVMTyp esizeof(X) = X == Any ? sizeof(Int) : sizeof(X) - TT = tape_type(LLVMType) + TT = Compiler.tape_type(LLVMType) if esizeof(TT) != convert(Int, AlignedSize) GPUCompiler.@safe_error "Enzyme aligned size and Julia size disagree" AlignedSize = convert(Int, AlignedSize) esizeof(TT) fieldtypes(TT) @@ -1712,10 +1232,10 @@ else ) end -include("compiler/passes.jl") include("compiler/optimize.jl") include("compiler/interpreter.jl") include("compiler/validation.jl") +include("typeutils/inference.jl") import .Interpreter: isKWCallSignature @@ -1741,176 +1261,6 @@ Create the methodinstance pair, and lookup the primal return type. return primal end -function primal_interp_world( - @nospecialize(::ReverseMode), - world::UInt -) - mode = Enzyme.API.DEM_ReverseModeCombined - - CT = @static if VERSION >= v"1.11.0-DEV.1552" - EnzymeCacheToken( - typeof(DefaultCompilerTarget()), - false, - GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=# - EnzymeCompilerParams, - false, - ) - else - Enzyme.Compiler.GLOBAL_REV_CACHE - end - - Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode) -end - -function primal_interp_world( - @nospecialize(::ForwardMode), - world::UInt -) - mode = Enzyme.API.DEM_ForwardMode - - CT = @static if VERSION >= v"1.11.0-DEV.1552" - EnzymeCacheToken( - typeof(DefaultCompilerTarget()), - false, - GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=# - EnzymeCompilerParams, - true, - ) - else - Enzyme.Compiler.GLOBAL_FWD_CACHE - end - - Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode) -end - -@inline primal_interp_world( - @nospecialize(::ReverseModeSplit), - world::UInt) = primal_interp_world(Reverse, world) - -function primal_return_type_world( - @nospecialize(mode::Mode), - world::UInt, - @nospecialize(TT::Type), -) - Core.Compiler._return_type(primal_interp_world(mode, world), TT) -end - -function primal_return_type_world( - @nospecialize(mode::Mode), - world::UInt, - mi::Core.MethodInstance, -) - interp = primal_interp_world(mode, world) - something( - Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), - Any, - ) -end - -primal_return_type_world( - @nospecialize(mode::Mode), - world::UInt, - @nospecialize(FT::Type), - @nospecialize(TT::Type), - ) = primal_return_type_world(mode, world, Tuple{FT, TT.parameters...}) - -function primal_return_type_generator(world::UInt, source, self, @nospecialize(mode::Type), @nospecialize(ft::Type), @nospecialize(tt::Type)) - @nospecialize - @assert Core.Compiler.isType(ft) && Core.Compiler.isType(tt) - @assert mode <: Mode - mode = mode() - ft = ft.parameters[1] - tt = tt.parameters[1] - - # validation - ft <: Core.Builtin && - error("$(GPUCompiler.unsafe_function_from_type(ft)) is not a generic function") - - # look up the method - method_error = :(throw(MethodError(ft, tt, $world))) - sig = Tuple{ft,tt.parameters...} - min_world = Ref{UInt}(typemin(UInt)) - max_world = Ref{UInt}(typemax(UInt)) - has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results - #interp = primal_interp_world(mode, world) - #method_table = Core.Compiler.method_table(interp) - method_table = nothing - mthds = Base._methods_by_ftype( - sig, - method_table, - -1, #=lim=# - world, - false, #=ambig=# - min_world, - max_world, - has_ambig, - ) - stub = Core.GeneratedFunctionStub( - identity, - Core.svec(:methodinstance, :mode, :ft, :tt), - Core.svec(), - ) - mthds === nothing && return stub(world, source, method_error) - length(mthds) == 1 || return stub(world, source, method_error) - - # look up the method and code instance - mtypes, msp, m = mthds[1] - mi = ccall( - :jl_specializations_get_linfo, - Ref{Core.MethodInstance}, - (Any, Any, Any), - m, - mtypes, - msp, - ) - ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo - - # prepare a new code info - new_ci = copy(ci) - empty!(new_ci.code) - @static if isdefined(Core, :DebugInfo) - new_ci.debuginfo = Core.DebugInfo(:none) - else - empty!(new_ci.codelocs) - resize!(new_ci.linetable, 1) # see note below - end - empty!(new_ci.ssaflags) - new_ci.ssavaluetypes = 0 - new_ci.min_world = min_world[] - new_ci.max_world = max_world[] - new_ci.edges = Core.MethodInstance[mi] - # XXX: setting this edge does not give us proper method invalidation, see - # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. - # invoking `code_llvm` also does the necessary codegen, as does calling the - # underlying C methods -- which GPUCompiler does, so everything Just Works. - - # prepare the slots - new_ci.slotnames = Symbol[Symbol("#self#"), :mode, :ft, :tt] - new_ci.slotflags = UInt8[0x00 for i = 1:4] - - # return the codegen world age - res = primal_return_type_world(mode, world, mi) - push!(new_ci.code, Core.Compiler.ReturnNode(res)) - push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` - @static if isdefined(Core, :DebugInfo) - else - push!(new_ci.codelocs, 1) # see note below - end - new_ci.ssavaluetypes += 1 - - # NOTE: we keep the first entry of the original linetable, and use it for location info - # on the call to check_cache. we can't not have a codeloc (using 0 causes - # corruption of the back trace), and reusing the target function's info - # has as advantage that we see the name of the kernel in the backtraces. - - return new_ci -end - -@eval Base.@assume_effects :removable :foldable :nothrow @inline function primal_return_type(mode::Mode, ft::Type, tt::Type) - $(Expr(:meta, :generated_only)) - $(Expr(:meta, :generated, primal_return_type_generator)) -end - ## # Enzyme compiler step ## @@ -1981,20 +1331,6 @@ end include("rules/activityrules.jl") -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:Const} = API.DFT_CONSTANT -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:Active} = - API.DFT_OUT_DIFF -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:Duplicated} = - API.DFT_DUP_ARG -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:BatchDuplicated} = - API.DFT_DUP_ARG -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:BatchDuplicatedFunc} = - API.DFT_DUP_ARG -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:DuplicatedNoNeed} = - API.DFT_DUP_NONEED -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:BatchDuplicatedNoNeed} = - API.DFT_DUP_NONEED - const DumpPreEnzyme = Ref(false) const DumpPostWrap = Ref(false) @@ -2037,14 +1373,7 @@ function enzyme!( @assert length(modifiedBetween) == length(TT.parameters) - swiftself = any( - any( - map( - k -> kind(k) == kind(EnumAttribute("swiftself")), - collect(parameter_attributes(primalf, i)), - ), - ) for i = 1:length(collect(parameters(primalf))) - ) + swiftself = has_swiftself(primalf) if swiftself push!(args_activity, API.DFT_CONSTANT) push!(args_typeInfo, TypeTree()) @@ -2164,10 +1493,10 @@ function enzyme!( tape = API.EnzymeExtractTapeTypeFromAugmentation(augmented) utape = API.EnzymeExtractUnderlyingTapeTypeFromAugmentation(augmented) if utape != C_NULL - TapeType = EnzymeTapeToLoad{tape_type(LLVMType(utape))} + TapeType = EnzymeTapeToLoad{Compiler.tape_type(LLVMType(utape))} tape = utape elseif tape != C_NULL - TapeType = tape_type(LLVMType(tape)) + TapeType = Compiler.tape_type(LLVMType(tape)) else TapeType = Cvoid end @@ -2365,13 +1694,8 @@ function create_abi_wrapper( mod = LLVM.parent(enzymefn) ctx = LLVM.context(mod) - push!(function_attributes(enzymefn), EnumAttribute("alwaysinline", 0)) - hasNoInline = any( - map( - k -> kind(k) == kind(EnumAttribute("noinline")), - collect(function_attributes(enzymefn)), - ), - ) + push!(function_attributes(enzymefn), EnumAttribute("alwaysinline")) + hasNoInline = has_fn_attr(enzymefn, EnumAttribute("noinline")) if hasNoInline LLVM.API.LLVMRemoveEnumAttributeAtIndex( enzymefn, @@ -2509,9 +1833,9 @@ function create_abi_wrapper( tape = API.EnzymeExtractTapeTypeFromAugmentation(augmented) utape = API.EnzymeExtractUnderlyingTapeTypeFromAugmentation(augmented) if utape != C_NULL - TapeType = EnzymeTapeToLoad{tape_type(LLVMType(utape))} + TapeType = EnzymeTapeToLoad{Compiler.tape_type(LLVMType(utape))} elseif tape != C_NULL - TapeType = tape_type(LLVMType(tape)) + TapeType = Compiler.tape_type(LLVMType(tape)) else TapeType = Cvoid end @@ -2623,7 +1947,7 @@ function create_abi_wrapper( end if tape != C_NULL tape = LLVM.LLVMType(tape) - jltape = convert(LLVM.LLVMType, tape_type(tape); allow_boxed = true) + jltape = convert(LLVM.LLVMType, Compiler.tape_type(tape); allow_boxed = true) push!(T_wrapperargs, jltape) else needs_tape = false @@ -2791,7 +2115,7 @@ function create_abi_wrapper( funcspec = my_methodinstance(Func, Tuple{}, world) llvmf = nested_codegen!(Mode, mod, funcspec, world) push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) - Func_RT = Core.Compiler.typeinf_ext_toplevel(interp, funcspec).rettype + Func_RT = return_type(interp, funcspec) @assert Func_RT == NTuple{width,T′} _, psret, _ = get_return_info(Func_RT) args = LLVM.Value[] @@ -3169,297 +2493,6 @@ function fixup_metadata!(f::LLVM.Function) end end -struct RemovedParam end - -# Modified from GPUCompiler classify_arguments -function classify_arguments( - @nospecialize(source_sig::Type), - codegen_ft::LLVM.FunctionType, - has_sret::Bool, - has_returnroots::Bool, - has_swiftself::Bool, - parmsRemoved::Vector{UInt64}, -) - codegen_types = parameters(codegen_ft) - - args = [] - codegen_i = 1 - orig_i = 1 - if has_sret - if !in(orig_i - 1, parmsRemoved) - codegen_i += 1 - end - orig_i += 1 - end - if has_returnroots - if !in(orig_i - 1, parmsRemoved) - codegen_i += 1 - end - orig_i += 1 - end - if has_swiftself - if !in(orig_i - 1, parmsRemoved) - codegen_i += 1 - end - orig_i += 1 - end - for (source_i, source_typ) in enumerate(source_sig.parameters) - if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) - push!(args, (cc = GPUCompiler.GHOST, typ = source_typ, arg_i = source_i)) - continue - end - if in(orig_i - 1, parmsRemoved) - push!(args, (cc = RemovedParam, typ = source_typ)) - orig_i += 1 - continue - end - codegen_typ = codegen_types[codegen_i] - - if codegen_typ isa LLVM.PointerType - llvm_source_typ = convert(LLVMType, source_typ; allow_boxed = true) - # pointers are used for multiple kinds of arguments - # - literal pointer values - if source_typ <: Ptr || source_typ <: Core.LLVMPtr - @assert llvm_source_typ == codegen_typ - push!( - args, - ( - cc = GPUCompiler.BITS_VALUE, - typ = source_typ, - arg_i = source_i, - codegen = (typ = codegen_typ, i = codegen_i), - ), - ) - # - boxed values - # XXX: use `deserves_retbox` instead? - elseif llvm_source_typ isa LLVM.PointerType - @assert llvm_source_typ == codegen_typ - push!( - args, - ( - cc = GPUCompiler.MUT_REF, - typ = source_typ, - arg_i = source_i, - codegen = (typ = codegen_typ, i = codegen_i), - ), - ) - # - references to aggregates - else - @assert llvm_source_typ != codegen_typ - push!( - args, - ( - cc = GPUCompiler.BITS_REF, - typ = source_typ, - arg_i = source_i, - codegen = (typ = codegen_typ, i = codegen_i), - ), - ) - end - else - push!( - args, - ( - cc = GPUCompiler.BITS_VALUE, - typ = source_typ, - arg_i = source_i, - codegen = (typ = codegen_typ, i = codegen_i), - ), - ) - end - - codegen_i += 1 - orig_i += 1 - end - - return args -end - -function isSpecialPtr(@nospecialize(Ty::LLVM.LLVMType)) - if !isa(Ty, LLVM.PointerType) - return false - end - AS = LLVM.addrspace(Ty) - return 10 <= AS && AS <= 13 -end - -mutable struct CountTrackedPointers - count::UInt - all::Bool - derived::Bool -end - -function CountTrackedPointers(@nospecialize(T::LLVM.LLVMType)) - res = CountTrackedPointers(0, true, false) - - if isa(T, LLVM.PointerType) - if isSpecialPtr(T) - res.count += 1 - if LLVM.addrspace(T) != Tracked - res.derived = true - end - end - elseif isa(T, LLVM.StructType) - for ElT in elements(T) - sub = CountTrackedPointers(ElT) - res.count += sub.count - res.all &= sub.all - res.derived |= sub.derived - end - elseif isa(T, LLVM.ArrayType) - sub = CountTrackedPointers(eltype(T)) - res.count += sub.count - res.all &= sub.all - res.derived |= sub.derived - res.count *= length(T) - elseif isa(T, LLVM.VectorType) - sub = CountTrackedPointers(eltype(T)) - res.count += sub.count - res.all &= sub.all - res.derived |= sub.derived - res.count *= size(T) - end - if res.count == 0 - res.all = false - end - return res -end - -# must deserve sret -function deserves_rooting(@nospecialize(T::LLVM.LLVMType)) - tracked = CountTrackedPointers(T) - @assert !tracked.derived - if tracked.count != 0 && !tracked.all - return true # tracked.count; - end - return false -end - -# https://github.com/JuliaLang/julia/blob/64378db18b512677fc6d3b012e6d1f02077af191/src/cgutils.cpp#L823 -# returns if all unboxed -function for_each_uniontype_small(@nospecialize(f), @nospecialize(ty::Type), counter::Base.RefValue{Int} = Ref(0)) - if counter[] > 127 - return false - end - if ty isa Union - allunbox = for_each_uniontype_small(f, ty.a, counter) - allunbox &= for_each_uniontype_small(f, ty.b, counter) - return allunbox - end - # https://github.com/JuliaLang/julia/blob/170d6439445c86e640214620dad3423d2bb42337/src/codegen.cpp#L1233 - if Base.isconcretetype(ty) && !ismutabletype(ty) && Base.datatype_pointerfree(ty) - counter[] += 1 - f(ty) - return true - end - return false -end - -# From https://github.com/JuliaLang/julia/blob/038d31463f0ef744c8308bdbe87339b9c3f0b890/src/cgutils.cpp#L3108 -function union_alloca_type(@nospecialize(UT::Type)) - nbytes = 0 - function inner(@nospecialize(jlrettype::Type)) - if !(Base.issingletontype(jlrettype) && isa(jlrettype, DataType)) - nbytes = max(nbytes, sizeof(jlrettype)) - end - end - for_each_uniontype_small(inner, UT) - return nbytes -end - -# From https://github.com/JuliaLang/julia/blob/e6bf81f39a202eedc7bd4f310c1ab60b5b86c251/src/codegen.cpp#L6447 -function is_sret(@nospecialize(jlrettype::Type)) - if jlrettype === Union{} - # jlrettype == (jl_value_t*)jl_bottom_type - return false - elseif Base.isstructtype(jlrettype) && - Base.issingletontype(jlrettype) && - isa(jlrettype, DataType) - # jl_is_structtype(jlrettype) && jl_is_datatype_singleton((jl_datatype_t*)jlrettype) - return false - elseif jlrettype isa Union # jl_is_uniontype(jlrettype) - if union_alloca_type(jlrettype) > 0 - # sret, also a regular return here - return true - end - return false - elseif !GPUCompiler.deserves_retbox(jlrettype) - rt = convert(LLVMType, jlrettype) - if !isa(rt, LLVM.VoidType) && GPUCompiler.deserves_sret(jlrettype, rt) - return true - end - end - return false -end -function is_sret_union(@nospecialize(jlrettype::Type)) - if jlrettype === Union{} - # jlrettype == (jl_value_t*)jl_bottom_type - return false - elseif Base.isstructtype(jlrettype) && - Base.issingletontype(jlrettype) && - isa(jlrettype, DataType) - # jl_is_structtype(jlrettype) && jl_is_datatype_singleton((jl_datatype_t*)jlrettype) - return false - elseif jlrettype isa Union # jl_is_uniontype(jlrettype) - if union_alloca_type(jlrettype) > 0 - # sret, also a regular return here - return true - end - end - return false -end - -# https://github.com/JuliaLang/julia/blob/0a696a3842750fcedca8832bc0aabe9096c7658f/src/codegen.cpp#L6812 -function get_return_info( - @nospecialize(jlrettype::Type), -)::Tuple{Union{Nothing,Type},Union{Nothing,Type},Union{Nothing,Type}} - sret = nothing - returnRoots = nothing - rt = nothing - if jlrettype === Union{} - rt = Nothing - elseif Base.isstructtype(jlrettype) && - Base.issingletontype(jlrettype) && - isa(jlrettype, DataType) - rt = Nothing - elseif jlrettype isa Union - nbytes = 0 - allunbox = for_each_uniontype_small(jlrettype) do jlrettype - if !(Base.issingletontype(jlrettype) && isa(jlrettype, DataType)) - nbytes = max(nbytes, sizeof(jlrettype)) - end - end - if nbytes != 0 - rt = NamedTuple{(Symbol("1"), Symbol("2")),Tuple{Any,UInt8}} - # Pointer to?, Ptr{NTuple{UInt8, allunbox} - sret = Ptr{jlrettype} - elseif allunbox - rt = UInt8 - else - rt = Any - end - elseif jlrettype <: Tuple && in(Any, jlrettype.parameters) - rt = Any - elseif !GPUCompiler.deserves_retbox(jlrettype) - lRT = convert(LLVMType, jlrettype) - if !isa(lRT, LLVM.VoidType) && GPUCompiler.deserves_sret(jlrettype, lRT) - sret = Ptr{jlrettype} - tracked = CountTrackedPointers(lRT) - @assert !tracked.derived - if tracked.count != 0 && !tracked.all - returnRoots = Ptr{AnyArray(Int(tracked.count))} - end - else - rt = jlrettype - end - else - # retbox - rt = Ptr{jlrettype} - end - - return (rt, sret, returnRoots) -end - # Modified from GPUCompiler/src/irgen.jl:365 lower_byval function lower_convention( @nospecialize(functy::Type), @@ -3492,14 +2525,7 @@ function lower_convention( # TODO removed implications retRemoved, parmsRemoved = removed_ret_parms(entry_f) - swiftself = any( - any( - map( - k -> kind(k) == kind(EnumAttribute("swiftself")), - collect(parameter_attributes(entry_f, i)), - ), - ) for i = 1:length(collect(parameters(entry_f))) - ) + swiftself = has_swiftself(entry_f) @assert !swiftself "Swiftself attribute coming from differentiable context is not supported" prargs = classify_arguments(functy, entry_ft, sret, returnRoots, swiftself, parmsRemoved) @@ -3581,18 +2607,8 @@ function lower_convention( set_subprogram!(wrapper_f, sfn) end - hasReturnsTwice = any( - map( - k -> kind(k) == kind(EnumAttribute("returns_twice")), - collect(function_attributes(entry_f)), - ), - ) - hasNoInline = any( - map( - k -> kind(k) == kind(EnumAttribute("noinline")), - collect(function_attributes(entry_f)), - ), - ) + hasReturnsTwice = has_fn_attr(entry_f, EnumAttribute("returns_twice")) + hasNoInline = has_fn_attr(entry_f, EnumAttribute("noinline")) if hasNoInline LLVM.API.LLVMRemoveEnumAttributeAtIndex( entry_f, @@ -4248,16 +3264,9 @@ function GPUCompiler.codegen( primal_job = CompilerJob(primal, config2, job.world) # TODO EnzymeInterp params, etc end + GPUCompiler.prepare_job!(primal_job) + mod, meta = GPUCompiler.emit_llvm(primal_job; libraries=true, toplevel=toplevel, optimize=false, cleanup=false, only_entry=false, validate=false) - mod, meta = GPUCompiler.codegen( - :llvm, - primal_job; - optimize = false, - toplevel = toplevel, - cleanup = false, - validate = false, - parent_job = parent_job, - ) prepare_llvm(mod, primal_job, meta) for f in functions(mod) permit_inlining!(f) @@ -4320,20 +3329,10 @@ function GPUCompiler.codegen( end toremove = String[] for f in functions(mod) - if !any( - map( - k -> kind(k) == kind(EnumAttribute("alwaysinline")), - collect(function_attributes(f)), - ), - ) + if !has_fn_attr(f, EnumAttribute("alwaysinline")) continue end - if !any( - map( - k -> kind(k) == kind(EnumAttribute("returns_twice")), - collect(function_attributes(f)), - ), - ) + if !has_fn_attr(f, EnumAttribute("returnstwice")) push!(function_attributes(f), EnumAttribute("returns_twice")) push!(toremove, name(f)) end @@ -4411,14 +3410,7 @@ function GPUCompiler.codegen( end expectLen -= length(parmsRemoved) - swiftself = any( - any( - map( - k -> kind(k) == kind(EnumAttribute("swiftself")), - collect(parameter_attributes(f, i)), - ), - ) for i = 1:length(collect(parameters(f))) - ) + swiftself = has_swiftself(f) if swiftself expectLen += 1 @@ -5281,14 +4273,7 @@ end Ty = eltype(FT) reg = active_reg_inner(Ty, (), world) if reg == DupState || reg == MixedState - swiftself = any( - any( - map( - k -> kind(k) == kind(EnumAttribute("swiftself")), - collect(parameter_attributes(primalf, i)), - ), - ) for i = 1:length(collect(parameters(primalf))) - ) + swiftself = has_swiftself(primalf) todo = LLVM.Value[parameters(primalf)[1+swiftself]] done = Set{LLVM.Value}() doneInst = Set{LLVM.Instruction}() @@ -5479,20 +4464,10 @@ end end end end - if !any( - map( - k -> kind(k) == kind(EnumAttribute("alwaysinline")), - collect(function_attributes(f)), - ), - ) + if !has_fn_attr(f, EnumAttribute("alwaysinline")) continue end - if !any( - map( - k -> kind(k) == kind(EnumAttribute("returns_twice")), - collect(function_attributes(f)), - ), - ) + if !has_fn_attr(f, EnumAttribute("returns_twice")) push!(function_attributes(f), EnumAttribute("returns_twice")) push!(toremove, name(f)) end @@ -5726,20 +4701,6 @@ end args..., ) - -function jl_set_typeof(v::Ptr{Cvoid}, @nospecialize(T::Type)) - tag = reinterpret(Ptr{Any}, reinterpret(UInt, v) - 8) - Base.unsafe_store!(tag, T) # set tag - return nothing -end - -@generated function splatnew(::Type{T}, args::TT) where {T,TT<:Tuple} - return quote - Base.@_inline_meta - $(Expr(:splatnew, :T, :args)) - end -end - include("typeutils/recursive_add.jl") @inline function default_adjoint(T) @@ -6161,7 +5122,7 @@ end if needs_tape && !(isghostty(TapeType) || Core.Compiler.isconstType(TapeType)) tape = callparams[end] if TapeType <: EnzymeTapeToLoad - llty = from_tape_type(eltype(TapeType)) + llty = Compiler.from_tape_type(eltype(TapeType)) tape = bitcast!( builder, tape, @@ -6171,7 +5132,7 @@ end API.SetMustCache!(tape) callparams[end] = tape else - llty = from_tape_type(TapeType) + llty = Compiler.from_tape_type(TapeType) @assert value_type(tape) == llty end end @@ -6346,15 +5307,6 @@ const cache_lock = ReentrantLock() end end -@inline remove_innerty(::Type{<:Const}) = Const -@inline remove_innerty(::Type{<:Active}) = Active -@inline remove_innerty(::Type{<:Duplicated}) = Duplicated -@inline remove_innerty(::Type{<:DuplicatedNoNeed}) = DuplicatedNoNeed -@inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated -@inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed -@inline remove_innerty(::Type{<:MixedDuplicated}) = MixedDuplicated -@inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated - @inline function thunkbase( mi::Core.MethodInstance, World::Union{UInt, Nothing}, @@ -6395,7 +5347,7 @@ end interp = GPUCompiler.get_interpreter(tmp_job) # TODO check compile return here, early - rrt = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype + rrt = return_type(interp, mi) run_enzyme = true diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 66313d3853..f9769881a4 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -437,961 +437,6 @@ else end end -function addNA(@nospecialize(inst::LLVM.Instruction), @nospecialize(node::LLVM.Metadata), MD::LLVM.MDKind) - md = metadata(inst) - next = nothing - if haskey(md, MD) - next = LLVM.MDNode(Metadata[node, operands(md[MD])...]) - else - next = LLVM.MDNode(Metadata[node]) - end - setindex!(md, next, MD) -end - -function addr13NoAlias(mod::LLVM.Module) - ctx = LLVM.context(mod) - dom = API.EnzymeAnonymousAliasScopeDomain("addr13", ctx) - scope = API.EnzymeAnonymousAliasScope(dom, "na_addr13") - aliasscope = noalias = scope - for f in functions(mod), bb in blocks(f), inst in instructions(bb) - if isa(inst, LLVM.StoreInst) - addNA(inst, noalias, LLVM.MD_noalias) - elseif isa(inst, LLVM.CallInst) - fn = LLVM.called_operand(inst) - if isa(fn, LLVM.Function) - name = LLVM.name(fn) - if startswith(name, "llvm.memcpy") || startswith(name, "llvm.memmove") - addNA(inst, noalias, LLVM.MD_noalias) - end - end - elseif isa(inst, LLVM.LoadInst) - ty = value_type(inst) - if isa(ty, LLVM.PointerType) - if addrspace(ty) == 13 - addNA(inst, aliasscope, LLVM.MD_alias_scope) - end - end - end - end -end - -## given code like -# % a = alloca -# ... -# memref(cast(%a), %b, constant size == sizeof(a)) -# -# turn this into load/store, as this is more -# amenable to caching analysis infrastructure -function memcpy_alloca_to_loadstore(mod::LLVM.Module) - dl = datalayout(mod) - for f in functions(mod) - if length(blocks(f)) != 0 - bb = first(blocks(f)) - todel = Set{LLVM.Instruction}() - for alloca in instructions(bb) - if !isa(alloca, LLVM.AllocaInst) - continue - end - todo = Tuple{LLVM.Instruction,LLVM.Value}[(alloca, alloca)] - copy = nothing - legal = true - elty = LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(alloca)) - lifetimestarts = LLVM.Instruction[] - while length(todo) > 0 - cur, prev = pop!(todo) - if isa(cur, LLVM.AllocaInst) || - isa(cur, LLVM.AddrSpaceCastInst) || - isa(cur, LLVM.BitCastInst) - for u in LLVM.uses(cur) - u = LLVM.user(u) - push!(todo, (u, cur)) - end - continue - end - if isa(cur, LLVM.CallInst) && - isa(LLVM.called_operand(cur), LLVM.Function) - intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(cur)) - if intr == LLVM.Intrinsic("llvm.lifetime.start").id - push!(lifetimestarts, cur) - continue - end - if intr == LLVM.Intrinsic("llvm.lifetime.end").id - continue - end - if intr == LLVM.Intrinsic("llvm.memcpy").id - sz = operands(cur)[3] - if operands(cur)[1] == prev && - isa(sz, LLVM.ConstantInt) && - convert(Int, sz) == sizeof(dl, elty) - if copy === nothing || copy == cur - copy = cur - continue - end - end - end - end - - # read only insts of arg, don't matter - if isa(cur, LLVM.LoadInst) - continue - end - if isa(cur, LLVM.CallInst) && - isa(LLVM.called_operand(cur), LLVM.Function) - legalc = true - for (i, ci) in enumerate(operands(cur)[1:end-1]) - if ci == prev - nocapture = false - readonly = false - for a in collect( - parameter_attributes(LLVM.called_operand(cur), i), - ) - if kind(a) == kind(EnumAttribute("readonly")) - readonly = true - end - if kind(a) == kind(EnumAttribute("readnone")) - readonly = true - end - if kind(a) == kind(EnumAttribute("nocapture")) - nocapture = true - end - end - if !nocapture || !readonly - legalc = false - break - end - end - end - if legalc - continue - end - end - - legal = false - break - end - - if legal && copy !== nothing - B = LLVM.IRBuilder() - position!(B, copy) - dst = operands(copy)[1] - src = operands(copy)[2] - dst0 = bitcast!( - B, - dst, - LLVM.PointerType(LLVM.IntType(8), addrspace(value_type(dst))), - ) - - dst = - bitcast!(B, dst, LLVM.PointerType(elty, addrspace(value_type(dst)))) - src = - bitcast!(B, src, LLVM.PointerType(elty, addrspace(value_type(src)))) - - src = load!(B, elty, src) - FT = LLVM.FunctionType( - LLVM.VoidType(), - [LLVM.IntType(64), value_type(dst0)], - ) - lifetimestart, _ = get_function!(mod, "llvm.lifetime.start.p0i8", FT) - call!( - B, - FT, - lifetimestart, - LLVM.Value[LLVM.ConstantInt(Int64(sizeof(dl, elty))), dst0], - ) - store!(B, src, dst) - push!(todel, copy) - end - for lt in lifetimestarts - push!(todel, lt) - end - end - for inst in todel - eraseInst(LLVM.parent(inst), inst) - end - end - end -end - -# If there is a phi node of a decayed value, Enzyme may need to cache it -# Here we force all decayed pointer phis to first addrspace from 10 -function nodecayed_phis!(mod::LLVM.Module) - # Simple handler to fix addrspace 11 - #complex handler for addrspace 13, which itself comes from a load of an - # addrspace 10 - ctx = LLVM.context(mod) - for f in functions(mod) - - guaranteedInactive = false - - for attr in collect(function_attributes(f)) - if !isa(attr, LLVM.StringAttribute) - continue - end - if kind(attr) == "enzyme_inactive" - guaranteedInactive = true - break - end - end - - if guaranteedInactive - continue - end - - - entry_ft = LLVM.function_type(f) - - RT = LLVM.return_type(entry_ft) - inactiveRet = RT == LLVM.VoidType() - - for attr in collect(return_attributes(f)) - if !isa(attr, LLVM.StringAttribute) - continue - end - if kind(attr) == "enzyme_inactive" - inactiveRet = true - break - end - end - - if inactiveRet - for idx in length(collect(parameters(f))) - inactiveParm = false - for attr in collect(parameter_attributes(f, idx)) - if !isa(attr, LLVM.StringAttribute) - continue - end - if kind(attr) == "enzyme_inactive" - inactiveParm = true - break - end - end - if !inactiveParm - inactiveRet = false - break - end - end - if inactiveRet - continue - end - end - - offty = LLVM.IntType(8 * sizeof(Int)) - i8 = LLVM.IntType(8) - - for addr in (11, 13) - - nextvs = Dict{LLVM.PHIInst,LLVM.PHIInst}() - mtodo = Vector{LLVM.PHIInst}[] - goffsets = Dict{LLVM.PHIInst,LLVM.PHIInst}() - nonphis = LLVM.Instruction[] - anyV = false - for bb in blocks(f) - todo = LLVM.PHIInst[] - nonphi = nothing - for inst in instructions(bb) - if !isa(inst, LLVM.PHIInst) - nonphi = inst - break - end - ty = value_type(inst) - if !isa(ty, LLVM.PointerType) - continue - end - if addrspace(ty) != addr - continue - end - if addr == 11 - all_args = true - addrtodo = Value[inst] - seen = Set{LLVM.Value}() - - while length(addrtodo) != 0 - v = pop!(addrtodo) - base, _ = get_base_and_offset(v; offsetAllowed=false) - if in(base, seen) - continue - end - push!(seen, base) - if isa(base, LLVM.Argument) && addrspace(value_type(base)) == 11 - continue - end - if isa(base, LLVM.PHIInst) - for (v, _) in LLVM.incoming(base) - push!(addrtodo, v) - end - continue - end - all_args = false - break - end - if all_args - continue - end - end - - push!(todo, inst) - nb = IRBuilder() - position!(nb, inst) - el_ty = if addr == 11 - eltype(ty) - else - LLVM.StructType(LLVM.LLVMType[]) - end - nphi = phi!( - nb, - LLVM.PointerType(el_ty, 10), - "nodecayed." * LLVM.name(inst), - ) - nextvs[inst] = nphi - anyV = true - - goffsets[inst] = phi!(nb, offty, "nodecayedoff." * LLVM.name(inst)) - end - push!(mtodo, todo) - push!(nonphis, nonphi) - end - for (bb, todo, nonphi) in zip(blocks(f), mtodo, nonphis) - - for inst in todo - ty = value_type(inst) - el_ty = if addr == 11 - eltype(ty) - else - LLVM.StructType(LLVM.LLVMType[]) - end - nvs = Tuple{LLVM.Value,LLVM.BasicBlock}[] - offsets = Tuple{LLVM.Value,LLVM.BasicBlock}[] - for (v, pb) in LLVM.incoming(inst) - done = false - for ((nv, pb0), (offset, pb1)) in zip(nvs, offsets) - if pb0 == pb - push!(nvs, (nv, pb)) - push!(offsets, (offset, pb)) - done = true - break - end - end - if done - continue - end - b = IRBuilder() - position!(b, terminator(pb)) - - - v0 = v - @inline function getparent(@nospecialize(v::LLVM.Value), @nospecialize(offset::LLVM.Value), hasload::Bool) - if addr == 11 && addrspace(value_type(v)) == 10 - return v, offset, hasload - end - if addr == 13 && hasload && addrspace(value_type(v)) == 10 - return v, offset, hasload - end - if addr == 13 && !hasload - if isa(v, LLVM.LoadInst) - v2, o2, hl2 = getparent(operands(v)[1], LLVM.ConstantInt(offty, 0), true) - rhs = LLVM.ConstantInt(offty, 0) - if o2 != rhs - msg = sprint() do io::IO - println( - io, - "Enzyme internal error addr13 load doesn't keep offset 0", - ) - println(io, "v=", string(v)) - println(io, "v2=", string(v2)) - println(io, "o2=", string(o2)) - println(io, "hl2=", string(hl2)) - println(io, "offty=", string(offty)) - println(io, "rhs=", string(rhs)) - end - throw(AssertionError(msg)) - end - return v2, offset, true - end - if isa(v, LLVM.CallInst) - cf = LLVM.called_operand(v) - if isa(cf, LLVM.Function) && LLVM.name(cf) == "julia.gc_loaded" - ld = operands(v)[2] - while isa(ld, LLVM.BitCastInst) || isa(ld, LLVM.AddrSpaceCastInst) - ld = operands(ld)[1] - end - if isa(ld, LLVM.LoadInst) - v2, o2, hl2 = getparent(operands(ld)[1], LLVM.ConstantInt(offty, 0), true) - rhs = LLVM.ConstantInt(offty, sizeof(Int)) - - base_2, off_2 = get_base_and_offset(v2) - base_1, off_1 = get_base_and_offset(operands(v)[1]) - - if o2 == rhs && base_1 == base_2 && off_1 == off_2 - return operands(v)[1], offset, true - end - - pty = TypeTree(API.DT_Pointer, LLVM.context(ld)) - only!(pty, -1) - rhs = ptrtoint!(b, get_memory_data(b, operands(v)[1]), offty) - metadata(rhs)["enzyme_type"] = to_md(pty, ctx) - lhs = ptrtoint!(b, operands(v)[2], offty) - metadata(rhs)["enzyme_type"] = to_md(pty, ctx) - off2 = nuwsub!(b, lhs, rhs) - ity = TypeTree(API.DT_Integer, LLVM.context(ld)) - only!(ity, -1) - metadata(off2)["enzyme_type"] = to_md(ity, ctx) - add = nuwadd!(b, offset, off2) - metadata(add)["enzyme_type"] = to_md(ity, ctx) - return operands(v)[1], add, true - end - end - end - end - - if addr == 13 && isa(v, LLVM.ConstantExpr) - if opcode(v) == LLVM.API.LLVMAddrSpaceCast - v2 = operands(v)[1] - if addrspace(value_type(v2)) == 0 - if addr == 13 && isa(v, LLVM.ConstantExpr) - v2 = const_addrspacecast( - operands(v)[1], - LLVM.PointerType(eltype(value_type(v)), 10), - ) - return v2, offset, hasload - end - end - end - end - - if isa(v, LLVM.ConstantExpr) - if opcode(v) == LLVM.API.LLVMAddrSpaceCast - v2 = operands(v)[1] - if addrspace(value_type(v2)) == 10 - return v2, offset, hasload - end - if addrspace(value_type(v2)) == 0 - if addr == 11 - v2 = const_addrspacecast( - v2, - LLVM.PointerType(eltype(value_type(v)), 10), - ) - return v2, offset, hasload - end - end - if LLVM.isnull(v2) - v2 = const_addrspacecast( - v2, - LLVM.PointerType(eltype(value_type(v)), 10), - ) - return v2, offset, hasload - end - end - if opcode(v) == LLVM.API.LLVMBitCast - preop = operands(v)[1] - while isa(preop, LLVM.ConstantExpr) && opcode(preop) == LLVM.API.LLVMBitCast - preop = operands(preop)[1] - end - v2, offset, skipload = - getparent(preop, offset, hasload) - v2 = const_bitcast( - v2, - LLVM.PointerType( - eltype(value_type(v)), - addrspace(value_type(v2)), - ), - ) - @assert eltype(value_type(v2)) == eltype(value_type(v)) - return v2, offset, skipload - end - - if opcode(v) == LLVM.API.LLVMGetElementPtr - v2, offset, skipload = - getparent(operands(v)[1], offset, hasload) - offset = const_add( - offset, - API.EnzymeComputeByteOffsetOfGEP(b, v, offty), - ) - v2 = const_bitcast( - v2, - LLVM.PointerType( - eltype(value_type(v)), - addrspace(value_type(v2)), - ), - ) - @assert eltype(value_type(v2)) == eltype(value_type(v)) - return v2, offset, skipload - end - - end - - if isa(v, LLVM.AddrSpaceCastInst) - if addrspace(value_type(operands(v)[1])) == 0 - v2 = addrspacecast!( - b, - operands(v)[1], - LLVM.PointerType(eltype(value_type(v)), 10), - ) - return v2, offset, hasload - end - nv, noffset, nhasload = - getparent(operands(v)[1], offset, hasload) - if eltype(value_type(nv)) != eltype(value_type(v)) - nv = bitcast!( - b, - nv, - LLVM.PointerType( - eltype(value_type(v)), - addrspace(value_type(nv)), - ), - ) - end - return nv, noffset, nhasload - end - - if isa(v, LLVM.BitCastInst) - preop = operands(v)[1] - while isa(preop, LLVM.BitCastInst) - preop = operands(preop)[1] - end - v2, offset, skipload = - getparent(preop, offset, hasload) - v2 = bitcast!( - b, - v2, - LLVM.PointerType( - eltype(value_type(v)), - addrspace(value_type(v2)), - ), - ) - @assert eltype(value_type(v2)) == eltype(value_type(v)) - return v2, offset, skipload - end - - if isa(v, LLVM.GetElementPtrInst) && all( - x -> (isa(x, LLVM.ConstantInt) && convert(Int, x) == 0), - operands(v)[2:end], - ) - v2, offset, skipload = - getparent(operands(v)[1], offset, hasload) - v2 = bitcast!( - b, - v2, - LLVM.PointerType( - eltype(value_type(v)), - addrspace(value_type(v2)), - ), - ) - @assert eltype(value_type(v2)) == eltype(value_type(v)) - return v2, offset, skipload - end - - if isa(v, LLVM.GetElementPtrInst) - v2, offset, skipload = - getparent(operands(v)[1], offset, hasload) - offset = nuwadd!( - b, - offset, - API.EnzymeComputeByteOffsetOfGEP(b, v, offty), - ) - v2 = bitcast!( - b, - v2, - LLVM.PointerType( - eltype(value_type(v)), - addrspace(value_type(v2)), - ), - ) - @assert eltype(value_type(v2)) == eltype(value_type(v)) - return v2, offset, skipload - end - - undeforpoison = isa(v, LLVM.UndefValue) - @static if LLVM.version() >= v"12" - undeforpoison |= isa(v, LLVM.PoisonValue) - end - if undeforpoison - return LLVM.UndefValue( - LLVM.PointerType(eltype(value_type(v)), 10), - ), - offset, - addr == 13 - end - - if isa(v, LLVM.PHIInst) && !hasload && haskey(goffsets, v) - offset = nuwadd!(b, offset, goffsets[v]) - nv = nextvs[v] - return nv, offset, addr == 13 - end - - if isa(v, LLVM.SelectInst) - lhs_v, lhs_offset, lhs_skipload = - getparent(operands(v)[2], offset, hasload) - rhs_v, rhs_offset, rhs_skipload = - getparent(operands(v)[3], offset, hasload) - if value_type(lhs_v) != value_type(rhs_v) || - value_type(lhs_offset) != value_type(rhs_offset) || - lhs_skipload != rhs_skipload - msg = sprint() do io - println( - io, - "Could not analyze [select] garbage collection behavior of", - ) - println(io, " v0: ", string(v0)) - println(io, " v: ", string(v)) - println(io, " offset: ", string(offset)) - println(io, " hasload: ", string(hasload)) - println(io, " lhs_v", lhs_v) - println(io, " rhs_v", rhs_v) - println(io, " lhs_offset", lhs_offset) - println(io, " rhs_offset", rhs_offset) - println(io, " lhs_skipload", lhs_skipload) - println(io, " rhs_skipload", rhs_skipload) - end - bt = GPUCompiler.backtrace(inst) - throw(EnzymeInternalError(msg, string(f), bt)) - end - return select!(b, operands(v)[1], lhs_v, rhs_v), - select!(b, operands(v)[1], lhs_offset, rhs_offset), - lhs_skipload - end - - msg = sprint() do io - println(io, "Could not analyze garbage collection behavior of") - println(io, " inst: ", string(inst)) - println(io, " v0: ", string(v0)) - println(io, " v: ", string(v)) - println(io, " offset: ", string(offset)) - println(io, " hasload: ", string(hasload)) - end - bt = GPUCompiler.backtrace(inst) - throw(EnzymeInternalError(msg, string(f), bt)) - end - - v, offset, hadload = getparent(v, LLVM.ConstantInt(offty, 0), false) - - if addr == 13 - @assert hadload - end - - if eltype(value_type(v)) != el_ty - v = bitcast!( - b, - v, - LLVM.PointerType(el_ty, addrspace(value_type(v))), - ) - end - push!(nvs, (v, pb)) - push!(offsets, (offset, pb)) - end - - nb = IRBuilder() - position!(nb, nonphi) - - offset = goffsets[inst] - append!(LLVM.incoming(offset), offsets) - if all(x -> x[1] == offsets[1][1], offsets) - offset = offsets[1][1] - end - - nphi = nextvs[inst] - - function ogbc(@nospecialize(x::LLVM.Value)) - while isa(x, LLVM.BitCastInst) - x = operands(x)[1] - end - return x - end - - if all(x -> ogbc(x[1]) == ogbc(nvs[1][1]), nvs) - bc = ogbc(nvs[1][1]) - if value_type(bc) != value_type(nphi) - bc = bitcast!(nb, bc, value_type(nphi)) - end - replace_uses!(nphi, bc) - LLVM.API.LLVMInstructionEraseFromParent(nphi) - nphi = bc - else - append!(LLVM.incoming(nphi), nvs) - end - - if addr == 13 - @static if VERSION < v"1.11-" - nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10)) - nphi = addrspacecast!(nb, nphi, LLVM.PointerType(ty, 11)) - nphi = load!(nb, ty, nphi) - else - base_obj = nphi - - jlt = LLVM.PointerType(LLVM.StructType(LLVM.LLVMType[]), 10) - pjlt = LLVM.PointerType(jlt) - - nphi = get_memory_data(nb, nphi) - nphi = bitcast!(nb, nphi, pjlt) - - GTy = LLVM.FunctionType(LLVM.PointerType(jlt, 13), LLVM.LLVMType[jlt, pjlt]) - gcloaded, _ = get_function!( - mod, - "julia.gc_loaded", - GTy - ) - nphi = call!(nb, GTy, gcloaded, LLVM.Value[base_obj, nphi]) - if value_type(nphi) != ty - nphi = bitcast!(nb, nphi, ty) - end - end - else - nphi = addrspacecast!(nb, nphi, ty) - end - if !isa(offset, LLVM.ConstantInt) || convert(Int64, offset) != 0 - nphi = bitcast!(nb, nphi, LLVM.PointerType(i8, addrspace(ty))) - nphi = gep!(nb, i8, nphi, [offset]) - nphi = bitcast!(nb, nphi, ty) - end - replace_uses!(inst, nphi) - end - for inst in todo - LLVM.API.LLVMInstructionEraseFromParent(inst) - end - end - end - end - return nothing -end - -function fix_decayaddr!(mod::LLVM.Module) - for f in functions(mod) - invalid = LLVM.Instruction[] - for bb in blocks(f), inst in instructions(bb) - if !isa(inst, LLVM.AddrSpaceCastInst) - continue - end - prety = value_type(operands(inst)[1]) - postty = value_type(inst) - if addrspace(prety) != 10 - continue - end - if addrspace(postty) != 0 - continue - end - push!(invalid, inst) - end - - for inst in invalid - temp = nothing - for u in LLVM.uses(inst) - st = LLVM.user(u) - # Storing _into_ the decay addr is okay - # we just cannot store the decayed addr into - # somewhere - if isa(st, LLVM.StoreInst) - if operands(st)[2] == inst - LLVM.API.LLVMSetOperand(st, 2 - 1, operands(inst)[1]) - continue - end - end - if isa(st, LLVM.LoadInst) - LLVM.API.LLVMSetOperand(st, 1 - 1, operands(inst)[1]) - continue - end - # if isa(st, LLVM.InsertValueInst) - # if operands(st)[1] == inst - # push!(invalid, st) - # LLVM.API.LLVMSetOperand(st, 1-1, LLVM.UndefValue(value_type(inst))) - # continue - # end - # if operands(st)[2] == inst - # push!(invalid, st) - # LLVM.API.LLVMSetOperand(st, 2-1, LLVM.UndefValue(value_type(inst))) - # continue - # end - # end - if !isa(st, LLVM.CallInst) - bt = GPUCompiler.backtrace(st) - msg = sprint() do io::IO - println(io, string(f)) - println(io, inst) - println(io, st) - print(io, "Illegal decay of nonnull\n") - if bt !== nothing - print(io, "\nCaused by:") - Base.show_backtrace(io, bt) - println(io) - end - end - throw(AssertionError(msg)) - end - - fop = operands(st)[end] - - intr = LLVM.API.LLVMGetIntrinsicID(fop) - - if intr == LLVM.Intrinsic("llvm.memcpy").id || - intr == LLVM.Intrinsic("llvm.memmove").id || - intr == LLVM.Intrinsic("llvm.memset").id - newvs = LLVM.Value[] - for (i, v) in enumerate(operands(st)[1:end-1]) - if v == inst - LLVM.API.LLVMSetOperand(st, i - 1, operands(inst)[1]) - push!(newvs, operands(inst)[1]) - continue - end - push!(newvs, v) - end - - nb = IRBuilder() - position!(nb, st) - if intr == LLVM.Intrinsic("llvm.memcpy").id - newi = memcpy!(nb, newvs[1], 0, newvs[2], 0, newvs[3]) - elseif intr == LLVM.Intrinsic("llvm.memmove").id - newi = memmove!(nb, newvs[1], 0, newvs[2], 0, newvs[3]) - else - newi = memset!(nb, newvs[1], newvs[2], newvs[3], 0) - end - - for idx in [ - LLVM.API.LLVMAttributeFunctionIndex, - LLVM.API.LLVMAttributeReturnIndex, - [ - LLVM.API.LLVMAttributeIndex(i) for - i = 1:(length(operands(st))-1) - ]..., - ] - idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) - count = LLVM.API.LLVMGetCallSiteAttributeCount(st, idx) - - Attrs = Base.unsafe_convert( - Ptr{LLVM.API.LLVMAttributeRef}, - Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), - ) - LLVM.API.LLVMGetCallSiteAttributes(st, idx, Attrs) - for j = 1:count - LLVM.API.LLVMAddCallSiteAttribute( - newi, - idx, - unsafe_load(Attrs, j), - ) - end - Libc.free(Attrs) - end - - API.EnzymeCopyMetadata(newi, st) - - LLVM.API.LLVMInstructionEraseFromParent(st) - continue - end - mayread = false - maywrite = false - sret = true - sretkind = kind(if LLVM.version().major >= 12 - TypeAttribute("sret", LLVM.Int32Type()) - else - EnumAttribute("sret") - end) - for (i, v) in enumerate(operands(st)[1:end-1]) - if v == inst - readnone = false - readonly = false - writeonly = false - t_sret = false - for a in collect(parameter_attributes(fop, i)) - if kind(a) == sretkind - t_sret = true - end - if kind(a) == kind(StringAttribute("enzyme_sret")) - t_sret = true - end - # if kind(a) == kind(StringAttribute("enzyme_sret_v")) - # t_sret = true - # end - if kind(a) == kind(EnumAttribute("readonly")) - readonly = true - end - if kind(a) == kind(EnumAttribute("readnone")) - readnone = true - end - if kind(a) == kind(EnumAttribute("writeonly")) - writeonly = true - end - end - if !t_sret - sret = false - end - if readnone - continue - end - if !readonly - maywrite = true - end - if !writeonly - mayread = true - end - end - end - if !sret - msg = sprint() do io - println(io, "Enzyme Internal Error: did not have sret when expected") - println(io, "f=", string(f)) - println(io, "inst=", string(inst)) - println(io, "st=", string(st)) - println(io, "fop=", string(fop)) - end - throw(AssertionError(msg)) - end - - elt = eltype(value_type(inst)) - if temp === nothing - nb = IRBuilder() - position!(nb, first(instructions(first(blocks(f))))) - temp = alloca!(nb, elt) - end - if mayread - nb = IRBuilder() - position!(nb, st) - ld = load!(nb, elt, operands(inst)[1]) - store!(nb, ld, temp) - end - if maywrite - nb = IRBuilder() - position!(nb, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(st))) - ld = load!(nb, elt, temp) - si = store!(nb, ld, operands(inst)[1]) - julia_post_cache_store(si.ref, nb.ref, reinterpret(Ptr{UInt64}, C_NULL)) - end - end - - if temp !== nothing - replace_uses!(inst, temp) - end - LLVM.API.LLVMInstructionEraseFromParent(inst) - end - end - return nothing -end - -function pre_attr!(mod::LLVM.Module) - return nothing - tofinalize = Tuple{LLVM.Function,Bool,Vector{Int64}}[] - for fn in collect(functions(mod)) - if isempty(blocks(fn)) - continue - end - if linkage(fn) != LLVM.API.LLVMInternalLinkage && - linkage(fn) != LLVM.API.LLVMPrivateLinkage - continue - end - - fty = LLVM.FunctionType(fn) - nfn = LLVM.Function(mod, "enzyme_attr_prev_" * LLVM.name(enzymefn), fty) - LLVM.IRBuilder() do builder - entry = BasicBlock(nfn, "entry") - position!(builder, entry) - cv = call!(fn, [LLVM.UndefValue(ty) for ty in parameters(fty)]) - LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(1), attr) - if LLVM.return_type(fty) == LLVM.VoidType() - ret!(builder) - else - ret!(builder, cv) - end - end - end - return nothing -end - function jl_inst_simplify!(PM::LLVM.ModulePassManager) ccall( (:LLVMAddJLInstSimplifyPass, API.libEnzyme), @@ -1403,1157 +448,8 @@ end function post_attr!(mod::LLVM.Module) end -function prop_global!(g::LLVM.GlobalVariable) - newfns = String[] - changed = false - todo = Tuple{Vector{Cuint},LLVM.Value}[] - for u in LLVM.uses(g) - u = LLVM.user(u) - push!(todo, (Cuint[], u)) - end - while length(todo) > 0 - path, var = pop!(todo) - if isa(var, LLVM.LoadInst) - B = IRBuilder() - position!(B, var) - res = LLVM.initializer(g) - for p in path - res = extract_value!(B, res, p) - end - changed = true - for u in LLVM.uses(var) - u = LLVM.user(u) - if isa(u, LLVM.CallInst) - f2 = LLVM.called_operand(u) - if isa(f2, LLVM.Function) - push!(newfns, LLVM.name(f2)) - end - end - end - replace_uses!(var, res) - eraseInst(LLVM.parent(var), var) - continue - end - if isa(var, LLVM.AddrSpaceCastInst) - for u in LLVM.uses(var) - u = LLVM.user(u) - push!(todo, (path, u)) - end - continue - end - if isa(var, LLVM.ConstantExpr) && opcode(var) == LLVM.API.LLVMAddrSpaceCast - for u in LLVM.uses(var) - u = LLVM.user(u) - push!(todo, (path, u)) - end - continue - end - if isa(var, LLVM.GetElementPtrInst) - if all(isa(v, LLVM.ConstantInt) for v in operands(var)[2:end]) - if convert(Cuint, operands(var)[2]) == 0 - for u in LLVM.uses(var) - u = LLVM.user(u) - push!( - todo, - ( - vcat( - path, - collect(( - convert(Cuint, v) for v in operands(var)[3:end] - )), - ), - u, - ), - ) - end - end - continue - end - end - end - return changed, newfns -end - -# From https://llvm.org/doxygen/IR_2Instruction_8cpp_source.html#l00959 -function mayWriteToMemory(@nospecialize(inst::LLVM.Instruction); err_is_readonly::Bool = false)::Bool - # we will ignore fense here - if isa(inst, LLVM.StoreInst) - return true - end - if isa(inst, LLVM.VAArgInst) - return true - end - if isa(inst, LLVM.AtomicCmpXchgInst) - return true - end - if isa(inst, LLVM.AtomicRMWInst) - return true - end - if isa(inst, LLVM.CatchPadInst) - return true - end - if isa(inst, LLVM.CatchRetInst) - return true - end - if isa(inst, LLVM.CallInst) || isa(inst, LLVM.InvokeInst) || isa(inst, LLVM.CallBrInst) - idx = reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex) - count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) - - Attrs = Base.unsafe_convert( - Ptr{LLVM.API.LLVMAttributeRef}, - Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), - ) - LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) - for j = 1:count - attr = LLVM.Attribute(unsafe_load(Attrs, j)) - if kind(attr) == kind(EnumAttribute("readnone")) - return false - end - if kind(attr) == kind(EnumAttribute("readonly")) - return false - end - # Note out of spec, and only legal in context of removing unused calls - if kind(attr) == kind(StringAttribute("enzyme_error")) && err_is_readonly - return false - end - if kind(attr) == kind(StringAttribute("memory")) - if is_readonly(MemoryEffect(value(attr))) - return false - end - end - end - Libc.free(Attrs) - return true - end - # Ignoring load unordered case - return false -end - -function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String}) - calls = LLVM.CallInst[] - - hasUser = false - for u in LLVM.uses(fn) - un = LLVM.user(u) - - # Only permit call users - if !isa(un, LLVM.CallInst) - return false - end - un = un::LLVM.CallInst - - # Passing the fn as an argument is not permitted - for op in collect(operands(un))[1:end-1] - if op == fn - return false - end - end - - # Something with a user is not permitted - for u2 in LLVM.uses(un) - hasUser = true - break - end - push!(calls, un) - end - - done = Set{LLVM.Function}() - todo = LLVM.Function[fn] - - while length(todo) != 0 - cur = pop!(todo) - if cur in done - continue - end - push!(done, cur) - - if is_readonly(cur) - continue - end - - if LLVM.name(cur) == "julia.safepoint" - continue - end - - if isempty(blocks(cur)) - return false - end - - err_is_readonly = !is_noreturn(cur) - - for bb in blocks(cur) - for inst in instructions(bb) - if !mayWriteToMemory(inst; err_is_readonly) - continue - end - if isa(inst, LLVM.CallInst) - - fn2 = LLVM.called_operand(inst) - if isa(fn2, LLVM.Function) - push!(todo, fn2) - continue - end - end - return false - end - end - end - - changed = set_readonly!(fn) - - if length(calls) == 0 || hasUser - return changed - end - - for c in calls - parentf = LLVM.parent(LLVM.parent(c)) - push!(next, LLVM.name(parentf)) - LLVM.API.LLVMInstructionEraseFromParent(c) - end - push!(next, LLVM.name(fn)) - return true -end - -function propagate_returned!(mod::LLVM.Module) - globs = LLVM.GlobalVariable[] - for g in globals(mod) - if linkage(g) == LLVM.API.LLVMInternalLinkage || - linkage(g) == LLVM.API.LLVMPrivateLinkage - if !isconstant(g) - continue - end - push!(globs, g) - end - end - todo = collect(functions(mod)) - while true - next = Set{String}() - changed = false - for g in globs - tc, tn = prop_global!(g) - changed |= tc - for f in tn - push!(next, f) - end - end - tofinalize = Tuple{LLVM.Function,Bool,Vector{Int64}}[] - for fn in functions(mod) - if isempty(blocks(fn)) - continue - end - if remove_readonly_unused_calls!(fn, next) - changed = true - end - attrs = collect(function_attributes(fn)) - prevent = any( - kind(attr) == kind(StringAttribute("enzyme_preserve_primal")) for - attr in attrs - ) - # if any(kind(attr) == kind(EnumAttribute("noinline")) for attr in attrs) - # continue - # end - argn = nothing - toremove = Int64[] - for (i, arg) in enumerate(parameters(fn)) - if any( - kind(attr) == kind(EnumAttribute("returned")) for - attr in collect(parameter_attributes(fn, i)) - ) - argn = i - end - - # remove unused sret-like - if !prevent && - ( - linkage(fn) == LLVM.API.LLVMInternalLinkage || - linkage(fn) == LLVM.API.LLVMPrivateLinkage - ) && - any( - kind(attr) == kind(EnumAttribute("nocapture")) for - attr in collect(parameter_attributes(fn, i)) - ) - val = nothing - illegalUse = false - torem = LLVM.Instruction[] - argeltype = if LLVM.version().major >= 12 - # TODO try to get sret element type if possible - # note currently opaque pointers has this break [and we need to doa check if opaque - # and if so get inner piece] - eltype(value_type(arg)) - else - eltype(value_type(arg)) - end - for u in LLVM.uses(fn) - un = LLVM.user(u) - if !isa(un, LLVM.CallInst) - illegalUse = true - break - end - ops = collect(operands(un))[1:end-1] - bad = false - for op in ops - if op == fn - bad = true - break - end - end - if bad - illegalUse = true - break - end - if !isa(ops[i], LLVM.AllocaInst) && !isa(ops[i], LLVM.UndefValue) && !isa(ops[i], LLVM.PoisonValue) - illegalUse = true - break - end - eltype = if isa(ops[i], LLVM.AllocaInst) - LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(ops[i])) - else - LLVM.eltype(value_type(ops[i])) - end - seenfn = false - todo = LLVM.Instruction[] - if isa(ops[i], LLVM.AllocaInst) - for u2 in LLVM.uses(ops[i]) - un2 = LLVM.user(u2) - push!(todo, un2) - end - end - while length(todo) > 0 - un2 = pop!(todo) - if isa(un2, LLVM.BitCastInst) - push!(torem, un2) - for u3 in LLVM.uses(un2) - un3 = LLVM.user(u3) - push!(todo, un3) - end - continue - end - if isa(un2, LLVM.GetElementPtrInst) - push!(torem, un2) - for u3 in LLVM.uses(un2) - un3 = LLVM.user(u3) - push!(todo, un3) - end - continue - end - if !isa(un2, LLVM.CallInst) - illegalUse = true - break - end - ff = LLVM.called_operand(un2) - if !isa(ff, LLVM.Function) - illegalUse = true - break - end - if un2 == un && !seenfn - seenfn = true - continue - end - intr = LLVM.API.LLVMGetIntrinsicID(ff) - if intr == LLVM.Intrinsic("llvm.lifetime.start").id - push!(torem, un2) - continue - end - if intr == LLVM.Intrinsic("llvm.lifetime.end").id - push!(torem, un2) - continue - end - if LLVM.name(ff) != "llvm.enzyme.sret_use" - illegalUse = true - break - end - push!(torem, un2) - end - if illegalUse - break - end - end - if !illegalUse - for c in reverse(torem) - eraseInst(LLVM.parent(c), c) - end - B = IRBuilder() - position!(B, first(instructions(first(blocks(fn))))) - al = alloca!(B, argeltype) - if value_type(al) != value_type(arg) - al = addrspacecast!(B, al, value_type(arg)) - end - LLVM.replace_uses!(arg, al) - end - end - - # interprocedural const prop from callers of arg - if !prevent && ( - linkage(fn) == LLVM.API.LLVMInternalLinkage || - linkage(fn) == LLVM.API.LLVMPrivateLinkage - ) - val = nothing - illegalUse = false - for u in LLVM.uses(fn) - un = LLVM.user(u) - if !isa(un, LLVM.CallInst) - illegalUse = true - break - end - ops = collect(operands(un))[1:end-1] - bad = false - for op in ops - if op == fn - bad = true - break - end - end - if bad - illegalUse = true - break - end - if isa(ops[i], LLVM.UndefValue) || isa(ops[i], LLVM.PoisonValue) - continue - end - if ops[i] == arg - continue - end - if isa(ops[i], LLVM.Constant) - if val === nothing - val = ops[i] - else - if val != ops[i] - illegalUse = true - break - end - end - continue - end - illegalUse = true - break - end - if !illegalUse - if val === nothing - val = LLVM.UndefValue(value_type(arg)) - end - for u in LLVM.uses(arg) - u = LLVM.user(u) - if isa(u, LLVM.CallInst) - f2 = LLVM.called_operand(u) - if isa(f2, LLVM.Function) - push!(next, LLVM.name(f2)) - end - end - changed = true - end - LLVM.replace_uses!(arg, val) - end - end - # see if there are no users of the value (excluding recursive/return) - baduse = false - for u in LLVM.uses(arg) - u = LLVM.user(u) - if argn == i && LLVM.API.LLVMIsAReturnInst(u) != C_NULL - continue - end - if !isa(u, LLVM.CallInst) - baduse = true - break - end - if LLVM.called_operand(u) != fn - baduse = true - break - end - for (si, op) in enumerate(operands(u)) - if si == i - continue - end - if op == arg - baduse = true - break - end - end - if baduse - break - end - end - if !baduse - push!(toremove, i - 1) - end - end - illegalUse = !( - linkage(fn) == LLVM.API.LLVMInternalLinkage || - linkage(fn) == LLVM.API.LLVMPrivateLinkage - ) - hasAnyUse = false - for u in LLVM.uses(fn) - un = LLVM.user(u) - if !isa(un, LLVM.CallInst) - illegalUse = true - continue - end - ops = collect(operands(un))[1:end-1] - bad = false - for op in ops - if op == fn - bad = true - break - end - end - if bad - illegalUse = true - continue - end - if argn !== nothing - hasUse = false - for u in LLVM.uses(un) - hasUse = true - break - end - if hasUse - changed = true - push!(next, LLVM.name(LLVM.parent(LLVM.parent(un)))) - LLVM.replace_uses!(un, ops[argn]) - end - else - for u in LLVM.uses(un) - hasAnyUse = true - break - end - end - end - #if the function return has no users whatsoever, remove it - if argn === nothing && - !hasAnyUse && - LLVM.return_type(LLVM.function_type(fn)) != LLVM.VoidType() - argn = -1 - end - if argn === nothing && length(toremove) == 0 - continue - end - if !illegalUse - push!(tofinalize, (fn, argn === nothing, toremove)) - end - end - for (fn, keepret, toremove) in tofinalize - todo = LLVM.CallInst[] - for u in LLVM.uses(fn) - un = LLVM.user(u) - push!(next, LLVM.name(LLVM.parent(LLVM.parent(un)))) - end - delete_writes_into_removed_args(fn, toremove, keepret) - nm = LLVM.name(fn) - #try - nfn = LLVM.Function( - API.EnzymeCloneFunctionWithoutReturnOrArgs(fn, keepret, toremove), - ) - for u in LLVM.uses(fn) - un = LLVM.user(u) - push!(todo, un) - end - for un in todo - md = metadata(un) - if !keepret && haskey(md, LLVM.MD_range) - delete!(md, LLVM.MD_range) - end - API.EnzymeSetCalledFunction(un, nfn, toremove) - end - eraseInst(mod, fn) - changed = true - # catch e - # break - #end - end - if !changed - break - else - todo = LLVM.Function[] - for name in next - fn = functions(mod)[name] - if linkage(fn) == LLVM.API.LLVMInternalLinkage || - linkage(fn) == LLVM.API.LLVMPrivateLinkage - has_user = false - for u in LLVM.uses(fn) - has_user = true - break - end - if !has_user - LLVM.API.LLVMDeleteFunction(fn) - end - end - push!(todo, fn) - end - end - end -end - -function delete_writes_into_removed_args(fn::LLVM.Function, toremove::Vector{Int64}, keepret::Bool) - args = collect(parameters(fn)) - for tr in toremove - tr = tr + 1 - todorep = Tuple{LLVM.Instruction, LLVM.Value}[] - for opv in LLVM.uses(args[tr]) - u = LLVM.user(opv) - push!(todorep, (u, args[tr])) - end - toerase = LLVM.Instruction[] - while length(todorep) != 0 - cur, cval = pop!(todorep) - if isa(cur, LLVM.StoreInst) - if operands(cur)[2] == cval - LLVM.API.LLVMInstructionEraseFromParent(nphi) - continue - end - end - if isa(cur, LLVM.GetElementPtrInst) || - isa(cur, LLVM.BitCastInst) || - isa(cur, LLVM.AddrSpaceCastInst) - for opv in LLVM.uses(cur) - u = LLVM.user(opv) - push!(todorep, (u, cur)) - end - continue - end - if isa(cur, LLVM.CallInst) - cf = LLVM.called_operand(cur) - if cf == fn - baduse = false - for (i, v) in enumerate(operands(cur)) - if i-1 in toremove - continue - end - if v == cval - baduse = true - end - end - if !baduse - continue - end - end - end - if !keepret && LLVM.API.LLVMIsAReturnInst(cur) != C_NULL - LLVM.API.LLVMSetOperand(cur, 0, LLVM.UndefValue(value_type(cval))) - continue - end - throw(AssertionError("Deleting argument with an unknown dependency, $(string(cur)) uses $(string(cval))")) - end - end -end - -function detect_writeonly!(mod::LLVM.Module) - for f in functions(mod) - if isempty(LLVM.blocks(f)) - continue - end - for (i, a) in enumerate(parameters(f)) - if isa(value_type(a), LLVM.PointerType) - todo = Tuple{LLVM.Value,LLVM.Instruction}[] - for u in LLVM.uses(a) - push!(todo, (a, LLVM.user(u))) - end - seen = Set{Tuple{LLVM.Value,LLVM.Instruction}}() - mayread = false - maywrite = false - while length(todo) > 0 - cur = pop!(todo) - if in(cur, seen) - continue - end - push!(seen, cur) - curv, curi = cur - - if isa(curi, LLVM.StoreInst) - if operands(curi)[1] != curv - maywrite = true - continue - end - end - - if isa(curi, LLVM.LoadInst) - mayread = true - continue - end - - if isa(curi, LLVM.GetElementPtrInst) || - isa(curi, LLVM.BitCastInst) || - isa(curi, LLVM.AddrSpaceCastInst) - for u in LLVM.uses(curi) - push!(todo, (curi, LLVM.user(u))) - end - continue - end - mayread = true - maywrite = true - end - if any( - map( - k -> kind(k) == kind(EnumAttribute("readnone")), - collect(parameter_attributes(f, i)), - ), - ) - mayread = false - maywrite = false - end - if any( - map( - k -> kind(k) == kind(EnumAttribute("readonly")), - collect(parameter_attributes(f, i)), - ), - ) - maywrite = false - end - if any( - map( - k -> kind(k) == kind(EnumAttribute("writeonly")), - collect(parameter_attributes(f, i)), - ), - ) - mayread = false - end - - LLVM.API.LLVMRemoveEnumAttributeAtIndex( - f, - LLVM.API.LLVMAttributeIndex(i), - kind(EnumAttribute("readnone")), - ) - LLVM.API.LLVMRemoveEnumAttributeAtIndex( - f, - LLVM.API.LLVMAttributeIndex(i), - kind(EnumAttribute("readonly")), - ) - LLVM.API.LLVMRemoveEnumAttributeAtIndex( - f, - LLVM.API.LLVMAttributeIndex(i), - kind(EnumAttribute("writeonly")), - ) - - if !mayread && !maywrite - push!(parameter_attributes(f, i), LLVM.EnumAttribute("readnone", 0)) - elseif !mayread - push!(parameter_attributes(f, i), LLVM.EnumAttribute("writeonly", 0)) - elseif !maywrite - push!(parameter_attributes(f, i), LLVM.EnumAttribute("readonly", 0)) - end - - end - end - end - return nothing -end - -function validate_return_roots!(mod::LLVM.Module) - for f in functions(mod) - srets = [] - enzyme_srets = Int[] - enzyme_srets_v = Int[] - rroots = Int[] - rroots_v = Int[] - sretkind = kind(if LLVM.version().major >= 12 - TypeAttribute("sret", LLVM.Int32Type()) - else - EnumAttribute("sret") - end) - for (i, a) in enumerate(parameters(f)) - for attr in collect(parameter_attributes(f, i)) - if isa(attr, StringAttribute) - if kind(attr) == "enzymejl_returnRoots" - push!(rroots, i) - end - if kind(attr) == "enzymejl_returnRoots_v" - push!(rroots_v, i) - end - if kind(attr) == "enzyme_sret" - push!(enzyme_srets, i) - end - if kind(attr) == "enzyme_sret_v" - push!(enzyme_srets, i) - end - end - if kind(attr) == sretkind - push!(srets, (i, attr)) - end - end - end - if length(enzyme_srets) >= 1 && length(srets) == 0 - @assert enzyme_srets[1] == 1 - VT = LLVM.VoidType() - if length(enzyme_srets) == 1 && - LLVM.return_type(LLVM.function_type(f)) == VT && - length(enzyme_srets_v) == 0 - # Upgrading to sret requires writeonly - if !any( - kind(attr) == kind(EnumAttribute("writeonly")) for - attr in collect(parameter_attributes(f, 1)) - ) - msg = sprint() do io::IO - println(io, "Enzyme internal error (not writeonly sret)") - println(io, string(f)) - println( - io, - "collect(parameter_attributes(f, 1))=", - collect(parameter_attributes(f, 1)), - ) - end - throw(AssertionError(msg)) - end - - alty = nothing - for u in LLVM.uses(f) - u = LLVM.user(u) - @assert isa(u, LLVM.CallInst) - @assert LLVM.called_operand(u) == f - alop = operands(u)[1] - if !isa(alop, LLVM.AllocaInst) - msg = sprint() do io::IO - println(io, "Enzyme internal error (!isa(alop, LLVM.AllocaInst))") - println(io, "alop=", alop) - println(io, "u=", u) - println(io, "f=", string(f)) - end - throw(AssertionError(msg)) - - end - @assert isa(alop, LLVM.AllocaInst) - nty = API.EnzymeAllocaType(alop) - if alty === nothing - alty = nty - else - @assert alty == nty - end - attr = if LLVM.version().major >= 12 - TypeAttribute("sret", alty) - else - EnumAttribute("sret") - end - LLVM.API.LLVMAddCallSiteAttribute( - u, - LLVM.API.LLVMAttributeIndex(1), - attr, - ) - LLVM.API.LLVMRemoveCallSiteStringAttribute( - u, - LLVM.API.LLVMAttributeIndex(1), - "enzyme_sret", - length("enzyme_sret"), - ) - end - @assert alty !== nothing - attr = if LLVM.version().major >= 12 - TypeAttribute("sret", alty) - else - EnumAttribute("sret") - end - - push!(parameter_attributes(f, 1), attr) - delete!(parameter_attributes(f, 1), StringAttribute("enzyme_sret")) - srets = [(1, attr)] - enzyme_srets = Int[] - else - - enzyme_srets2 = Int[] - for idx in enzyme_srets - alty = nothing - bad = false - for u in LLVM.uses(f) - u = LLVM.user(u) - @assert isa(u, LLVM.CallInst) - @assert LLVM.called_operand(u) == f - alop = operands(u)[1] - @assert isa(alop, LLVM.AllocaInst) - nty = API.EnzymeAllocaType(alop) - if any_jltypes(nty) - bad = true - end - LLVM.API.LLVMRemoveCallSiteStringAttribute( - u, - LLVM.API.LLVMAttributeIndex(idx), - "enzyme_sret", - length("enzyme_sret"), - ) - end - if !bad - delete!( - parameter_attributes(f, idx), - StringAttribute("enzyme_sret"), - ) - else - push!(enzyme_srets2, idx) - end - end - enzyme_srets = enzyme_srets2 - - if length(enzyme_srets) != 0 - msg = sprint() do io::IO - println(io, "Enzyme internal error (length(enzyme_srets) != 0)") - println(io, "f=", string(f)) - println(io, "enzyme_srets=", enzyme_srets) - println(io, "enzyme_srets_v=", enzyme_srets_v) - println(io, "srets=", srets) - println(io, "rroots=", rroots) - println(io, "rroots_v=", rroots_v) - end - throw(AssertionError(msg)) - end - end - end - @assert length(enzyme_srets_v) == 0 - for (i, attr) in srets - @assert i == 1 - end - for i in rroots - @assert length(srets) != 0 - @assert i == 2 - end - # illegal - for i in rroots_v - @assert false - end - end -end - -function checkNoAssumeFalse(mod::LLVM.Module, shouldshow::Bool = false) - for f in functions(mod) - for bb in blocks(f), inst in instructions(bb) - if !isa(inst, LLVM.CallInst) - continue - end - intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(inst)) - if intr != LLVM.Intrinsic("llvm.assume").id - continue - end - op = operands(inst)[1] - if isa(op, LLVM.ConstantInt) - op2 = convert(Bool, op) - if !op2 - msg = sprint() do io - println(io, "Enzyme Internal Error: non-constant assume condition") - println(io, "mod=", string(mod)) - println(io, "f=", string(f)) - println(io, "bb=", string(bb)) - println(io, "op2=", string(op2)) - end - throw(AssertionError(msg)) - end - end - if isa(op, LLVM.ICmpInst) - if predicate_int(op) == LLVM.API.LLVMIntNE && - operands(op)[1] == operands(op)[2] - msg = sprint() do io - println(io, "Enzyme Internal Error: non-icmp assume condition") - println(io, "mod=", string(mod)) - println(io, "f=", string(f)) - println(io, "bb=", string(bb)) - println(io, "op=", string(op)) - end - throw(AssertionError(msg)) - end - end - end - end -end - cse!(pm) = LLVM.API.LLVMAddEarlyCSEPass(pm) -function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine) - # We need to run globalopt first. This is because remove dead args will otherwise - # take internal functions and replace their args with undef. Then on LLVM up to - # and including 12 (but fixed 13+), Attributor will incorrectly change functions that - # call code with undef to become unreachable, even when there exist other valid - # callsites. See: https://godbolt.org/z/9Y3Gv6q5M - ModulePassManager() do pm - global_dce!(pm) - LLVM.run!(pm, mod) - end - # Prevent dead-arg-elimination of functions which we may require args for in the derivative - funcT = LLVM.FunctionType(LLVM.VoidType(), LLVMType[], vararg = true) - if LLVM.version().major <= 15 - func, _ = get_function!( - mod, - "llvm.enzymefakeuse", - funcT, - LLVM.Attribute[EnumAttribute("readnone"), EnumAttribute("nofree")], - ) - rfunc, _ = get_function!( - mod, - "llvm.enzymefakeread", - funcT, - LLVM.Attribute[ - EnumAttribute("readonly"), - EnumAttribute("nofree"), - EnumAttribute("argmemonly"), - ], - ) - sfunc, _ = get_function!( - mod, - "llvm.enzyme.sret_use", - funcT, - LLVM.Attribute[ - EnumAttribute("readonly"), - EnumAttribute("nofree"), - EnumAttribute("argmemonly"), - ], - ) - else - func, _ = get_function!( - mod, - "llvm.enzymefakeuse", - funcT, - LLVM.Attribute[EnumAttribute("memory", NoEffects.data), EnumAttribute("nofree")], - ) - rfunc, _ = get_function!( - mod, - "llvm.enzymefakeread", - funcT, - LLVM.Attribute[EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")], - ) - sfunc, _ = get_function!( - mod, - "llvm.enzyme.sret_use", - funcT, - LLVM.Attribute[EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")], - ) - end - - for fn in functions(mod) - if isempty(blocks(fn)) - continue - end - # Ensure that interprocedural optimizations do not delete the use of returnRoots (or shadows) - # if inactive sret, this will only occur on 2. If active sret, inactive retRoot, can on 3, and - # active both can occur on 4. If the original sret is removed (at index 1) we no longer need - # to preserve this. - for idx in (2, 3, 4) - if length(collect(parameters(fn))) >= idx && any( - ( - kind(attr) == kind(StringAttribute("enzymejl_returnRoots")) || - kind(attr) == kind(StringAttribute("enzymejl_returnRoots_v")) - ) for attr in collect(parameter_attributes(fn, idx)) - ) - for u in LLVM.uses(fn) - u = LLVM.user(u) - @assert isa(u, LLVM.CallInst) - B = IRBuilder() - nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u)) - position!(B, nextInst) - inp = operands(u)[idx] - cl = call!(B, funcT, rfunc, LLVM.Value[inp]) - if isa(value_type(inp), LLVM.PointerType) - LLVM.API.LLVMAddCallSiteAttribute( - cl, - LLVM.API.LLVMAttributeIndex(1), - EnumAttribute("nocapture"), - ) - end - end - end - end - sretkind = kind(if LLVM.version().major >= 12 - TypeAttribute("sret", LLVM.Int32Type()) - else - EnumAttribute("sret") - end) - for idx in (1, 2) - if length(collect(parameters(fn))) < idx - continue - end - attrs = collect(parameter_attributes(fn, idx)) - if any( - ( - kind(attr) == sretkind || - kind(attr) == kind(StringAttribute("enzyme_sret")) || - kind(attr) == kind(StringAttribute("enzyme_sret_v")) - ) for attr in attrs - ) && any_jltypes(sret_ty(fn, idx)) - for u in LLVM.uses(fn) - u = LLVM.user(u) - if isa(u, LLVM.ConstantExpr) - u = LLVM.user(only(LLVM.uses(u))) - end - if !isa(u, LLVM.CallInst) - continue - end - @assert isa(u, LLVM.CallInst) - B = IRBuilder() - nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u)) - position!(B, nextInst) - inp = operands(u)[idx] - cl = call!(B, funcT, sfunc, LLVM.Value[inp]) - if isa(value_type(inp), LLVM.PointerType) - LLVM.API.LLVMAddCallSiteAttribute( - cl, - LLVM.API.LLVMAttributeIndex(1), - EnumAttribute("nocapture"), - ) - end - end - end - end - attrs = collect(function_attributes(fn)) - prevent = any( - kind(attr) == kind(StringAttribute("enzyme_preserve_primal")) for attr in attrs - ) - # && any(kind(attr) == kind(StringAttribute("enzyme_math")) for attr in attrs) - if prevent - B = IRBuilder() - position!(B, first(instructions(first(blocks(fn))))) - call!(B, funcT, func, LLVM.Value[p for p in parameters(fn)]) - end - end - propagate_returned!(mod) - ModulePassManager() do pm - instruction_combining!(pm) - jl_inst_simplify!(pm) - alloc_opt_tm!(pm, tm) - scalar_repl_aggregates_ssa!(pm) # SSA variant? - cse!(pm) - LLVM.run!(pm, mod) - end - propagate_returned!(mod) - pre_attr!(mod) - if RunAttributor[] - if LLVM.version().major >= 13 - ModulePassManager() do pm - API.EnzymeAddAttributorLegacyPass(pm) - LLVM.run!(pm, mod) - end - end - end - propagate_returned!(mod) - ModulePassManager() do pm - instruction_combining!(pm) - jl_inst_simplify!(pm) - alloc_opt_tm!(pm, tm) - scalar_repl_aggregates_ssa!(pm) # SSA variant? - if RunAttributor[] - if LLVM.version().major >= 13 - API.EnzymeAddAttributorLegacyPass(pm) - end - end - cse!(pm) - LLVM.run!(pm, mod) - end - post_attr!(mod) - propagate_returned!(mod) - - for u in LLVM.uses(rfunc) - u = LLVM.user(u) - eraseInst(LLVM.parent(u), u) - end - eraseInst(mod, rfunc) - for u in LLVM.uses(sfunc) - u = LLVM.user(u) - eraseInst(LLVM.parent(u), u) - end - eraseInst(mod, sfunc) - for fn in functions(mod) - for b in blocks(fn) - inst = first(LLVM.instructions(b)) - if isa(inst, LLVM.CallInst) - fn = LLVM.called_operand(inst) - if fn == func - eraseInst(b, inst) - end - end - end - end - eraseInst(mod, func) -end - function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine) addr13NoAlias(mod) # everying except unroll, slpvec, loop-vec diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index e58e574dd2..7f004ea379 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -109,7 +109,7 @@ function set_writing(mri::ModRefInfo) return mri | MRI_Mod end -function set_readonly(effect::MemoryEffect) +function set_readonly(effect::MemoryEffect)::MemoryEffect data = UInt32(0) for loc in (ArgMem, InaccessibleMem, Other) data = UInt32(set_readonly(getModRef(effect, loc))) << getLocationPos(loc) @@ -117,15 +117,15 @@ function set_readonly(effect::MemoryEffect) return MemoryEffect(data) end -function is_readonly(mri::ModRefInfo) +function is_readonly(mri::ModRefInfo)::Bool return mri == MRI_NoModRef || mri == MRI_Ref end -function is_readnone(mri::ModRefInfo) +function is_readnone(mri::ModRefInfo)::Bool return mri == MRI_NoModRef end -function is_writeonly(mri::ModRefInfo) +function is_writeonly(mri::ModRefInfo)::Bool return mri == MRI_NoModRef || mri == MRI_Mod end @@ -137,7 +137,7 @@ for n in (:is_readonly, :is_readnone, :is_writeonly) end end -function is_noreturn(f::LLVM.Function) +Base.@assume_effects :removable :foldable :nothrow function is_noreturn(f::LLVM.Function)::Bool for attr in collect(function_attributes(f)) if kind(attr) == kind(EnumAttribute("noreturn")) return true @@ -146,7 +146,7 @@ function is_noreturn(f::LLVM.Function) return false end -function is_readonly(f::LLVM.Function) +Base.@assume_effects :removable :foldable :nothrow function is_readonly(f::LLVM.Function)::Bool intr = LLVM.API.LLVMGetIntrinsicID(f) if intr == LLVM.Intrinsic("llvm.lifetime.start").id return true @@ -179,7 +179,7 @@ function is_readonly(f::LLVM.Function) return false end -function is_readnone(f::LLVM.Function) +Base.@assume_effects :removable :foldable :nothrow function is_readnone(f::LLVM.Function)::Bool intr = LLVM.API.LLVMGetIntrinsicID(f) if intr == LLVM.Intrinsic("llvm.lifetime.start").id return true @@ -209,7 +209,7 @@ function is_readnone(f::LLVM.Function) return false end -function is_writeonly(f::LLVM.Function) +Base.@assume_effects :removable :foldable :nothrow function is_writeonly(f::LLVM.Function)::Bool intr = LLVM.API.LLVMGetIntrinsicID(f) if intr == LLVM.Intrinsic("llvm.lifetime.start").id return true @@ -329,14 +329,13 @@ function get_pgcstack(func::LLVM.Function) end function reinsert_gcmarker!(func::LLVM.Function, @nospecialize(PB::Union{Nothing, LLVM.IRBuilder}) = nothing) - for (i, v) in enumerate(parameters(func)) - if any( - map( - k -> kind(k) == kind(EnumAttribute("swiftself")), - collect(parameter_attributes(func, i)), - ), - ) - return v + for i in 1:length(LLVM.parameters(func)) + for attr in collect(LLVM.parameter_attributes(func, i)) + if attr isa LLVM.EnumAttribute + if kind(attr) == swiftself_kind + return parameters(func)[i] + end + end end end @@ -361,6 +360,46 @@ function reinsert_gcmarker!(func::LLVM.Function, @nospecialize(PB::Union{Nothing end end +@inline enum_attr_kind(kind::String) = LLVM.API.LLVMGetEnumAttributeKindForName(kind, Csize_t(length(kind))) + +const swiftself_kind = enum_attr_kind("swiftself") + +Base.@assume_effects :removable :foldable :nothrow function has_swiftself(fn::LLVM.Function)::Bool + for i in 1:length(LLVM.parameters(fn)) + for attr in collect(LLVM.parameter_attributes(fn, i)) + if attr isa LLVM.EnumAttribute + if kind(attr) == swiftself_kind + return true + end + end + end + end + return false +end +Base.@assume_effects :removable :foldable :nothrow function has_fn_attr(fn::LLVM.Function, attr::LLVM.EnumAttribute)::Bool + ekind = LLVM.kind(attr) + for attr in collect(function_attributes(fn)) + if attr isa LLVM.EnumAttribute + if kind(attr) == ekind + return true + end + end + end + return false +end + +Base.@assume_effects :removable :foldable :nothrow function has_fn_attr(fn::LLVM.Function, attr::LLVM.StringAttribute)::Bool + ekind = LLVM.kind(attr) + for attr in collect(function_attributes(fn)) + if attr isa LLVM.StringAttribute + if kind(attr) == ekind + return true + end + end + end + return false +end + function eraseInst(bb::LLVM.BasicBlock, @nospecialize(inst::LLVM.Instruction)) @static if isdefined(LLVM, Symbol("erase!")) LLVM.erase!(inst) @@ -404,17 +443,17 @@ function unique_gcmarker!(func::LLVM.Function) end @inline AnonymousStruct(::Type{U}) where {U<:Tuple} = - NamedTuple{ntuple(i -> Symbol(i), Val(length(U.parameters))),U} + NamedTuple{ntuple(Symbol, Val(length(U.parameters))),U} # recursively compute the eltype type indexed by idx[0], idx[1], ... -function recursive_eltype(@nospecialize(val::LLVM.Value), idxs::Vector{Cuint}) - ty = LLVM.value_type(val) +Base.@assume_effects :removable :foldable :nothrow function recursive_eltype(@nospecialize(val::LLVM.Value), idxs::Vector{Cuint})::LLVM.LLVMType + ty = LLVM.value_type(val)::LLVM.LLVMType for i in idxs if isa(ty, LLVM.ArrayType) - ty = eltype(ty) + ty = eltype(ty)::LLVM.LLVMType else @assert isa(ty, LLVM.StructType) - ty = elements(ty)[i+1] + ty = elements(ty)[i+1]::LLVM.LLVMType end end return ty diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 3e833324b6..e90f7d0712 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -15,7 +15,7 @@ end function get_blas_symbols() symbols = BLAS.get_config().exported_symbols if BLAS.USE_BLAS64 - return map(n -> n * "64_", symbols) + return map(Base.Fix2(*, "64_"), symbols) end return symbols end diff --git a/src/errors.jl b/src/errors.jl index a1a81580e6..b48c34b54d 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -79,6 +79,85 @@ function Base.showerror(io::IO, ece::EnzymeInternalError) end end +struct EnzymeRuntimeException <: Base.Exception + msg::Cstring +end + +function Base.showerror(io::IO, ece::EnzymeRuntimeException) + print(io, "Enzyme execution failed.\n") + msg = Base.unsafe_string(ece.msg) + print(io, msg, '\n') +end + +struct EnzymeMutabilityException <: Base.Exception + msg::Cstring +end + +function Base.showerror(io::IO, ece::EnzymeMutabilityException) + msg = Base.unsafe_string(ece.msg) + print(io, msg, '\n') +end + +struct EnzymeRuntimeActivityError <: Base.Exception + msg::Cstring +end + +function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError) + println(io, "Constant memory is stored (or returned) to a differentiable variable.") + println( + io, + "As a result, Enzyme cannot provably ensure correctness and throws this error.", + ) + println( + io, + "This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).", + ) + println( + io, + "If Enzyme should be able to prove this use non-differentable, open an issue!", + ) + println(io, "To work around this issue, either:") + println( + io, + " a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or", + ) + println( + io, + " b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.", + ) + msg = Base.unsafe_string(ece.msg) + print(io, msg, '\n') +end + +struct EnzymeNoTypeError <: Base.Exception + msg::Cstring +end + +function Base.showerror(io::IO, ece::EnzymeNoTypeError) + print(io, "Enzyme cannot deduce type\n") + msg = Base.unsafe_string(ece.msg) + print(io, msg, '\n') +end + +struct EnzymeNoShadowError <: Base.Exception + msg::Cstring +end + +function Base.showerror(io::IO, ece::EnzymeNoShadowError) + print(io, "Enzyme could not find shadow for value\n") + msg = Base.unsafe_string(ece.msg) + print(io, msg, '\n') +end + +struct EnzymeNoDerivativeError <: Base.Exception + msg::Cstring +end + +function Base.showerror(io::IO, ece::EnzymeNoDerivativeError) + msg = Base.unsafe_string(ece.msg) + print(io, msg, '\n') +end + parent_scope(val::LLVM.Function, depth = 0) = depth == 0 ? LLVM.parent(val) : val parent_scope(val::LLVM.Module, depth = 0) = val parent_scope(@nospecialize(val::LLVM.Value), depth = 0) = parent_scope(LLVM.parent(val), depth + 1) diff --git a/src/compiler/passes.jl b/src/llvm/passes.jl similarity index 99% rename from src/compiler/passes.jl rename to src/llvm/passes.jl index 403b2bfa04..b9d8ca9385 100644 --- a/src/compiler/passes.jl +++ b/src/llvm/passes.jl @@ -1,3 +1,4 @@ + function reinsert_gcmarker_pass!(fn::LLVM.Function) reinsert_gcmarker!(fn) unique_gcmarker!(fn) diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl new file mode 100644 index 0000000000..b1a5aaafbc --- /dev/null +++ b/src/llvm/transforms.jl @@ -0,0 +1,2181 @@ + +function force_recompute!(mod::LLVM.Module) + for f in functions(mod), bb in blocks(f) + iter = LLVM.API.LLVMGetFirstInstruction(bb) + while iter != C_NULL + inst = LLVM.Instruction(iter) + iter = LLVM.API.LLVMGetNextInstruction(iter) + if isa(inst, LLVM.LoadInst) + has_loaded = false + for u in LLVM.uses(inst) + v = LLVM.user(u) + if isa(v, LLVM.CallInst) + cf = LLVM.called_operand(v) + if isa(cf, LLVM.Function) && LLVM.name(cf) == "julia.gc_loaded" && operands(v)[2] == inst + has_loaded = true + break + end + end + if isa(v, LLVM.BitCastInst) + for u2 in LLVM.uses(v) + v2 = LLVM.user(u2) + if isa(v2, LLVM.CallInst) + cf = LLVM.called_operand(v2) + if isa(cf, LLVM.Function) && LLVM.name(cf) == "julia.gc_loaded" && operands(v2)[2] == v + has_loaded = true + break + end + end + end + end + end + if has_loaded + metadata(inst)["enzyme_nocache"] = MDNode(LLVM.Metadata[]) + end + end + if isa(inst, LLVM.CallInst) + cf = LLVM.called_operand(inst) + if isa(cf, LLVM.Function) + if LLVM.name(cf) == "llvm.julia.gc_preserve_begin" + has_use = false + for u2 in LLVM.uses(inst) + has_use = true + break + end + if !has_use + eraseInst(bb, inst) + end + end + end + end + end + end +end + +function permit_inlining!(f::LLVM.Function) + for bb in blocks(f), inst in instructions(bb) + # remove illegal invariant.load and jtbaa_const invariants + if isa(inst, LLVM.LoadInst) + md = metadata(inst) + if haskey(md, LLVM.MD_tbaa) + modified = LLVM.Metadata( + ccall( + (:EnzymeMakeNonConstTBAA, API.libEnzyme), + LLVM.API.LLVMMetadataRef, + (LLVM.API.LLVMMetadataRef,), + md[LLVM.MD_tbaa], + ), + ) + setindex!(md, modified, LLVM.MD_tbaa) + end + if haskey(md, LLVM.MD_invariant_load) + delete!(md, LLVM.MD_invariant_load) + end + end + end +end + +function addNA(@nospecialize(inst::LLVM.Instruction), @nospecialize(node::LLVM.Metadata), MD::LLVM.MDKind) + md = metadata(inst) + next = nothing + if haskey(md, MD) + next = LLVM.MDNode(Metadata[node, operands(md[MD])...]) + else + next = LLVM.MDNode(Metadata[node]) + end + setindex!(md, next, MD) +end + +function addr13NoAlias(mod::LLVM.Module) + ctx = LLVM.context(mod) + dom = API.EnzymeAnonymousAliasScopeDomain("addr13", ctx) + scope = API.EnzymeAnonymousAliasScope(dom, "na_addr13") + aliasscope = noalias = scope + for f in functions(mod), bb in blocks(f), inst in instructions(bb) + if isa(inst, LLVM.StoreInst) + addNA(inst, noalias, LLVM.MD_noalias) + elseif isa(inst, LLVM.CallInst) + fn = LLVM.called_operand(inst) + if isa(fn, LLVM.Function) + name = LLVM.name(fn) + if startswith(name, "llvm.memcpy") || startswith(name, "llvm.memmove") + addNA(inst, noalias, LLVM.MD_noalias) + end + end + elseif isa(inst, LLVM.LoadInst) + ty = value_type(inst) + if isa(ty, LLVM.PointerType) + if addrspace(ty) == 13 + addNA(inst, aliasscope, LLVM.MD_alias_scope) + end + end + end + end +end + +## given code like +# % a = alloca +# ... +# memref(cast(%a), %b, constant size == sizeof(a)) +# +# turn this into load/store, as this is more +# amenable to caching analysis infrastructure +function memcpy_alloca_to_loadstore(mod::LLVM.Module) + dl = datalayout(mod) + for f in functions(mod) + if length(blocks(f)) != 0 + bb = first(blocks(f)) + todel = Set{LLVM.Instruction}() + for alloca in instructions(bb) + if !isa(alloca, LLVM.AllocaInst) + continue + end + todo = Tuple{LLVM.Instruction,LLVM.Value}[(alloca, alloca)] + copy = nothing + legal = true + elty = LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(alloca)) + lifetimestarts = LLVM.Instruction[] + while length(todo) > 0 + cur, prev = pop!(todo) + if isa(cur, LLVM.AllocaInst) || + isa(cur, LLVM.AddrSpaceCastInst) || + isa(cur, LLVM.BitCastInst) + for u in LLVM.uses(cur) + u = LLVM.user(u) + push!(todo, (u, cur)) + end + continue + end + if isa(cur, LLVM.CallInst) && + isa(LLVM.called_operand(cur), LLVM.Function) + intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(cur)) + if intr == LLVM.Intrinsic("llvm.lifetime.start").id + push!(lifetimestarts, cur) + continue + end + if intr == LLVM.Intrinsic("llvm.lifetime.end").id + continue + end + if intr == LLVM.Intrinsic("llvm.memcpy").id + sz = operands(cur)[3] + if operands(cur)[1] == prev && + isa(sz, LLVM.ConstantInt) && + convert(Int, sz) == sizeof(dl, elty) + if copy === nothing || copy == cur + copy = cur + continue + end + end + end + end + + # read only insts of arg, don't matter + if isa(cur, LLVM.LoadInst) + continue + end + if isa(cur, LLVM.CallInst) && + isa(LLVM.called_operand(cur), LLVM.Function) + legalc = true + for (i, ci) in enumerate(operands(cur)[1:end-1]) + if ci == prev + nocapture = false + readonly = false + for a in collect( + parameter_attributes(LLVM.called_operand(cur), i), + ) + if kind(a) == kind(EnumAttribute("readonly")) + readonly = true + end + if kind(a) == kind(EnumAttribute("readnone")) + readonly = true + end + if kind(a) == kind(EnumAttribute("nocapture")) + nocapture = true + end + end + if !nocapture || !readonly + legalc = false + break + end + end + end + if legalc + continue + end + end + + legal = false + break + end + + if legal && copy !== nothing + B = LLVM.IRBuilder() + position!(B, copy) + dst = operands(copy)[1] + src = operands(copy)[2] + dst0 = bitcast!( + B, + dst, + LLVM.PointerType(LLVM.IntType(8), addrspace(value_type(dst))), + ) + + dst = + bitcast!(B, dst, LLVM.PointerType(elty, addrspace(value_type(dst)))) + src = + bitcast!(B, src, LLVM.PointerType(elty, addrspace(value_type(src)))) + + src = load!(B, elty, src) + FT = LLVM.FunctionType( + LLVM.VoidType(), + [LLVM.IntType(64), value_type(dst0)], + ) + lifetimestart, _ = get_function!(mod, "llvm.lifetime.start.p0i8", FT) + call!( + B, + FT, + lifetimestart, + LLVM.Value[LLVM.ConstantInt(Int64(sizeof(dl, elty))), dst0], + ) + store!(B, src, dst) + push!(todel, copy) + end + for lt in lifetimestarts + push!(todel, lt) + end + end + for inst in todel + eraseInst(LLVM.parent(inst), inst) + end + end + end +end + +# If there is a phi node of a decayed value, Enzyme may need to cache it +# Here we force all decayed pointer phis to first addrspace from 10 +function nodecayed_phis!(mod::LLVM.Module) + # Simple handler to fix addrspace 11 + #complex handler for addrspace 13, which itself comes from a load of an + # addrspace 10 + ctx = LLVM.context(mod) + for f in functions(mod) + + guaranteedInactive = false + + for attr in collect(function_attributes(f)) + if !isa(attr, LLVM.StringAttribute) + continue + end + if kind(attr) == "enzyme_inactive" + guaranteedInactive = true + break + end + end + + if guaranteedInactive + continue + end + + + entry_ft = LLVM.function_type(f) + + RT = LLVM.return_type(entry_ft) + inactiveRet = RT == LLVM.VoidType() + + for attr in collect(return_attributes(f)) + if !isa(attr, LLVM.StringAttribute) + continue + end + if kind(attr) == "enzyme_inactive" + inactiveRet = true + break + end + end + + if inactiveRet + for idx in length(collect(parameters(f))) + inactiveParm = false + for attr in collect(parameter_attributes(f, idx)) + if !isa(attr, LLVM.StringAttribute) + continue + end + if kind(attr) == "enzyme_inactive" + inactiveParm = true + break + end + end + if !inactiveParm + inactiveRet = false + break + end + end + if inactiveRet + continue + end + end + + offty = LLVM.IntType(8 * sizeof(Int)) + i8 = LLVM.IntType(8) + + for addr in (11, 13) + + nextvs = Dict{LLVM.PHIInst,LLVM.PHIInst}() + mtodo = Vector{LLVM.PHIInst}[] + goffsets = Dict{LLVM.PHIInst,LLVM.PHIInst}() + nonphis = LLVM.Instruction[] + anyV = false + for bb in blocks(f) + todo = LLVM.PHIInst[] + nonphi = nothing + for inst in instructions(bb) + if !isa(inst, LLVM.PHIInst) + nonphi = inst + break + end + ty = value_type(inst) + if !isa(ty, LLVM.PointerType) + continue + end + if addrspace(ty) != addr + continue + end + if addr == 11 + all_args = true + addrtodo = Value[inst] + seen = Set{LLVM.Value}() + + while length(addrtodo) != 0 + v = pop!(addrtodo) + base, _ = get_base_and_offset(v; offsetAllowed=false) + if in(base, seen) + continue + end + push!(seen, base) + if isa(base, LLVM.Argument) && addrspace(value_type(base)) == 11 + continue + end + if isa(base, LLVM.PHIInst) + for (v, _) in LLVM.incoming(base) + push!(addrtodo, v) + end + continue + end + all_args = false + break + end + if all_args + continue + end + end + + push!(todo, inst) + nb = IRBuilder() + position!(nb, inst) + el_ty = if addr == 11 + eltype(ty) + else + LLVM.StructType(LLVM.LLVMType[]) + end + nphi = phi!( + nb, + LLVM.PointerType(el_ty, 10), + "nodecayed." * LLVM.name(inst), + ) + nextvs[inst] = nphi + anyV = true + + goffsets[inst] = phi!(nb, offty, "nodecayedoff." * LLVM.name(inst)) + end + push!(mtodo, todo) + push!(nonphis, nonphi) + end + for (bb, todo, nonphi) in zip(blocks(f), mtodo, nonphis) + + for inst in todo + ty = value_type(inst) + el_ty = if addr == 11 + eltype(ty) + else + LLVM.StructType(LLVM.LLVMType[]) + end + nvs = Tuple{LLVM.Value,LLVM.BasicBlock}[] + offsets = Tuple{LLVM.Value,LLVM.BasicBlock}[] + for (v, pb) in LLVM.incoming(inst) + done = false + for ((nv, pb0), (offset, pb1)) in zip(nvs, offsets) + if pb0 == pb + push!(nvs, (nv, pb)) + push!(offsets, (offset, pb)) + done = true + break + end + end + if done + continue + end + b = IRBuilder() + position!(b, terminator(pb)) + + + v0 = v + @inline function getparent(@nospecialize(v::LLVM.Value), @nospecialize(offset::LLVM.Value), hasload::Bool) + if addr == 11 && addrspace(value_type(v)) == 10 + return v, offset, hasload + end + if addr == 13 && hasload && addrspace(value_type(v)) == 10 + return v, offset, hasload + end + if addr == 13 && !hasload + if isa(v, LLVM.LoadInst) + v2, o2, hl2 = getparent(operands(v)[1], LLVM.ConstantInt(offty, 0), true) + rhs = LLVM.ConstantInt(offty, 0) + if o2 != rhs + msg = sprint() do io::IO + println( + io, + "Enzyme internal error addr13 load doesn't keep offset 0", + ) + println(io, "v=", string(v)) + println(io, "v2=", string(v2)) + println(io, "o2=", string(o2)) + println(io, "hl2=", string(hl2)) + println(io, "offty=", string(offty)) + println(io, "rhs=", string(rhs)) + end + throw(AssertionError(msg)) + end + return v2, offset, true + end + if isa(v, LLVM.CallInst) + cf = LLVM.called_operand(v) + if isa(cf, LLVM.Function) && LLVM.name(cf) == "julia.gc_loaded" + ld = operands(v)[2] + while isa(ld, LLVM.BitCastInst) || isa(ld, LLVM.AddrSpaceCastInst) + ld = operands(ld)[1] + end + if isa(ld, LLVM.LoadInst) + v2, o2, hl2 = getparent(operands(ld)[1], LLVM.ConstantInt(offty, 0), true) + rhs = LLVM.ConstantInt(offty, sizeof(Int)) + + base_2, off_2 = get_base_and_offset(v2) + base_1, off_1 = get_base_and_offset(operands(v)[1]) + + if o2 == rhs && base_1 == base_2 && off_1 == off_2 + return operands(v)[1], offset, true + end + + pty = TypeTree(API.DT_Pointer, LLVM.context(ld)) + only!(pty, -1) + rhs = ptrtoint!(b, get_memory_data(b, operands(v)[1]), offty) + metadata(rhs)["enzyme_type"] = to_md(pty, ctx) + lhs = ptrtoint!(b, operands(v)[2], offty) + metadata(rhs)["enzyme_type"] = to_md(pty, ctx) + off2 = nuwsub!(b, lhs, rhs) + ity = TypeTree(API.DT_Integer, LLVM.context(ld)) + only!(ity, -1) + metadata(off2)["enzyme_type"] = to_md(ity, ctx) + add = nuwadd!(b, offset, off2) + metadata(add)["enzyme_type"] = to_md(ity, ctx) + return operands(v)[1], add, true + end + end + end + end + + if addr == 13 && isa(v, LLVM.ConstantExpr) + if opcode(v) == LLVM.API.LLVMAddrSpaceCast + v2 = operands(v)[1] + if addrspace(value_type(v2)) == 0 + if addr == 13 && isa(v, LLVM.ConstantExpr) + v2 = const_addrspacecast( + operands(v)[1], + LLVM.PointerType(eltype(value_type(v)), 10), + ) + return v2, offset, hasload + end + end + end + end + + if isa(v, LLVM.ConstantExpr) + if opcode(v) == LLVM.API.LLVMAddrSpaceCast + v2 = operands(v)[1] + if addrspace(value_type(v2)) == 10 + return v2, offset, hasload + end + if addrspace(value_type(v2)) == 0 + if addr == 11 + v2 = const_addrspacecast( + v2, + LLVM.PointerType(eltype(value_type(v)), 10), + ) + return v2, offset, hasload + end + end + if LLVM.isnull(v2) + v2 = const_addrspacecast( + v2, + LLVM.PointerType(eltype(value_type(v)), 10), + ) + return v2, offset, hasload + end + end + if opcode(v) == LLVM.API.LLVMBitCast + preop = operands(v)[1] + while isa(preop, LLVM.ConstantExpr) && opcode(preop) == LLVM.API.LLVMBitCast + preop = operands(preop)[1] + end + v2, offset, skipload = + getparent(preop, offset, hasload) + v2 = const_bitcast( + v2, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(v2)), + ), + ) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end + + if opcode(v) == LLVM.API.LLVMGetElementPtr + v2, offset, skipload = + getparent(operands(v)[1], offset, hasload) + offset = const_add( + offset, + API.EnzymeComputeByteOffsetOfGEP(b, v, offty), + ) + v2 = const_bitcast( + v2, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(v2)), + ), + ) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end + + end + + if isa(v, LLVM.AddrSpaceCastInst) + if addrspace(value_type(operands(v)[1])) == 0 + v2 = addrspacecast!( + b, + operands(v)[1], + LLVM.PointerType(eltype(value_type(v)), 10), + ) + return v2, offset, hasload + end + nv, noffset, nhasload = + getparent(operands(v)[1], offset, hasload) + if eltype(value_type(nv)) != eltype(value_type(v)) + nv = bitcast!( + b, + nv, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(nv)), + ), + ) + end + return nv, noffset, nhasload + end + + if isa(v, LLVM.BitCastInst) + preop = operands(v)[1] + while isa(preop, LLVM.BitCastInst) + preop = operands(preop)[1] + end + v2, offset, skipload = + getparent(preop, offset, hasload) + v2 = bitcast!( + b, + v2, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(v2)), + ), + ) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end + + if isa(v, LLVM.GetElementPtrInst) && all( + x -> (isa(x, LLVM.ConstantInt) && convert(Int, x) == 0), + operands(v)[2:end], + ) + v2, offset, skipload = + getparent(operands(v)[1], offset, hasload) + v2 = bitcast!( + b, + v2, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(v2)), + ), + ) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end + + if isa(v, LLVM.GetElementPtrInst) + v2, offset, skipload = + getparent(operands(v)[1], offset, hasload) + offset = nuwadd!( + b, + offset, + API.EnzymeComputeByteOffsetOfGEP(b, v, offty), + ) + v2 = bitcast!( + b, + v2, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(v2)), + ), + ) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end + + undeforpoison = isa(v, LLVM.UndefValue) + @static if LLVM.version() >= v"12" + undeforpoison |= isa(v, LLVM.PoisonValue) + end + if undeforpoison + return LLVM.UndefValue( + LLVM.PointerType(eltype(value_type(v)), 10), + ), + offset, + addr == 13 + end + + if isa(v, LLVM.PHIInst) && !hasload && haskey(goffsets, v) + offset = nuwadd!(b, offset, goffsets[v]) + nv = nextvs[v] + return nv, offset, addr == 13 + end + + if isa(v, LLVM.SelectInst) + lhs_v, lhs_offset, lhs_skipload = + getparent(operands(v)[2], offset, hasload) + rhs_v, rhs_offset, rhs_skipload = + getparent(operands(v)[3], offset, hasload) + if value_type(lhs_v) != value_type(rhs_v) || + value_type(lhs_offset) != value_type(rhs_offset) || + lhs_skipload != rhs_skipload + msg = sprint() do io + println( + io, + "Could not analyze [select] garbage collection behavior of", + ) + println(io, " v0: ", string(v0)) + println(io, " v: ", string(v)) + println(io, " offset: ", string(offset)) + println(io, " hasload: ", string(hasload)) + println(io, " lhs_v", lhs_v) + println(io, " rhs_v", rhs_v) + println(io, " lhs_offset", lhs_offset) + println(io, " rhs_offset", rhs_offset) + println(io, " lhs_skipload", lhs_skipload) + println(io, " rhs_skipload", rhs_skipload) + end + bt = GPUCompiler.backtrace(inst) + throw(EnzymeInternalError(msg, string(f), bt)) + end + return select!(b, operands(v)[1], lhs_v, rhs_v), + select!(b, operands(v)[1], lhs_offset, rhs_offset), + lhs_skipload + end + + msg = sprint() do io + println(io, "Could not analyze garbage collection behavior of") + println(io, " inst: ", string(inst)) + println(io, " v0: ", string(v0)) + println(io, " v: ", string(v)) + println(io, " offset: ", string(offset)) + println(io, " hasload: ", string(hasload)) + end + bt = GPUCompiler.backtrace(inst) + throw(EnzymeInternalError(msg, string(f), bt)) + end + + v, offset, hadload = getparent(v, LLVM.ConstantInt(offty, 0), false) + + if addr == 13 + @assert hadload + end + + if eltype(value_type(v)) != el_ty + v = bitcast!( + b, + v, + LLVM.PointerType(el_ty, addrspace(value_type(v))), + ) + end + push!(nvs, (v, pb)) + push!(offsets, (offset, pb)) + end + + nb = IRBuilder() + position!(nb, nonphi) + + offset = goffsets[inst] + append!(LLVM.incoming(offset), offsets) + if all(x -> x[1] == offsets[1][1], offsets) + offset = offsets[1][1] + end + + nphi = nextvs[inst] + + function ogbc(@nospecialize(x::LLVM.Value)) + while isa(x, LLVM.BitCastInst) + x = operands(x)[1] + end + return x + end + + if all(x -> ogbc(x[1]) == ogbc(nvs[1][1]), nvs) + bc = ogbc(nvs[1][1]) + if value_type(bc) != value_type(nphi) + bc = bitcast!(nb, bc, value_type(nphi)) + end + replace_uses!(nphi, bc) + LLVM.API.LLVMInstructionEraseFromParent(nphi) + nphi = bc + else + append!(LLVM.incoming(nphi), nvs) + end + + if addr == 13 + @static if VERSION < v"1.11-" + nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10)) + nphi = addrspacecast!(nb, nphi, LLVM.PointerType(ty, 11)) + nphi = load!(nb, ty, nphi) + else + base_obj = nphi + + jlt = LLVM.PointerType(LLVM.StructType(LLVM.LLVMType[]), 10) + pjlt = LLVM.PointerType(jlt) + + nphi = get_memory_data(nb, nphi) + nphi = bitcast!(nb, nphi, pjlt) + + GTy = LLVM.FunctionType(LLVM.PointerType(jlt, 13), LLVM.LLVMType[jlt, pjlt]) + gcloaded, _ = get_function!( + mod, + "julia.gc_loaded", + GTy + ) + nphi = call!(nb, GTy, gcloaded, LLVM.Value[base_obj, nphi]) + if value_type(nphi) != ty + nphi = bitcast!(nb, nphi, ty) + end + end + else + nphi = addrspacecast!(nb, nphi, ty) + end + if !isa(offset, LLVM.ConstantInt) || convert(Int64, offset) != 0 + nphi = bitcast!(nb, nphi, LLVM.PointerType(i8, addrspace(ty))) + nphi = gep!(nb, i8, nphi, [offset]) + nphi = bitcast!(nb, nphi, ty) + end + replace_uses!(inst, nphi) + end + for inst in todo + LLVM.API.LLVMInstructionEraseFromParent(inst) + end + end + end + end + return nothing +end + +function fix_decayaddr!(mod::LLVM.Module) + for f in functions(mod) + invalid = LLVM.Instruction[] + for bb in blocks(f), inst in instructions(bb) + if !isa(inst, LLVM.AddrSpaceCastInst) + continue + end + prety = value_type(operands(inst)[1]) + postty = value_type(inst) + if addrspace(prety) != 10 + continue + end + if addrspace(postty) != 0 + continue + end + push!(invalid, inst) + end + + for inst in invalid + temp = nothing + for u in LLVM.uses(inst) + st = LLVM.user(u) + # Storing _into_ the decay addr is okay + # we just cannot store the decayed addr into + # somewhere + if isa(st, LLVM.StoreInst) + if operands(st)[2] == inst + LLVM.API.LLVMSetOperand(st, 2 - 1, operands(inst)[1]) + continue + end + end + if isa(st, LLVM.LoadInst) + LLVM.API.LLVMSetOperand(st, 1 - 1, operands(inst)[1]) + continue + end + # if isa(st, LLVM.InsertValueInst) + # if operands(st)[1] == inst + # push!(invalid, st) + # LLVM.API.LLVMSetOperand(st, 1-1, LLVM.UndefValue(value_type(inst))) + # continue + # end + # if operands(st)[2] == inst + # push!(invalid, st) + # LLVM.API.LLVMSetOperand(st, 2-1, LLVM.UndefValue(value_type(inst))) + # continue + # end + # end + if !isa(st, LLVM.CallInst) + bt = GPUCompiler.backtrace(st) + msg = sprint() do io::IO + println(io, string(f)) + println(io, inst) + println(io, st) + print(io, "Illegal decay of nonnull\n") + if bt !== nothing + print(io, "\nCaused by:") + Base.show_backtrace(io, bt) + println(io) + end + end + throw(AssertionError(msg)) + end + + fop = operands(st)[end] + + intr = LLVM.API.LLVMGetIntrinsicID(fop) + + if intr == LLVM.Intrinsic("llvm.memcpy").id || + intr == LLVM.Intrinsic("llvm.memmove").id || + intr == LLVM.Intrinsic("llvm.memset").id + newvs = LLVM.Value[] + for (i, v) in enumerate(operands(st)[1:end-1]) + if v == inst + LLVM.API.LLVMSetOperand(st, i - 1, operands(inst)[1]) + push!(newvs, operands(inst)[1]) + continue + end + push!(newvs, v) + end + + nb = IRBuilder() + position!(nb, st) + if intr == LLVM.Intrinsic("llvm.memcpy").id + newi = memcpy!(nb, newvs[1], 0, newvs[2], 0, newvs[3]) + elseif intr == LLVM.Intrinsic("llvm.memmove").id + newi = memmove!(nb, newvs[1], 0, newvs[2], 0, newvs[3]) + else + newi = memset!(nb, newvs[1], newvs[2], newvs[3], 0) + end + + for idx in [ + LLVM.API.LLVMAttributeFunctionIndex, + LLVM.API.LLVMAttributeReturnIndex, + [ + LLVM.API.LLVMAttributeIndex(i) for + i = 1:(length(operands(st))-1) + ]..., + ] + idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) + count = LLVM.API.LLVMGetCallSiteAttributeCount(st, idx) + + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) + LLVM.API.LLVMGetCallSiteAttributes(st, idx, Attrs) + for j = 1:count + LLVM.API.LLVMAddCallSiteAttribute( + newi, + idx, + unsafe_load(Attrs, j), + ) + end + Libc.free(Attrs) + end + + API.EnzymeCopyMetadata(newi, st) + + LLVM.API.LLVMInstructionEraseFromParent(st) + continue + end + mayread = false + maywrite = false + sret = true + sretkind = kind(if LLVM.version().major >= 12 + TypeAttribute("sret", LLVM.Int32Type()) + else + EnumAttribute("sret") + end) + for (i, v) in enumerate(operands(st)[1:end-1]) + if v == inst + readnone = false + readonly = false + writeonly = false + t_sret = false + for a in collect(parameter_attributes(fop, i)) + if kind(a) == sretkind + t_sret = true + end + if kind(a) == kind(StringAttribute("enzyme_sret")) + t_sret = true + end + # if kind(a) == kind(StringAttribute("enzyme_sret_v")) + # t_sret = true + # end + if kind(a) == kind(EnumAttribute("readonly")) + readonly = true + end + if kind(a) == kind(EnumAttribute("readnone")) + readnone = true + end + if kind(a) == kind(EnumAttribute("writeonly")) + writeonly = true + end + end + if !t_sret + sret = false + end + if readnone + continue + end + if !readonly + maywrite = true + end + if !writeonly + mayread = true + end + end + end + if !sret + msg = sprint() do io + println(io, "Enzyme Internal Error: did not have sret when expected") + println(io, "f=", string(f)) + println(io, "inst=", string(inst)) + println(io, "st=", string(st)) + println(io, "fop=", string(fop)) + end + throw(AssertionError(msg)) + end + + elt = eltype(value_type(inst)) + if temp === nothing + nb = IRBuilder() + position!(nb, first(instructions(first(blocks(f))))) + temp = alloca!(nb, elt) + end + if mayread + nb = IRBuilder() + position!(nb, st) + ld = load!(nb, elt, operands(inst)[1]) + store!(nb, ld, temp) + end + if maywrite + nb = IRBuilder() + position!(nb, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(st))) + ld = load!(nb, elt, temp) + si = store!(nb, ld, operands(inst)[1]) + julia_post_cache_store(si.ref, nb.ref, reinterpret(Ptr{UInt64}, C_NULL)) + end + end + + if temp !== nothing + replace_uses!(inst, temp) + end + LLVM.API.LLVMInstructionEraseFromParent(inst) + end + end + return nothing +end + +function pre_attr!(mod::LLVM.Module) + return nothing + tofinalize = Tuple{LLVM.Function,Bool,Vector{Int64}}[] + for fn in collect(functions(mod)) + if isempty(blocks(fn)) + continue + end + if linkage(fn) != LLVM.API.LLVMInternalLinkage && + linkage(fn) != LLVM.API.LLVMPrivateLinkage + continue + end + + fty = LLVM.FunctionType(fn) + nfn = LLVM.Function(mod, "enzyme_attr_prev_" * LLVM.name(enzymefn), fty) + LLVM.IRBuilder() do builder + entry = BasicBlock(nfn, "entry") + position!(builder, entry) + cv = call!(fn, [LLVM.UndefValue(ty) for ty in parameters(fty)]) + LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(1), attr) + if LLVM.return_type(fty) == LLVM.VoidType() + ret!(builder) + else + ret!(builder, cv) + end + end + end + return nothing +end + +function prop_global!(g::LLVM.GlobalVariable) + newfns = String[] + changed = false + todo = Tuple{Vector{Cuint},LLVM.Value}[] + for u in LLVM.uses(g) + u = LLVM.user(u) + push!(todo, (Cuint[], u)) + end + while length(todo) > 0 + path, var = pop!(todo) + if isa(var, LLVM.LoadInst) + B = IRBuilder() + position!(B, var) + res = LLVM.initializer(g) + for p in path + res = extract_value!(B, res, p) + end + changed = true + for u in LLVM.uses(var) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + f2 = LLVM.called_operand(u) + if isa(f2, LLVM.Function) + push!(newfns, LLVM.name(f2)) + end + end + end + replace_uses!(var, res) + eraseInst(LLVM.parent(var), var) + continue + end + if isa(var, LLVM.AddrSpaceCastInst) + for u in LLVM.uses(var) + u = LLVM.user(u) + push!(todo, (path, u)) + end + continue + end + if isa(var, LLVM.ConstantExpr) && opcode(var) == LLVM.API.LLVMAddrSpaceCast + for u in LLVM.uses(var) + u = LLVM.user(u) + push!(todo, (path, u)) + end + continue + end + if isa(var, LLVM.GetElementPtrInst) + if all(isa(v, LLVM.ConstantInt) for v in operands(var)[2:end]) + if convert(Cuint, operands(var)[2]) == 0 + for u in LLVM.uses(var) + u = LLVM.user(u) + push!( + todo, + ( + vcat( + path, + collect(( + convert(Cuint, v) for v in operands(var)[3:end] + )), + ), + u, + ), + ) + end + end + continue + end + end + end + return changed, newfns +end + +# From https://llvm.org/doxygen/IR_2Instruction_8cpp_source.html#l00959 +function mayWriteToMemory(@nospecialize(inst::LLVM.Instruction); err_is_readonly::Bool = false)::Bool + # we will ignore fense here + if isa(inst, LLVM.StoreInst) + return true + end + if isa(inst, LLVM.VAArgInst) + return true + end + if isa(inst, LLVM.AtomicCmpXchgInst) + return true + end + if isa(inst, LLVM.AtomicRMWInst) + return true + end + if isa(inst, LLVM.CatchPadInst) + return true + end + if isa(inst, LLVM.CatchRetInst) + return true + end + if isa(inst, LLVM.CallInst) || isa(inst, LLVM.InvokeInst) || isa(inst, LLVM.CallBrInst) + idx = reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) + + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) + LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) + for j = 1:count + attr = LLVM.Attribute(unsafe_load(Attrs, j)) + if kind(attr) == kind(EnumAttribute("readnone")) + return false + end + if kind(attr) == kind(EnumAttribute("readonly")) + return false + end + # Note out of spec, and only legal in context of removing unused calls + if kind(attr) == kind(StringAttribute("enzyme_error")) && err_is_readonly + return false + end + if kind(attr) == kind(StringAttribute("memory")) + if is_readonly(MemoryEffect(value(attr))) + return false + end + end + end + Libc.free(Attrs) + return true + end + # Ignoring load unordered case + return false +end + +function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String}) + calls = LLVM.CallInst[] + + hasUser = false + for u in LLVM.uses(fn) + un = LLVM.user(u) + + # Only permit call users + if !isa(un, LLVM.CallInst) + return false + end + un = un::LLVM.CallInst + + # Passing the fn as an argument is not permitted + for op in collect(operands(un))[1:end-1] + if op == fn + return false + end + end + + # Something with a user is not permitted + for u2 in LLVM.uses(un) + hasUser = true + break + end + push!(calls, un) + end + + done = Set{LLVM.Function}() + todo = LLVM.Function[fn] + + while length(todo) != 0 + cur = pop!(todo) + if cur in done + continue + end + push!(done, cur) + + if is_readonly(cur) + continue + end + + if LLVM.name(cur) == "julia.safepoint" + continue + end + + if isempty(blocks(cur)) + return false + end + + err_is_readonly = !is_noreturn(cur) + + for bb in blocks(cur) + for inst in instructions(bb) + if !mayWriteToMemory(inst; err_is_readonly) + continue + end + if isa(inst, LLVM.CallInst) + + fn2 = LLVM.called_operand(inst) + if isa(fn2, LLVM.Function) + push!(todo, fn2) + continue + end + end + return false + end + end + end + + changed = set_readonly!(fn) + + if length(calls) == 0 || hasUser + return changed + end + + for c in calls + parentf = LLVM.parent(LLVM.parent(c)) + push!(next, LLVM.name(parentf)) + LLVM.API.LLVMInstructionEraseFromParent(c) + end + push!(next, LLVM.name(fn)) + return true +end + +function propagate_returned!(mod::LLVM.Module) + globs = LLVM.GlobalVariable[] + for g in globals(mod) + if linkage(g) == LLVM.API.LLVMInternalLinkage || + linkage(g) == LLVM.API.LLVMPrivateLinkage + if !isconstant(g) + continue + end + push!(globs, g) + end + end + todo = collect(functions(mod)) + while true + next = Set{String}() + changed = false + for g in globs + tc, tn = prop_global!(g) + changed |= tc + for f in tn + push!(next, f) + end + end + tofinalize = Tuple{LLVM.Function,Bool,Vector{Int64}}[] + for fn in functions(mod) + if isempty(blocks(fn)) + continue + end + if remove_readonly_unused_calls!(fn, next) + changed = true + end + attrs = collect(function_attributes(fn)) + prevent = any( + kind(attr) == kind(StringAttribute("enzyme_preserve_primal")) for + attr in attrs + ) + # if any(kind(attr) == kind(EnumAttribute("noinline")) for attr in attrs) + # continue + # end + argn = nothing + toremove = Int64[] + for (i, arg) in enumerate(parameters(fn)) + if any( + kind(attr) == kind(EnumAttribute("returned")) for + attr in collect(parameter_attributes(fn, i)) + ) + argn = i + end + + # remove unused sret-like + if !prevent && + ( + linkage(fn) == LLVM.API.LLVMInternalLinkage || + linkage(fn) == LLVM.API.LLVMPrivateLinkage + ) && + any( + kind(attr) == kind(EnumAttribute("nocapture")) for + attr in collect(parameter_attributes(fn, i)) + ) + val = nothing + illegalUse = false + torem = LLVM.Instruction[] + argeltype = if LLVM.version().major >= 12 + # TODO try to get sret element type if possible + # note currently opaque pointers has this break [and we need to doa check if opaque + # and if so get inner piece] + eltype(value_type(arg)) + else + eltype(value_type(arg)) + end + for u in LLVM.uses(fn) + un = LLVM.user(u) + if !isa(un, LLVM.CallInst) + illegalUse = true + break + end + ops = collect(operands(un))[1:end-1] + bad = false + for op in ops + if op == fn + bad = true + break + end + end + if bad + illegalUse = true + break + end + if !isa(ops[i], LLVM.AllocaInst) && !isa(ops[i], LLVM.UndefValue) && !isa(ops[i], LLVM.PoisonValue) + illegalUse = true + break + end + eltype = if isa(ops[i], LLVM.AllocaInst) + LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(ops[i])) + else + LLVM.eltype(value_type(ops[i])) + end + seenfn = false + todo = LLVM.Instruction[] + if isa(ops[i], LLVM.AllocaInst) + for u2 in LLVM.uses(ops[i]) + un2 = LLVM.user(u2) + push!(todo, un2) + end + end + while length(todo) > 0 + un2 = pop!(todo) + if isa(un2, LLVM.BitCastInst) + push!(torem, un2) + for u3 in LLVM.uses(un2) + un3 = LLVM.user(u3) + push!(todo, un3) + end + continue + end + if isa(un2, LLVM.GetElementPtrInst) + push!(torem, un2) + for u3 in LLVM.uses(un2) + un3 = LLVM.user(u3) + push!(todo, un3) + end + continue + end + if !isa(un2, LLVM.CallInst) + illegalUse = true + break + end + ff = LLVM.called_operand(un2) + if !isa(ff, LLVM.Function) + illegalUse = true + break + end + if un2 == un && !seenfn + seenfn = true + continue + end + intr = LLVM.API.LLVMGetIntrinsicID(ff) + if intr == LLVM.Intrinsic("llvm.lifetime.start").id + push!(torem, un2) + continue + end + if intr == LLVM.Intrinsic("llvm.lifetime.end").id + push!(torem, un2) + continue + end + if LLVM.name(ff) != "llvm.enzyme.sret_use" + illegalUse = true + break + end + push!(torem, un2) + end + if illegalUse + break + end + end + if !illegalUse + for c in reverse(torem) + eraseInst(LLVM.parent(c), c) + end + B = IRBuilder() + position!(B, first(instructions(first(blocks(fn))))) + al = alloca!(B, argeltype) + if value_type(al) != value_type(arg) + al = addrspacecast!(B, al, value_type(arg)) + end + LLVM.replace_uses!(arg, al) + end + end + + # interprocedural const prop from callers of arg + if !prevent && ( + linkage(fn) == LLVM.API.LLVMInternalLinkage || + linkage(fn) == LLVM.API.LLVMPrivateLinkage + ) + val = nothing + illegalUse = false + for u in LLVM.uses(fn) + un = LLVM.user(u) + if !isa(un, LLVM.CallInst) + illegalUse = true + break + end + ops = collect(operands(un))[1:end-1] + bad = false + for op in ops + if op == fn + bad = true + break + end + end + if bad + illegalUse = true + break + end + if isa(ops[i], LLVM.UndefValue) || isa(ops[i], LLVM.PoisonValue) + continue + end + if ops[i] == arg + continue + end + if isa(ops[i], LLVM.Constant) + if val === nothing + val = ops[i] + else + if val != ops[i] + illegalUse = true + break + end + end + continue + end + illegalUse = true + break + end + if !illegalUse + if val === nothing + val = LLVM.UndefValue(value_type(arg)) + end + for u in LLVM.uses(arg) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + f2 = LLVM.called_operand(u) + if isa(f2, LLVM.Function) + push!(next, LLVM.name(f2)) + end + end + changed = true + end + LLVM.replace_uses!(arg, val) + end + end + # see if there are no users of the value (excluding recursive/return) + baduse = false + for u in LLVM.uses(arg) + u = LLVM.user(u) + if argn == i && LLVM.API.LLVMIsAReturnInst(u) != C_NULL + continue + end + if !isa(u, LLVM.CallInst) + baduse = true + break + end + if LLVM.called_operand(u) != fn + baduse = true + break + end + for (si, op) in enumerate(operands(u)) + if si == i + continue + end + if op == arg + baduse = true + break + end + end + if baduse + break + end + end + if !baduse + push!(toremove, i - 1) + end + end + illegalUse = !( + linkage(fn) == LLVM.API.LLVMInternalLinkage || + linkage(fn) == LLVM.API.LLVMPrivateLinkage + ) + hasAnyUse = false + for u in LLVM.uses(fn) + un = LLVM.user(u) + if !isa(un, LLVM.CallInst) + illegalUse = true + continue + end + ops = collect(operands(un))[1:end-1] + bad = false + for op in ops + if op == fn + bad = true + break + end + end + if bad + illegalUse = true + continue + end + if argn !== nothing + hasUse = false + for u in LLVM.uses(un) + hasUse = true + break + end + if hasUse + changed = true + push!(next, LLVM.name(LLVM.parent(LLVM.parent(un)))) + LLVM.replace_uses!(un, ops[argn]) + end + else + for u in LLVM.uses(un) + hasAnyUse = true + break + end + end + end + #if the function return has no users whatsoever, remove it + if argn === nothing && + !hasAnyUse && + LLVM.return_type(LLVM.function_type(fn)) != LLVM.VoidType() + argn = -1 + end + if argn === nothing && length(toremove) == 0 + continue + end + if !illegalUse + push!(tofinalize, (fn, argn === nothing, toremove)) + end + end + for (fn, keepret, toremove) in tofinalize + todo = LLVM.CallInst[] + for u in LLVM.uses(fn) + un = LLVM.user(u) + push!(next, LLVM.name(LLVM.parent(LLVM.parent(un)))) + end + delete_writes_into_removed_args(fn, toremove, keepret) + nm = LLVM.name(fn) + #try + nfn = LLVM.Function( + API.EnzymeCloneFunctionWithoutReturnOrArgs(fn, keepret, toremove), + ) + for u in LLVM.uses(fn) + un = LLVM.user(u) + push!(todo, un) + end + for un in todo + md = metadata(un) + if !keepret && haskey(md, LLVM.MD_range) + delete!(md, LLVM.MD_range) + end + API.EnzymeSetCalledFunction(un, nfn, toremove) + end + eraseInst(mod, fn) + changed = true + # catch e + # break + #end + end + if !changed + break + else + todo = LLVM.Function[] + for name in next + fn = functions(mod)[name] + if linkage(fn) == LLVM.API.LLVMInternalLinkage || + linkage(fn) == LLVM.API.LLVMPrivateLinkage + has_user = false + for u in LLVM.uses(fn) + has_user = true + break + end + if !has_user + LLVM.API.LLVMDeleteFunction(fn) + end + end + push!(todo, fn) + end + end + end +end + +function delete_writes_into_removed_args(fn::LLVM.Function, toremove::Vector{Int64}, keepret::Bool) + args = collect(parameters(fn)) + for tr in toremove + tr = tr + 1 + todorep = Tuple{LLVM.Instruction, LLVM.Value}[] + for opv in LLVM.uses(args[tr]) + u = LLVM.user(opv) + push!(todorep, (u, args[tr])) + end + toerase = LLVM.Instruction[] + while length(todorep) != 0 + cur, cval = pop!(todorep) + if isa(cur, LLVM.StoreInst) + if operands(cur)[2] == cval + LLVM.API.LLVMInstructionEraseFromParent(nphi) + continue + end + end + if isa(cur, LLVM.GetElementPtrInst) || + isa(cur, LLVM.BitCastInst) || + isa(cur, LLVM.AddrSpaceCastInst) + for opv in LLVM.uses(cur) + u = LLVM.user(opv) + push!(todorep, (u, cur)) + end + continue + end + if isa(cur, LLVM.CallInst) + cf = LLVM.called_operand(cur) + if cf == fn + baduse = false + for (i, v) in enumerate(operands(cur)) + if i-1 in toremove + continue + end + if v == cval + baduse = true + end + end + if !baduse + continue + end + end + end + if !keepret && LLVM.API.LLVMIsAReturnInst(cur) != C_NULL + LLVM.API.LLVMSetOperand(cur, 0, LLVM.UndefValue(value_type(cval))) + continue + end + throw(AssertionError("Deleting argument with an unknown dependency, $(string(cur)) uses $(string(cval))")) + end + end +end + +function detect_writeonly!(mod::LLVM.Module) + for f in functions(mod) + if isempty(LLVM.blocks(f)) + continue + end + for (i, a) in enumerate(parameters(f)) + if isa(value_type(a), LLVM.PointerType) + todo = Tuple{LLVM.Value,LLVM.Instruction}[] + for u in LLVM.uses(a) + push!(todo, (a, LLVM.user(u))) + end + seen = Set{Tuple{LLVM.Value,LLVM.Instruction}}() + mayread = false + maywrite = false + while length(todo) > 0 + cur = pop!(todo) + if in(cur, seen) + continue + end + push!(seen, cur) + curv, curi = cur + + if isa(curi, LLVM.StoreInst) + if operands(curi)[1] != curv + maywrite = true + continue + end + end + + if isa(curi, LLVM.LoadInst) + mayread = true + continue + end + + if isa(curi, LLVM.GetElementPtrInst) || + isa(curi, LLVM.BitCastInst) || + isa(curi, LLVM.AddrSpaceCastInst) + for u in LLVM.uses(curi) + push!(todo, (curi, LLVM.user(u))) + end + continue + end + mayread = true + maywrite = true + end + if any( + map( + k -> kind(k) == kind(EnumAttribute("readnone")), + collect(parameter_attributes(f, i)), + ), + ) + mayread = false + maywrite = false + end + if any( + map( + k -> kind(k) == kind(EnumAttribute("readonly")), + collect(parameter_attributes(f, i)), + ), + ) + maywrite = false + end + if any( + map( + k -> kind(k) == kind(EnumAttribute("writeonly")), + collect(parameter_attributes(f, i)), + ), + ) + mayread = false + end + + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + f, + LLVM.API.LLVMAttributeIndex(i), + kind(EnumAttribute("readnone")), + ) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + f, + LLVM.API.LLVMAttributeIndex(i), + kind(EnumAttribute("readonly")), + ) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + f, + LLVM.API.LLVMAttributeIndex(i), + kind(EnumAttribute("writeonly")), + ) + + if !mayread && !maywrite + push!(parameter_attributes(f, i), LLVM.EnumAttribute("readnone", 0)) + elseif !mayread + push!(parameter_attributes(f, i), LLVM.EnumAttribute("writeonly", 0)) + elseif !maywrite + push!(parameter_attributes(f, i), LLVM.EnumAttribute("readonly", 0)) + end + + end + end + end + return nothing +end + +function validate_return_roots!(mod::LLVM.Module) + for f in functions(mod) + srets = [] + enzyme_srets = Int[] + enzyme_srets_v = Int[] + rroots = Int[] + rroots_v = Int[] + sretkind = kind(if LLVM.version().major >= 12 + TypeAttribute("sret", LLVM.Int32Type()) + else + EnumAttribute("sret") + end) + for (i, a) in enumerate(parameters(f)) + for attr in collect(parameter_attributes(f, i)) + if isa(attr, StringAttribute) + if kind(attr) == "enzymejl_returnRoots" + push!(rroots, i) + end + if kind(attr) == "enzymejl_returnRoots_v" + push!(rroots_v, i) + end + if kind(attr) == "enzyme_sret" + push!(enzyme_srets, i) + end + if kind(attr) == "enzyme_sret_v" + push!(enzyme_srets, i) + end + end + if kind(attr) == sretkind + push!(srets, (i, attr)) + end + end + end + if length(enzyme_srets) >= 1 && length(srets) == 0 + @assert enzyme_srets[1] == 1 + VT = LLVM.VoidType() + if length(enzyme_srets) == 1 && + LLVM.return_type(LLVM.function_type(f)) == VT && + length(enzyme_srets_v) == 0 + # Upgrading to sret requires writeonly + if !any( + kind(attr) == kind(EnumAttribute("writeonly")) for + attr in collect(parameter_attributes(f, 1)) + ) + msg = sprint() do io::IO + println(io, "Enzyme internal error (not writeonly sret)") + println(io, string(f)) + println( + io, + "collect(parameter_attributes(f, 1))=", + collect(parameter_attributes(f, 1)), + ) + end + throw(AssertionError(msg)) + end + + alty = nothing + for u in LLVM.uses(f) + u = LLVM.user(u) + @assert isa(u, LLVM.CallInst) + @assert LLVM.called_operand(u) == f + alop = operands(u)[1] + if !isa(alop, LLVM.AllocaInst) + msg = sprint() do io::IO + println(io, "Enzyme internal error (!isa(alop, LLVM.AllocaInst))") + println(io, "alop=", alop) + println(io, "u=", u) + println(io, "f=", string(f)) + end + throw(AssertionError(msg)) + + end + @assert isa(alop, LLVM.AllocaInst) + nty = API.EnzymeAllocaType(alop) + if alty === nothing + alty = nty + else + @assert alty == nty + end + attr = if LLVM.version().major >= 12 + TypeAttribute("sret", alty) + else + EnumAttribute("sret") + end + LLVM.API.LLVMAddCallSiteAttribute( + u, + LLVM.API.LLVMAttributeIndex(1), + attr, + ) + LLVM.API.LLVMRemoveCallSiteStringAttribute( + u, + LLVM.API.LLVMAttributeIndex(1), + "enzyme_sret", + length("enzyme_sret"), + ) + end + @assert alty !== nothing + attr = if LLVM.version().major >= 12 + TypeAttribute("sret", alty) + else + EnumAttribute("sret") + end + + push!(parameter_attributes(f, 1), attr) + delete!(parameter_attributes(f, 1), StringAttribute("enzyme_sret")) + srets = [(1, attr)] + enzyme_srets = Int[] + else + + enzyme_srets2 = Int[] + for idx in enzyme_srets + alty = nothing + bad = false + for u in LLVM.uses(f) + u = LLVM.user(u) + @assert isa(u, LLVM.CallInst) + @assert LLVM.called_operand(u) == f + alop = operands(u)[1] + @assert isa(alop, LLVM.AllocaInst) + nty = API.EnzymeAllocaType(alop) + if any_jltypes(nty) + bad = true + end + LLVM.API.LLVMRemoveCallSiteStringAttribute( + u, + LLVM.API.LLVMAttributeIndex(idx), + "enzyme_sret", + length("enzyme_sret"), + ) + end + if !bad + delete!( + parameter_attributes(f, idx), + StringAttribute("enzyme_sret"), + ) + else + push!(enzyme_srets2, idx) + end + end + enzyme_srets = enzyme_srets2 + + if length(enzyme_srets) != 0 + msg = sprint() do io::IO + println(io, "Enzyme internal error (length(enzyme_srets) != 0)") + println(io, "f=", string(f)) + println(io, "enzyme_srets=", enzyme_srets) + println(io, "enzyme_srets_v=", enzyme_srets_v) + println(io, "srets=", srets) + println(io, "rroots=", rroots) + println(io, "rroots_v=", rroots_v) + end + throw(AssertionError(msg)) + end + end + end + @assert length(enzyme_srets_v) == 0 + for (i, attr) in srets + @assert i == 1 + end + for i in rroots + @assert length(srets) != 0 + @assert i == 2 + end + # illegal + for i in rroots_v + @assert false + end + end +end + +function checkNoAssumeFalse(mod::LLVM.Module, shouldshow::Bool = false) + for f in functions(mod) + for bb in blocks(f), inst in instructions(bb) + if !isa(inst, LLVM.CallInst) + continue + end + intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(inst)) + if intr != LLVM.Intrinsic("llvm.assume").id + continue + end + op = operands(inst)[1] + if isa(op, LLVM.ConstantInt) + op2 = convert(Bool, op) + if !op2 + msg = sprint() do io + println(io, "Enzyme Internal Error: non-constant assume condition") + println(io, "mod=", string(mod)) + println(io, "f=", string(f)) + println(io, "bb=", string(bb)) + println(io, "op2=", string(op2)) + end + throw(AssertionError(msg)) + end + end + if isa(op, LLVM.ICmpInst) + if predicate_int(op) == LLVM.API.LLVMIntNE && + operands(op)[1] == operands(op)[2] + msg = sprint() do io + println(io, "Enzyme Internal Error: non-icmp assume condition") + println(io, "mod=", string(mod)) + println(io, "f=", string(f)) + println(io, "bb=", string(bb)) + println(io, "op=", string(op)) + end + throw(AssertionError(msg)) + end + end + end + end +end + +function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine) + # We need to run globalopt first. This is because remove dead args will otherwise + # take internal functions and replace their args with undef. Then on LLVM up to + # and including 12 (but fixed 13+), Attributor will incorrectly change functions that + # call code with undef to become unreachable, even when there exist other valid + # callsites. See: https://godbolt.org/z/9Y3Gv6q5M + ModulePassManager() do pm + global_dce!(pm) + LLVM.run!(pm, mod) + end + # Prevent dead-arg-elimination of functions which we may require args for in the derivative + funcT = LLVM.FunctionType(LLVM.VoidType(), LLVMType[], vararg = true) + if LLVM.version().major <= 15 + func, _ = get_function!( + mod, + "llvm.enzymefakeuse", + funcT, + LLVM.Attribute[EnumAttribute("readnone"), EnumAttribute("nofree")], + ) + rfunc, _ = get_function!( + mod, + "llvm.enzymefakeread", + funcT, + LLVM.Attribute[ + EnumAttribute("readonly"), + EnumAttribute("nofree"), + EnumAttribute("argmemonly"), + ], + ) + sfunc, _ = get_function!( + mod, + "llvm.enzyme.sret_use", + funcT, + LLVM.Attribute[ + EnumAttribute("readonly"), + EnumAttribute("nofree"), + EnumAttribute("argmemonly"), + ], + ) + else + func, _ = get_function!( + mod, + "llvm.enzymefakeuse", + funcT, + LLVM.Attribute[EnumAttribute("memory", NoEffects.data), EnumAttribute("nofree")], + ) + rfunc, _ = get_function!( + mod, + "llvm.enzymefakeread", + funcT, + LLVM.Attribute[EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")], + ) + sfunc, _ = get_function!( + mod, + "llvm.enzyme.sret_use", + funcT, + LLVM.Attribute[EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")], + ) + end + + for fn in functions(mod) + if isempty(blocks(fn)) + continue + end + # Ensure that interprocedural optimizations do not delete the use of returnRoots (or shadows) + # if inactive sret, this will only occur on 2. If active sret, inactive retRoot, can on 3, and + # active both can occur on 4. If the original sret is removed (at index 1) we no longer need + # to preserve this. + for idx in (2, 3, 4) + if length(collect(parameters(fn))) >= idx && any( + ( + kind(attr) == kind(StringAttribute("enzymejl_returnRoots")) || + kind(attr) == kind(StringAttribute("enzymejl_returnRoots_v")) + ) for attr in collect(parameter_attributes(fn, idx)) + ) + for u in LLVM.uses(fn) + u = LLVM.user(u) + @assert isa(u, LLVM.CallInst) + B = IRBuilder() + nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u)) + position!(B, nextInst) + inp = operands(u)[idx] + cl = call!(B, funcT, rfunc, LLVM.Value[inp]) + if isa(value_type(inp), LLVM.PointerType) + LLVM.API.LLVMAddCallSiteAttribute( + cl, + LLVM.API.LLVMAttributeIndex(1), + EnumAttribute("nocapture"), + ) + end + end + end + end + sretkind = kind(if LLVM.version().major >= 12 + TypeAttribute("sret", LLVM.Int32Type()) + else + EnumAttribute("sret") + end) + for idx in (1, 2) + if length(collect(parameters(fn))) < idx + continue + end + attrs = collect(parameter_attributes(fn, idx)) + if any( + ( + kind(attr) == sretkind || + kind(attr) == kind(StringAttribute("enzyme_sret")) || + kind(attr) == kind(StringAttribute("enzyme_sret_v")) + ) for attr in attrs + ) && any_jltypes(sret_ty(fn, idx)) + for u in LLVM.uses(fn) + u = LLVM.user(u) + if isa(u, LLVM.ConstantExpr) + u = LLVM.user(only(LLVM.uses(u))) + end + if !isa(u, LLVM.CallInst) + continue + end + @assert isa(u, LLVM.CallInst) + B = IRBuilder() + nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u)) + position!(B, nextInst) + inp = operands(u)[idx] + cl = call!(B, funcT, sfunc, LLVM.Value[inp]) + if isa(value_type(inp), LLVM.PointerType) + LLVM.API.LLVMAddCallSiteAttribute( + cl, + LLVM.API.LLVMAttributeIndex(1), + EnumAttribute("nocapture"), + ) + end + end + end + end + attrs = collect(function_attributes(fn)) + prevent = any( + kind(attr) == kind(StringAttribute("enzyme_preserve_primal")) for attr in attrs + ) + # && any(kind(attr) == kind(StringAttribute("enzyme_math")) for attr in attrs) + if prevent + B = IRBuilder() + position!(B, first(instructions(first(blocks(fn))))) + call!(B, funcT, func, LLVM.Value[p for p in parameters(fn)]) + end + end + propagate_returned!(mod) + ModulePassManager() do pm + instruction_combining!(pm) + jl_inst_simplify!(pm) + alloc_opt_tm!(pm, tm) + scalar_repl_aggregates_ssa!(pm) # SSA variant? + cse!(pm) + LLVM.run!(pm, mod) + end + propagate_returned!(mod) + pre_attr!(mod) + if RunAttributor[] + if LLVM.version().major >= 13 + ModulePassManager() do pm + API.EnzymeAddAttributorLegacyPass(pm) + LLVM.run!(pm, mod) + end + end + end + propagate_returned!(mod) + ModulePassManager() do pm + instruction_combining!(pm) + jl_inst_simplify!(pm) + alloc_opt_tm!(pm, tm) + scalar_repl_aggregates_ssa!(pm) # SSA variant? + if RunAttributor[] + if LLVM.version().major >= 13 + API.EnzymeAddAttributorLegacyPass(pm) + end + end + cse!(pm) + LLVM.run!(pm, mod) + end + post_attr!(mod) + propagate_returned!(mod) + + for u in LLVM.uses(rfunc) + u = LLVM.user(u) + eraseInst(LLVM.parent(u), u) + end + eraseInst(mod, rfunc) + for u in LLVM.uses(sfunc) + u = LLVM.user(u) + eraseInst(LLVM.parent(u), u) + end + eraseInst(mod, sfunc) + for fn in functions(mod) + for b in blocks(fn) + inst = first(LLVM.instructions(b)) + if isa(inst, LLVM.CallInst) + fn = LLVM.called_operand(inst) + if fn == func + eraseInst(b, inst) + end + end + end + end + eraseInst(mod, func) +end + diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl index 13bacb06a5..56db23430e 100644 --- a/src/rules/activityrules.jl +++ b/src/rules/activityrules.jl @@ -19,14 +19,7 @@ function julia_activity_rule(f::LLVM.Function) end expectLen -= length(parmsRemoved) - swiftself = any( - any( - map( - k -> kind(k) == kind(EnumAttribute("swiftself")), - collect(parameter_attributes(f, i)), - ), - ) for i = 1:length(collect(parameters(f))) - ) + swiftself = has_swiftself(f) if swiftself expectLen += 1 diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 55f3136286..6956842462 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -33,14 +33,7 @@ function enzyme_custom_setup_args( returnRoots = returnRoots !== nothing cv = LLVM.called_operand(orig) - swiftself = any( - any( - map( - k -> kind(k) == kind(EnumAttribute("swiftself")), - collect(parameter_attributes(cv, i)), - ), - ) for i = 1:length(collect(parameters(cv))) - ) + swiftself = has_swiftself(cv) jlargs = classify_arguments( mi.specTypes, called_type(orig), @@ -523,40 +516,31 @@ end world = enzyme_extract_world(fn) @safe_debug "Trying to apply custom forward rule" TT isKWCall llvmf = nothing - if isKWCall - if EnzymeRules.isapplicable(kwfunc, TT; world) - @safe_debug "Applying custom forward rule (kwcall)" TT - llvmf = nested_codegen!(mode, mod, kwfunc, TT, world) - fwd_RT = Compiler.primal_return_type_world(Forward, world, Core.Typeof(kwfunc), TT) - else - TT = Tuple{typeof(world),typeof(kwfunc),TT.parameters...} - llvmf = nested_codegen!(mode, mod, custom_rule_method_error, TT, world) - pushfirst!(args, LLVM.ConstantInt(world)) - fwd_RT = Union{} - end + + functy = if isKWCall + rkwfunc = typeof(Core.kwfunc(EnzymeRules.forward)) else - if EnzymeRules.isapplicable(EnzymeRules.forward, TT; world) - @safe_debug "Applying custom forward rule" TT - llvmf = nested_codegen!(mode, mod, EnzymeRules.forward, TT, world) - fwd_RT = Compiler.primal_return_type_world(Forward, world, typeof(EnzymeRules.forward), TT) - else - TT = Tuple{typeof(world),typeof(EnzymeRules.forward),TT.parameters...} - llvmf = nested_codegen!(mode, mod, custom_rule_method_error, TT, world) - pushfirst!(args, LLVM.ConstantInt(world)) - fwd_RT = Union{} - end + typeof(EnzymeRules.forward) end + @safe_debug "Applying custom forward rule" TT = TT, functy = functy + fmi, fwd_RT = try + fmi = my_methodinstance(functy, TT, world) + fwd_RT = primal_return_type_world(Forward, world, fmi) + fmi, fwd_RT + catch e + TT = Tuple{typeof(world),functy,TT.parameters...} + fmi = my_methodinstance(typeof(custom_rule_method_error), TT, world) + pushfirst!(args, LLVM.ConstantInt(world)) + fwd_RT = Union{} + fmi, fwd_RT + end + fmi = fmi::Core.MethodInstance + fwd_RT = fwd_RT::Type + llvmf = nested_codegen!(mode, mod, fmi, world) push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) - swiftself = any( - any( - map( - k -> kind(k) == kind(EnumAttribute("swiftself")), - collect(parameter_attributes(llvmf, i)), - ), - ) for i = 1:length(collect(parameters(llvmf))) - ) + swiftself = has_swiftself(llvmf) if swiftself pushfirst!(reinsert_gcmarker!(fn, B)) end @@ -603,12 +587,7 @@ end debug_from_orig!(gutils, res, orig) callconv!(res, callconv(llvmf)) - hasNoRet = any( - map( - k -> kind(k) == kind(EnumAttribute("noreturn")), - collect(function_attributes(llvmf)), - ), - ) + hasNoRet = has_fn_attr(llvmf, EnumAttribute("noreturn")) if hasNoRet return false @@ -802,10 +781,9 @@ end mode = get_mode(gutils) - ami = nothing augprimal_tt = copy(activity) - if isKWCall + functy = if isKWCall popfirst!(augprimal_tt) @assert kwtup !== nothing insert!(augprimal_tt, 1, kwtup) @@ -814,50 +792,32 @@ end insert!(augprimal_tt, 5, Type{RT}) augprimal_TT = Tuple{augprimal_tt...} - kwfunc = Core.kwfunc(EnzymeRules.augmented_primal) - try - ami = my_methodinstance(Core.Typeof(kwfunc), augprimal_TT, world) - @safe_debug "Applying custom augmented_primal rule (kwcall)" TT = augprimal_TT - catch e - augprimal_TT = Tuple{typeof(world),typeof(kwfunc),augprimal_TT.parameters...} - ami = my_methodinstance( - typeof(custom_rule_method_error), - augprimal_TT, - world, - ) - if forward - pushfirst!(args, LLVM.ConstantInt(world)) - end - end + Core.Typeof(Core.kwfunc(EnzymeRules.augmented_primal)) else @assert kwtup === nothing insert!(augprimal_tt, 1, C) insert!(augprimal_tt, 3, Type{RT}) augprimal_TT = Tuple{augprimal_tt...} - try - ami = my_methodinstance( - Core.Typeof(EnzymeRules.augmented_primal), - augprimal_TT, - world, - ) - @safe_debug "Applying custom augmented_primal rule" TT = augprimal_TT - catch e - augprimal_TT = Tuple{ - typeof(world), - typeof(EnzymeRules.augmented_primal), - augprimal_TT.parameters..., - } - ami = my_methodinstance( - typeof(custom_rule_method_error), - augprimal_TT, - world, - ) - if forward - pushfirst!(args, LLVM.ConstantInt(world)) - end + typeof(EnzymeRules.augmented_primal) + end + + ami = try + my_methodinstance(functy, augprimal_TT, world) + catch e + augprimal_TT = Tuple{typeof(world),functy,augprimal_TT.parameters...} + ami = my_methodinstance( + typeof(custom_rule_method_error), + augprimal_TT, + world, + ) + if forward + pushfirst!(args, LLVM.ConstantInt(world)) end + ami end + ami = ami::Core.MethodInstance + @safe_debug "Applying custom augmented_primal rule" TT = augprimal_TT, functy=functy return ami, augprimal_TT, ( @@ -942,17 +902,10 @@ function enzyme_custom_common_rev( @assert ami !== nothing target = DefaultCompilerTarget() params = PrimalCompilerParams(mode) - aug_RT = something( - Core.Compiler.typeinf_type( - GPUCompiler.get_interpreter( - CompilerJob(ami, CompilerConfig(target, params; kernel = false), world), - ), - ami.def, - ami.specTypes, - ami.sparam_vals, - ), - Any, + interp = GPUCompiler.get_interpreter( + CompilerJob(ami, CompilerConfig(target, params; kernel = false), world), ) + aug_RT = return_type(interp, ami) if kwtup !== nothing && kwtup <: Duplicated @safe_debug "Non-constant keyword argument found for " augprimal_TT emit_error( @@ -964,7 +917,6 @@ function enzyme_custom_common_rev( end rev_TT = nothing - rev_RT = nothing TapeT = Nothing @@ -1006,6 +958,7 @@ function enzyme_custom_common_rev( if forward llvmf = nested_codegen!(mode, mod, ami, world) @assert llvmf !== nothing + rev_RT = nothing else tt = copy(activity) if isKWCall @@ -1024,50 +977,28 @@ function enzyme_custom_common_rev( end rev_TT = Tuple{tt...} - if isKWCall - rkwfunc = Core.kwfunc(EnzymeRules.reverse) - if EnzymeRules.isapplicable(rkwfunc, rev_TT; world) - @safe_debug "Applying custom reverse rule (kwcall)" TT = rev_TT - try - llvmf = nested_codegen!(mode, mod, rkwfunc, rev_TT, world) - rev_RT = Compiler.primal_return_type_world(Reverse, world, Core.Typeof(rkwfunc), rev_TT) - catch e - rev_TT = Tuple{typeof(world),typeof(rkwfunc),rev_TT.parameters...} - llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) - pushfirst!(args, LLVM.ConstantInt(world)) - rev_RT = Union{} - applicablefn = false - end - else - rev_TT = Tuple{typeof(world),typeof(rkwfunc),rev_TT.parameters...} - llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) - pushfirst!(args, LLVM.ConstantInt(world)) - rev_RT = Union{} - applicablefn = false - end + functy = if isKWCall + rkwfunc = typeof(Core.kwfunc(EnzymeRules.reverse)) else - if EnzymeRules.isapplicable(EnzymeRules.reverse, rev_TT; world) - @safe_debug "Applying custom reverse rule" TT = rev_TT - try - llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, rev_TT, world) - rev_RT = Compiler.primal_return_type_world(Reverse, world, typeof(EnzymeRules.reverse), rev_TT) - catch e - rev_TT = - Tuple{typeof(world),typeof(EnzymeRules.reverse),rev_TT.parameters...} - llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) - pushfirst!(args, LLVM.ConstantInt(world)) - rev_RT = Union{} - applicablefn = false - end - else - rev_TT = - Tuple{typeof(world),typeof(EnzymeRules.reverse),rev_TT.parameters...} - llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) - pushfirst!(args, LLVM.ConstantInt(world)) - rev_RT = Union{} - applicablefn = false - end + typeof(EnzymeRules.reverse) end + + @safe_debug "Applying custom reverse rule" TT = rev_TT, functy=functy + rmi, rev_RT = try + rmi = my_methodinstance(functy, rev_TT, world) + rev_RT = return_type(interp, rmi) + rmi, rev_RT + catch e + rev_TT = Tuple{typeof(world),functy,rev_TT.parameters...} + rmi = my_methodinstance(typeof(custom_rule_method_error), rev_TT, world) + pushfirst!(args, LLVM.ConstantInt(world)) + rev_RT = Union{} + applicablefn = false + rmi, rev_RT + end + rmi = rmi::Core.MethodInstance + rev_RT = rev_RT::Type + llvmf = nested_codegen!(mode, mod, rmi, world) end push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) @@ -1092,14 +1023,7 @@ function enzyme_custom_common_rev( # llvmf = nested_codegen!(mode, mod, rev_func, Tuple{argTys...}, world) # end - swiftself = any( - any( - map( - k -> kind(k) == kind(EnumAttribute("swiftself")), - collect(parameter_attributes(llvmf, i)), - ), - ) for i = 1:length(collect(parameters(llvmf))) - ) + swiftself = has_swiftself(llvmf) miRT = enzyme_custom_extract_mi(llvmf)[2] _, sret, returnRoots = get_return_info(miRT) @@ -1311,12 +1235,7 @@ function enzyme_custom_common_rev( debug_from_orig!(gutils, res, orig) callconv!(res, callconv(llvmf)) - hasNoRet = any( - map( - k -> kind(k) == kind(EnumAttribute("noreturn")), - collect(function_attributes(llvmf)), - ), - ) + hasNoRet = has_fn_attr(llvmf, EnumAttribute("noreturn")) if hasNoRet return tapeV @@ -1608,10 +1527,7 @@ end fop = called_operand(orig)::LLVM.Function for (i, v) in enumerate(operands(orig)[1:end-1]) if v == val - if !any( - a -> kind(a) == kind(StringAttribute("enzymejl_returnRoots")), - collect(parameter_attributes(fop, i)), - ) + if !has_fn_attr(fop, StringAttribute("enzymejl_returnRoots")) non_rooting_use = true break end diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 35399d0b24..a7e3fbd712 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -160,12 +160,7 @@ include("parallelrules.jl") if in(name, ("ijl_f_finalizer", "jl_f_finalizer")) return common_finalizer_fwd(2, B, orig, gutils, normalR, shadowR) end - if any( - map( - k -> kind(k) == kind(StringAttribute("enzyme_inactive")), - collect(function_attributes(F)), - ), - ) + if has_fn_attr(F, StringAttribute("enzyme_inactive")) return true end end @@ -234,12 +229,7 @@ end if in(name, ("ijl_f_finalizer", "jl_f_finalizer")) return common_finalizer_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end - if any( - map( - k -> kind(k) == kind(StringAttribute("enzyme_inactive")), - collect(function_attributes(F)), - ), - ) + if has_fn_attr(F, StringAttribute("enzyme_inactive")) return true end end @@ -317,12 +307,7 @@ end common_finalizer_rev(2, B, orig, gutils, tape) return nothing end - if any( - map( - k -> kind(k) == kind(StringAttribute("enzyme_inactive")), - collect(function_attributes(F)), - ), - ) + if has_fn_attr(F, StringAttribute("enzyme_inactive")) return nothing end end @@ -343,12 +328,7 @@ end if in(name, ("ijl_invoke", "jl_invoke")) return common_invoke_fwd(2, B, orig, gutils, normalR, shadowR) end - if any( - map( - k -> kind(k) == kind(StringAttribute("enzyme_inactive")), - collect(function_attributes(F)), - ), - ) + if has_fn_attr(F, StringAttribute("enzyme_inactive")) return true end end @@ -365,12 +345,7 @@ end if in(name, ("ijl_invoke", "jl_invoke")) return common_invoke_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end - if any( - map( - k -> kind(k) == kind(StringAttribute("enzyme_inactive")), - collect(function_attributes(F)), - ), - ) + if has_fn_attr(F, StringAttribute("enzyme_inactive")) return true end end @@ -388,12 +363,7 @@ end common_invoke_rev(2, B, orig, gutils, tape) return nothing end - if any( - map( - k -> kind(k) == kind(StringAttribute("enzyme_inactive")), - collect(function_attributes(F)), - ), - ) + if has_fn_attr(F, StringAttribute("enzyme_inactive")) return nothing end end diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index 08f3f04afd..78c9cd9ce8 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -32,6 +32,11 @@ function runtime_newtask_fwd( return ccall(:jl_new_task, Ref{Task}, (Any, Any, Int), fclosure, post, ssize) end +struct Return2 + ret1::Any + ret2::Any +end + function runtime_newtask_augfwd( fn::FT1, dfn::FT2, diff --git a/src/sugar.jl b/src/sugar.jl index ba53f46a00..3e68830100 100644 --- a/src/sugar.jl +++ b/src/sugar.jl @@ -15,13 +15,10 @@ end params = Compiler.PrimalCompilerParams(API.DEM_ForwardMode) mi = my_methodinstance(fn, Tuple{T, Int}) job = GPUCompiler.CompilerJob(mi, GPUCompiler.CompilerConfig(target, params; kernel = false)) - mod, meta = GPUCompiler.codegen( - :llvm, - job; - optimize = false, - cleanup = false, - validate = false, - ) + + GPUCompiler.prepare_job!(job) + mod, meta = GPUCompiler.emit_llvm(job; libraries=true, toplevel=true, optimize=false, cleanup=false, only_entry=false, validate=false) + copysetfn = meta.entry blk = first(LLVM.blocks(copysetfn)) iter = LLVM.API.LLVMGetFirstInstruction(blk) @@ -40,12 +37,7 @@ end end end end - hasNoRet = any( - map( - k -> kind(k) == kind(LLVM.EnumAttribute("noreturn")), - collect(LLVM.function_attributes(copysetfn)), - ), - ) + hasNoRet = Compiler.has_fn_attr(copysetfn, LLVM.EnumAttribute("noreturn")) @assert !hasNoRet if !hasNoRet push!(LLVM.function_attributes(copysetfn), LLVM.EnumAttribute("alwaysinline", 0)) diff --git a/src/typeutils/conversion.jl b/src/typeutils/conversion.jl new file mode 100644 index 0000000000..27644aefcb --- /dev/null +++ b/src/typeutils/conversion.jl @@ -0,0 +1,128 @@ +# return result and if contains any +function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} + tkind = LLVM.API.LLVMGetTypeKind(Type) + if tkind == LLVM.API.LLVMStructTypeKind + tys = DataType[] + nelems = LLVM.API.LLVMCountStructElementTypes(Type) + containsAny = false + syms = Symbol[] + for i = 1:nelems + e = LLVM.API.LLVMStructGetTypeAtIndex(Type, i - 1) + T, sub = to_tape_type(e) + containsAny |= sub + push!(tys, T) + push!(syms, Symbol(i)) + end + Tup = Tuple{tys...} + if containsAny + res = (syms...,) + return NamedTuple{res,Tup}, false + else + return Tup, false + end + end + if tkind == LLVM.API.LLVMPointerTypeKind + addrspace = LLVM.API.LLVMGetPointerAddressSpace(Type) + if 10 <= addrspace <= 12 + return Any, true + else + e = LLVM.API.LLVMGetElementType(Type) + tkind2 = LLVM.API.LLVMGetTypeKind(e) + if tkind2 == LLVM.API.LLVMFunctionTypeKind + return Core.LLVMPtr{Cvoid,Int(addrspace)}, false + else + return Core.LLVMPtr{to_tape_type(e)[1],Int(addrspace)}, false + end + end + end + if tkind == LLVM.API.LLVMArrayTypeKind + e = LLVM.API.LLVMGetElementType(Type) + T, sub = to_tape_type(e) + len = Int(LLVM.API.LLVMGetArrayLength(Type)) + Tup = NTuple{len,T} + if sub + return NamedTuple{ntuple(Core.Symbol, Val(len)),Tup}, false + else + return Tup, false + end + end + if tkind == LLVM.API.LLVMVectorTypeKind + e = LLVM.API.LLVMGetElementType(Type) + T, sub = to_tape_type(e) + len = Int(LLVM.API.LLVMGetVectorSize(Type)) + Tup = NTuple{len,T} + if sub + return NamedTuple{ntuple(Core.Symbol, Val(len)),Tup}, false + else + return Tup, false + end + end + if tkind == LLVM.API.LLVMIntegerTypeKind + N = LLVM.API.LLVMGetIntTypeWidth(Type) + if N == 1 + return Bool, false + elseif N == 8 + return UInt8, false + elseif N == 16 + return UInt16, false + elseif N == 32 + return UInt32, false + elseif N == 64 + return UInt64, false + elseif N == 128 + return UInt128, false + elseif N == 256 + return UInt256, false + elseif N == 512 + return UInt512, false + elseif N == 1024 + return UInt1024, false + elseif N == 2048 + return UInt2048, false + else + error("Can't construct tape type for integer of width $N") + end + end + if tkind == LLVM.API.LLVMHalfTypeKind + return Float16, false + end + if tkind == LLVM.API.LLVMFloatTypeKind + return Float32, false + end + if tkind == LLVM.API.LLVMDoubleTypeKind + return Float64, false + end + if tkind == LLVM.API.LLVMFP128TypeKind + return Float128, false + end + error("Can't construct tape type for $Type $(string(Type)) $tkind") +end + +function tape_type(@nospecialize(LLVMType::LLVM.LLVMType)) + TT, isAny = to_tape_type(LLVMType.ref) + if isAny + return AnonymousStruct(Tuple{Any}) + end + return TT +end + +from_tape_type(::Type{T}) where {T<:AbstractFloat} = convert(LLVMType, T) +from_tape_type(::Type{T}) where {T<:Integer} = convert(LLVMType, T) +from_tape_type(::Type{NTuple{Size,T}}) where {Size,T} = + LLVM.ArrayType(from_tape_type(T), Size) +from_tape_type(::Type{Core.LLVMPtr{T,Addr}}) where {T,Addr} = + LLVM.PointerType(from_tape_type(UInt8), Addr) +# from_tape_type(::Type{Core.LLVMPtr{T, Addr}}, ctx) where {T, Addr} = LLVM.PointerType(from_tape_type(T, ctx), Addr) +from_tape_type(::Type{Any}) = LLVM.PointerType(LLVM.StructType(LLVM.LLVMType[]), Tracked) +function from_tape_type(::Type{NamedTuple{A,B}}) where {A,B} + from_tape_type(B) +end +function from_tape_type(::Type{B}) where {B<:Tuple} + ar = LLVM.LLVMType[from_tape_type(b) for b in B.parameters] + if length(B.parameters) >= 1 && all(ar[1] == b for b in ar) + return LLVM.ArrayType(ar[1], length(B.parameters)) + else + return LLVM.StructType(LLVM.LLVMType[from_tape_type(b) for b in B.parameters]) + end +end + diff --git a/src/typeutils/inference.jl b/src/typeutils/inference.jl new file mode 100644 index 0000000000..cd000b70c9 --- /dev/null +++ b/src/typeutils/inference.jl @@ -0,0 +1,182 @@ +function return_type(interp::Core.Compiler.AbstractInterpreter, mi::Core.MethodInstance)::Type + @static if VERSION < v"1.11.0" + code = Core.Compiler.get(Core.Compiler.code_cache(interp), mi, nothing) + if code isa Core.Compiler.CodeInstance + return code.rettype + end + result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp)) + Core.Compiler.typeinf(interp, result, :global) + Core.Compiler.is_inferred(result) || return Any + Core.Compiler.widenconst(Core.Compiler.ignorelimited(result.result)) + else + something(Core.Compiler.typeinf_type(interp, mi), Any) + end +end + +function primal_interp_world( + @nospecialize(::ReverseMode), + world::UInt +) + mode = Enzyme.API.DEM_ReverseModeCombined + + CT = @static if VERSION >= v"1.11.0-DEV.1552" + EnzymeCacheToken( + typeof(DefaultCompilerTarget()), + false, + GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=# + EnzymeCompilerParams, + false, + ) + else + Enzyme.Compiler.GLOBAL_REV_CACHE + end + + Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode) +end + +function primal_interp_world( + @nospecialize(::ForwardMode), + world::UInt +) + mode = Enzyme.API.DEM_ForwardMode + + CT = @static if VERSION >= v"1.11.0-DEV.1552" + EnzymeCacheToken( + typeof(DefaultCompilerTarget()), + false, + GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=# + EnzymeCompilerParams, + true, + ) + else + Enzyme.Compiler.GLOBAL_FWD_CACHE + end + + Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode) +end + +@inline primal_interp_world( + @nospecialize(::ReverseModeSplit), + world::UInt) = primal_interp_world(Reverse, world) + +function primal_return_type_world( + @nospecialize(mode::Mode), + world::UInt, + @nospecialize(TT::Type), +) + Core.Compiler._return_type(primal_interp_world(mode, world), TT) +end + +function primal_return_type_world( + @nospecialize(mode::Mode), + world::UInt, + mi::Core.MethodInstance, +) + interp = primal_interp_world(mode, world) + return_type(interp, mi) +end + +primal_return_type_world( + @nospecialize(mode::Mode), + world::UInt, + @nospecialize(FT::Type), + @nospecialize(TT::Type), + ) = primal_return_type_world(mode, world, Tuple{FT, TT.parameters...}) + +function primal_return_type_generator(world::UInt, source, self, @nospecialize(mode::Type), @nospecialize(ft::Type), @nospecialize(tt::Type)) + @nospecialize + @assert Core.Compiler.isType(ft) && Core.Compiler.isType(tt) + @assert mode <: Mode + mode = mode() + ft = ft.parameters[1] + tt = tt.parameters[1] + + # validation + ft <: Core.Builtin && + error("$(GPUCompiler.unsafe_function_from_type(ft)) is not a generic function") + + # look up the method + method_error = :(throw(MethodError(ft, tt, $world))) + sig = Tuple{ft,tt.parameters...} + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results + #interp = primal_interp_world(mode, world) + #method_table = Core.Compiler.method_table(interp) + method_table = nothing + mthds = Base._methods_by_ftype( + sig, + method_table, + -1, #=lim=# + world, + false, #=ambig=# + min_world, + max_world, + has_ambig, + ) + stub = Core.GeneratedFunctionStub( + identity, + Core.svec(:methodinstance, :mode, :ft, :tt), + Core.svec(), + ) + mthds === nothing && return stub(world, source, method_error) + length(mthds) == 1 || return stub(world, source, method_error) + + # look up the method and code instance + mtypes, msp, m = mthds[1] + mi = ccall( + :jl_specializations_get_linfo, + Ref{Core.MethodInstance}, + (Any, Any, Any), + m, + mtypes, + msp, + ) + ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo + + # prepare a new code info + new_ci = copy(ci) + empty!(new_ci.code) + @static if isdefined(Core, :DebugInfo) + new_ci.debuginfo = Core.DebugInfo(:none) + else + empty!(new_ci.codelocs) + resize!(new_ci.linetable, 1) # see note below + end + empty!(new_ci.ssaflags) + new_ci.ssavaluetypes = 0 + new_ci.min_world = min_world[] + new_ci.max_world = max_world[] + new_ci.edges = Core.MethodInstance[mi] + # XXX: setting this edge does not give us proper method invalidation, see + # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. + # invoking `code_llvm` also does the necessary codegen, as does calling the + # underlying C methods -- which GPUCompiler does, so everything Just Works. + + # prepare the slots + new_ci.slotnames = Symbol[Symbol("#self#"), :mode, :ft, :tt] + new_ci.slotflags = UInt8[0x00 for i = 1:4] + + # return the codegen world age + res = primal_return_type_world(mode, world, mi) + push!(new_ci.code, Core.Compiler.ReturnNode(res)) + push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` + @static if isdefined(Core, :DebugInfo) + else + push!(new_ci.codelocs, 1) # see note below + end + new_ci.ssavaluetypes += 1 + + # NOTE: we keep the first entry of the original linetable, and use it for location info + # on the call to check_cache. we can't not have a codeloc (using 0 causes + # corruption of the back trace), and reusing the target function's info + # has as advantage that we see the name of the kernel in the backtraces. + + return new_ci +end + +@eval Base.@assume_effects :removable :foldable :nothrow @inline function primal_return_type(mode::Mode, ft::Type, tt::Type) + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, primal_return_type_generator)) +end + diff --git a/src/typeutils/jltypes.jl b/src/typeutils/jltypes.jl new file mode 100644 index 0000000000..4c1b322eb9 --- /dev/null +++ b/src/typeutils/jltypes.jl @@ -0,0 +1,297 @@ + +struct RemovedParam end + +# Modified from GPUCompiler classify_arguments +function classify_arguments( + @nospecialize(source_sig::Type), + codegen_ft::LLVM.FunctionType, + has_sret::Bool, + has_returnroots::Bool, + has_swiftself::Bool, + parmsRemoved::Vector{UInt64}, +) + codegen_types = parameters(codegen_ft) + + args = [] + codegen_i = 1 + orig_i = 1 + if has_sret + if !in(orig_i - 1, parmsRemoved) + codegen_i += 1 + end + orig_i += 1 + end + if has_returnroots + if !in(orig_i - 1, parmsRemoved) + codegen_i += 1 + end + orig_i += 1 + end + if has_swiftself + if !in(orig_i - 1, parmsRemoved) + codegen_i += 1 + end + orig_i += 1 + end + for (source_i, source_typ) in enumerate(source_sig.parameters) + if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) + push!(args, (cc = GPUCompiler.GHOST, typ = source_typ, arg_i = source_i)) + continue + end + if in(orig_i - 1, parmsRemoved) + push!(args, (cc = RemovedParam, typ = source_typ)) + orig_i += 1 + continue + end + codegen_typ = codegen_types[codegen_i] + + if codegen_typ isa LLVM.PointerType + llvm_source_typ = convert(LLVMType, source_typ; allow_boxed = true) + # pointers are used for multiple kinds of arguments + # - literal pointer values + if source_typ <: Ptr || source_typ <: Core.LLVMPtr + @assert llvm_source_typ == codegen_typ + push!( + args, + ( + cc = GPUCompiler.BITS_VALUE, + typ = source_typ, + arg_i = source_i, + codegen = (typ = codegen_typ, i = codegen_i), + ), + ) + # - boxed values + # XXX: use `deserves_retbox` instead? + elseif llvm_source_typ isa LLVM.PointerType + @assert llvm_source_typ == codegen_typ + push!( + args, + ( + cc = GPUCompiler.MUT_REF, + typ = source_typ, + arg_i = source_i, + codegen = (typ = codegen_typ, i = codegen_i), + ), + ) + # - references to aggregates + else + @assert llvm_source_typ != codegen_typ + push!( + args, + ( + cc = GPUCompiler.BITS_REF, + typ = source_typ, + arg_i = source_i, + codegen = (typ = codegen_typ, i = codegen_i), + ), + ) + end + else + push!( + args, + ( + cc = GPUCompiler.BITS_VALUE, + typ = source_typ, + arg_i = source_i, + codegen = (typ = codegen_typ, i = codegen_i), + ), + ) + end + + codegen_i += 1 + orig_i += 1 + end + + return args +end + +# https://github.com/JuliaLang/julia/blob/64378db18b512677fc6d3b012e6d1f02077af191/src/cgutils.cpp#L823 +# returns if all unboxed +function for_each_uniontype_small(@nospecialize(f), @nospecialize(ty::Type), counter::Base.RefValue{Int} = Ref(0)) + if counter[] > 127 + return false + end + if ty isa Union + allunbox = for_each_uniontype_small(f, ty.a, counter) + allunbox &= for_each_uniontype_small(f, ty.b, counter) + return allunbox + end + # https://github.com/JuliaLang/julia/blob/170d6439445c86e640214620dad3423d2bb42337/src/codegen.cpp#L1233 + if Base.isconcretetype(ty) && !ismutabletype(ty) && Base.datatype_pointerfree(ty) + counter[] += 1 + f(ty) + return true + end + return false +end + +# From https://github.com/JuliaLang/julia/blob/038d31463f0ef744c8308bdbe87339b9c3f0b890/src/cgutils.cpp#L3108 +function union_alloca_type(@nospecialize(UT::Type)) + nbytes = 0 + function inner(@nospecialize(jlrettype::Type)) + if !(Base.issingletontype(jlrettype) && isa(jlrettype, DataType)) + nbytes = max(nbytes, sizeof(jlrettype)) + end + end + for_each_uniontype_small(inner, UT) + return nbytes +end + +# From https://github.com/JuliaLang/julia/blob/e6bf81f39a202eedc7bd4f310c1ab60b5b86c251/src/codegen.cpp#L6447 +function is_sret(@nospecialize(jlrettype::Type)) + if jlrettype === Union{} + # jlrettype == (jl_value_t*)jl_bottom_type + return false + elseif Base.isstructtype(jlrettype) && + Base.issingletontype(jlrettype) && + isa(jlrettype, DataType) + # jl_is_structtype(jlrettype) && jl_is_datatype_singleton((jl_datatype_t*)jlrettype) + return false + elseif jlrettype isa Union # jl_is_uniontype(jlrettype) + if union_alloca_type(jlrettype) > 0 + # sret, also a regular return here + return true + end + return false + elseif !GPUCompiler.deserves_retbox(jlrettype) + rt = convert(LLVMType, jlrettype) + if !isa(rt, LLVM.VoidType) && GPUCompiler.deserves_sret(jlrettype, rt) + return true + end + end + return false +end +function is_sret_union(@nospecialize(jlrettype::Type)) + if jlrettype === Union{} + # jlrettype == (jl_value_t*)jl_bottom_type + return false + elseif Base.isstructtype(jlrettype) && + Base.issingletontype(jlrettype) && + isa(jlrettype, DataType) + # jl_is_structtype(jlrettype) && jl_is_datatype_singleton((jl_datatype_t*)jlrettype) + return false + elseif jlrettype isa Union # jl_is_uniontype(jlrettype) + if union_alloca_type(jlrettype) > 0 + # sret, also a regular return here + return true + end + end + return false +end + +# https://github.com/JuliaLang/julia/blob/0a696a3842750fcedca8832bc0aabe9096c7658f/src/codegen.cpp#L6812 +function get_return_info( + @nospecialize(jlrettype::Type), +)::Tuple{Union{Nothing,Type},Union{Nothing,Type},Union{Nothing,Type}} + sret = nothing + returnRoots = nothing + rt = nothing + if jlrettype === Union{} + rt = Nothing + elseif Base.isstructtype(jlrettype) && + Base.issingletontype(jlrettype) && + isa(jlrettype, DataType) + rt = Nothing + elseif jlrettype isa Union + nbytes = 0 + allunbox = for_each_uniontype_small(jlrettype) do jlrettype + if !(Base.issingletontype(jlrettype) && isa(jlrettype, DataType)) + nbytes = max(nbytes, sizeof(jlrettype)) + end + end + if nbytes != 0 + rt = NamedTuple{(Symbol("1"), Symbol("2")),Tuple{Any,UInt8}} + # Pointer to?, Ptr{NTuple{UInt8, allunbox} + sret = Ptr{jlrettype} + elseif allunbox + rt = UInt8 + else + rt = Any + end + elseif jlrettype <: Tuple && in(Any, jlrettype.parameters) + rt = Any + elseif !GPUCompiler.deserves_retbox(jlrettype) + lRT = convert(LLVMType, jlrettype) + if !isa(lRT, LLVM.VoidType) && GPUCompiler.deserves_sret(jlrettype, lRT) + sret = Ptr{jlrettype} + tracked = CountTrackedPointers(lRT) + @assert !tracked.derived + if tracked.count != 0 && !tracked.all + returnRoots = Ptr{AnyArray(Int(tracked.count))} + end + else + rt = jlrettype + end + else + # retbox + rt = Ptr{jlrettype} + end + + return (rt, sret, returnRoots) +end + +# From https://github.com/JuliaLang/julia/blob/81813164963f38dcd779d65ecd222fad8d7ed437/src/cgutils.cpp#L570 +@inline function isghostty(@nospecialize(ty)) + if ty === Union{} + return true + end + if Base.isconcretetype(ty) && !ismutabletype(ty) + if sizeof(ty) == 0 + return true + end + # TODO consider struct_to_llvm ? + end + return false +end + +struct Tape{TapeTy,ShadowTy,ResT} + internal_tape::TapeTy + shadow_return::ShadowTy +end + + +@inline any_jltypes(::Type{Nothing}) = false +@inline any_jltypes(::Type{T}) where {T<:AbstractFloat} = false +@inline any_jltypes(::Type{T}) where {T<:Integer} = false +@inline any_jltypes(::Type{Complex{T}}) where {T} = any_jltypes(T) +@inline any_jltypes(::Type{Tuple{}}) = false +@inline any_jltypes(::Type{NTuple{Size,T}}) where {Size,T} = any_jltypes(T) +@inline any_jltypes(::Type{Core.LLVMPtr{T,Addr}}) where {T,Addr} = 10 <= Addr <= 12 +@inline any_jltypes(::Type{Any}) = true +@inline any_jltypes(::Type{NamedTuple{A,B}}) where {A,B} = + any(any_jltypes(b) for b in B.parameters) +@inline any_jltypes(::Type{T}) where {T<:Tuple} = any(any_jltypes(b) for b in T.parameters) + +const WideIntWidths = [256, 512, 1024, 2048] + +let + for n ∈ WideIntWidths + let T = Symbol(:UInt, n) + eval(quote + primitive type $T <: Unsigned $n end + end) + end + end +end + +function jl_set_typeof(v::Ptr{Cvoid}, @nospecialize(T::Type)) + tag = reinterpret(Ptr{Any}, reinterpret(UInt, v) - 8) + Base.unsafe_store!(tag, T) # set tag + return nothing +end + +@generated function splatnew(::Type{T}, args::TT) where {T,TT<:Tuple} + return quote + Base.@_inline_meta + $(Expr(:splatnew, :T, :args)) + end +end + +@inline remove_innerty(::Type{<:Const}) = Const +@inline remove_innerty(::Type{<:Active}) = Active +@inline remove_innerty(::Type{<:Duplicated}) = Duplicated +@inline remove_innerty(::Type{<:DuplicatedNoNeed}) = DuplicatedNoNeed +@inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated +@inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed +@inline remove_innerty(::Type{<:MixedDuplicated}) = MixedDuplicated +@inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated diff --git a/src/typeutils/lltypes.jl b/src/typeutils/lltypes.jl new file mode 100644 index 0000000000..be6bb126c7 --- /dev/null +++ b/src/typeutils/lltypes.jl @@ -0,0 +1,200 @@ +function isSpecialPtr(@nospecialize(Ty::LLVM.LLVMType)) + if !isa(Ty, LLVM.PointerType) + return false + end + AS = LLVM.addrspace(Ty) + return 10 <= AS && AS <= 13 +end + +mutable struct CountTrackedPointers + count::UInt + all::Bool + derived::Bool +end + +function CountTrackedPointers(@nospecialize(T::LLVM.LLVMType)) + res = CountTrackedPointers(0, true, false) + + if isa(T, LLVM.PointerType) + if isSpecialPtr(T) + res.count += 1 + if LLVM.addrspace(T) != Tracked + res.derived = true + end + end + elseif isa(T, LLVM.StructType) + for ElT in elements(T) + sub = CountTrackedPointers(ElT) + res.count += sub.count + res.all &= sub.all + res.derived |= sub.derived + end + elseif isa(T, LLVM.ArrayType) + sub = CountTrackedPointers(eltype(T)) + res.count += sub.count + res.all &= sub.all + res.derived |= sub.derived + res.count *= length(T) + elseif isa(T, LLVM.VectorType) + sub = CountTrackedPointers(eltype(T)) + res.count += sub.count + res.all &= sub.all + res.derived |= sub.derived + res.count *= size(T) + end + if res.count == 0 + res.all = false + end + return res +end + +# must deserve sret +function deserves_rooting(@nospecialize(T::LLVM.LLVMType)) + tracked = CountTrackedPointers(T) + @assert !tracked.derived + if tracked.count != 0 && !tracked.all + return true # tracked.count; + end + return false +end + + +function any_jltypes(Type::LLVM.PointerType) + if 10 <= LLVM.addrspace(Type) <= 12 + return true + else + # do we care about {} addrspace(11)** + return false + end +end + +any_jltypes(Type::LLVM.StructType) = any(any_jltypes, LLVM.elements(Type)) +any_jltypes(Type::Union{LLVM.VectorType,LLVM.ArrayType}) = any_jltypes(eltype(Type)) +any_jltypes(::LLVM.IntegerType) = false +any_jltypes(::LLVM.FloatingPointType) = false +any_jltypes(::LLVM.VoidType) = false + +nfields(Type::LLVM.StructType) = length(LLVM.elements(Type)) +nfields(Type::LLVM.VectorType) = size(Type) +nfields(Type::LLVM.ArrayType) = length(Type) +nfields(Type::LLVM.PointerType) = 1 + +function store_nonjl_types!(B::LLVM.IRBuilder, @nospecialize(startval::LLVM.Value), @nospecialize(p::LLVM.Value)) + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + vals = LLVM.Value[] + if p != nothing + push!(vals, p) + end + todo = Tuple{Tuple,LLVM.Value}[((), startval)] + while length(todo) != 0 + path, cur = popfirst!(todo) + ty = value_type(cur) + if isa(ty, LLVM.PointerType) + if any_jltypes(ty) + continue + end + end + if isa(ty, LLVM.ArrayType) + if any_jltypes(ty) + for i = 1:length(ty) + ev = extract_value!(B, cur, i - 1) + push!(todo, ((path..., i - 1), ev)) + end + continue + end + end + if isa(ty, LLVM.StructType) + if any_jltypes(ty) + for (i, t) in enumerate(LLVM.elements(ty)) + ev = extract_value!(B, cur, i - 1) + push!(todo, ((path..., i - 1), ev)) + end + continue + end + end + parray = LLVM.Value[LLVM.ConstantInt(LLVM.IntType(64), 0)] + for v in path + push!(parray, LLVM.ConstantInt(LLVM.IntType(32), v)) + end + gptr = gep!(B, value_type(startval), p, parray) + st = store!(B, cur, gptr) + end + return +end + +function get_julia_inner_types(B::LLVM.IRBuilder, @nospecialize(p::Union{Nothing, LLVM.Value}), @nospecialize(startvals::Vararg{LLVM.Value}); added = LLVM.API.LLVMValueRef[]) + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + vals = LLVM.Value[] + if p != nothing + push!(vals, p) + end + todo = LLVM.Value[startvals...] + while length(todo) != 0 + cur = popfirst!(todo) + ty = value_type(cur) + if isa(ty, LLVM.PointerType) + if any_jltypes(ty) + if addrspace(ty) != Tracked + cur = addrspacecast!( + B, + cur, + LLVM.PointerType(eltype(ty), Tracked), + LLVM.name(cur) * ".innertracked", + ) + if isa(cur, LLVM.Instruction) + push!(added, cur.ref) + end + end + if value_type(cur) != T_prjlvalue + cur = bitcast!(B, cur, T_prjlvalue) + if isa(cur, LLVM.Instruction) + push!(added, cur.ref) + end + end + push!(vals, cur) + end + continue + end + if isa(ty, LLVM.ArrayType) + if any_jltypes(ty) + for i = 1:length(ty) + ev = extract_value!(B, cur, i - 1) + if isa(ev, LLVM.Instruction) + push!(added, ev.ref) + end + push!(todo, ev) + end + end + continue + end + if isa(ty, LLVM.StructType) + for (i, t) in enumerate(LLVM.elements(ty)) + if any_jltypes(t) + ev = extract_value!(B, cur, i - 1) + if isa(ev, LLVM.Instruction) + push!(added, ev.ref) + end + push!(todo, ev) + end + end + continue + end + if isa(ty, LLVM.IntegerType) + continue + end + if isa(ty, LLVM.FloatingPointType) + continue + end + msg = sprint() do io + println(io, "Enzyme illegal subtype") + println(io, "ty=", ty) + println(io, "cur=", cur) + println(io, "p=", p) + println(io, "startvals=", startvals) + end + throw(AssertionError(msg)) + end + return vals +end diff --git a/src/utils.jl b/src/utils.jl index 8fc1ce4962..0f92fc1f5d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -175,7 +175,7 @@ if VERSION >= v"1.11.0-DEV.1552" const prevmethodinstance = GPUCompiler.generic_methodinstance -function methodinstance_generator(world::UInt, source, self, ft::Type, tt::Type) +function methodinstance_generator(world::UInt, source, self, @nospecialize(ft::Type), @nospecialize(tt::Type)) @nospecialize @assert Core.Compiler.isType(ft) && Core.Compiler.isType(tt) ft = ft.parameters[1] @@ -225,19 +225,19 @@ function methodinstance_generator(world::UInt, source, self, ft::Type, tt::Type) return new_ci end -@eval function prevmethodinstance(ft, tt) +@eval function prevmethodinstance(ft, tt)::Core.MethodInstance $(Expr(:meta, :generated_only)) $(Expr(:meta, :generated, methodinstance_generator)) end # XXX: version of Base.method_instance that uses a function type @inline function my_methodinstance(@nospecialize(ft::Type), @nospecialize(tt::Type), - world::Integer=tls_world_age()) + world::Integer=tls_world_age())::Core.MethodInstance sig = GPUCompiler.signature_type_by_tt(ft, tt) if Base.isdispatchtuple(sig) # JuliaLang/julia#52233 - return GPUCompiler.methodinstance(ft, tt, world) + return GPUCompiler.methodinstance(ft, tt, world)::Core.MethodInstance else - return prevmethodinstance(ft, tt, world) + return prevmethodinstance(ft, tt, world)::Core.MethodInstance end end else From 79678f7a93fd1a65ffc633ed132baf8d33a1b4f8 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 1 Dec 2024 16:14:03 -0500 Subject: [PATCH 467/495] Add no recur to inference (#2153) --- src/compiler/interpreter.jl | 56 ++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 1e442482be..7b48acf271 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -806,13 +806,13 @@ function abstract_call_known( [:(Enzyme.Compiler.Interpreter.override_bc_materialize), fargs[2:end]...], [Core.Const(Enzyme.Compiler.Interpreter.override_bc_materialize), argtypes[2:end]...], ) - return abstract_call_known( - interp, - Enzyme.Compiler.Interpreter.override_bc_materialize, - arginfo2, - si, - sv, - max_methods, + return Base.@invoke abstract_call_known( + interp::AbstractInterpreter, + Enzyme.Compiler.Interpreter.override_bc_materialize::Any, + arginfo2::ArgInfo, + si::StmtInfo, + sv::AbsIntState, + max_methods::Int, ) end end @@ -830,13 +830,13 @@ function abstract_call_known( [:(Enzyme.Compiler.Interpreter.override_bc_copyto!), fargs[2:end]...], [Core.Const(Enzyme.Compiler.Interpreter.override_bc_copyto!), argtypes[2:end]...], ) - return abstract_call_known( - interp, - Enzyme.Compiler.Interpreter.override_bc_copyto!, - arginfo2, - si, - sv, - max_methods, + return Base.@invoke abstract_call_known( + interp::AbstractInterpreter, + Enzyme.Compiler.Interpreter.override_bc_copyto!::Any, + arginfo2::ArgInfo, + si::StmtInfo, + sv::AbsIntState, + max_methods::Int, ) end end @@ -854,13 +854,13 @@ function abstract_call_known( [:(Enzyme.Compiler.Interpreter.myunsafe_copyto!), fargs[2:end]...], [Core.Const(Enzyme.Compiler.Interpreter.myunsafe_copyto!), argtypes[2:end]...], ) - return abstract_call_known( - interp, - Enzyme.Compiler.Interpreter.myunsafe_copyto!, - arginfo2, - si, - sv, - max_methods, + return Base.@invoke abstract_call_known( + interp::AbstractInterpreter, + Enzyme.Compiler.Interpreter.myunsafe_copyto!::Any, + arginfo2::ArgInfo, + si::StmtInfo, + sv::AbsIntState, + max_methods::Int, ) end end @@ -874,13 +874,13 @@ function abstract_call_known( [:(Enzyme.autodiff_deferred), fargs[2:end]...], [Core.Const(Enzyme.autodiff_deferred), argtypes[2:end]...], ) - return abstract_call_known( - interp, - Enzyme.autodiff_deferred, - arginfo2, - si, - sv, - max_methods, + return Base.@invoke abstract_call_known( + interp::AbstractInterpreter, + Enzyme.autodiff_deferred::Any, + arginfo2::ArgInfo, + si::StmtInfo, + sv::AbsIntState, + max_methods::Int, ) end end From 31055b5de3ea86968a35d99bda987dcd9caf978f Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 1 Dec 2024 17:36:19 -0500 Subject: [PATCH 468/495] Further reduce recursive inference (#2154) * Further reduce recursive inference * fix --- src/compiler/interpreter.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 7b48acf271..ff4c6c991b 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -797,10 +797,6 @@ function abstract_call_known( if f === Base.materialize && length(argtypes) == 2 bcty = widenconst(argtypes[2]) if Base.isconcretetype(bcty) && bcty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing} && all(array_or_number, bcty.parameters[4].parameters) && any(Base.Fix2(Base.:<:, AbstractArray), bcty.parameters[4].parameters) - fnty = bcty.parameters[3] - eltys = map(num_or_eltype, bcty.parameters[4].parameters) - retty = Core.Compiler._return_type(interp, Tuple{fnty, eltys...}) - if Base.isconcretetype(retty) arginfo2 = ArgInfo( fargs isa Nothing ? nothing : [:(Enzyme.Compiler.Interpreter.override_bc_materialize), fargs[2:end]...], @@ -814,7 +810,6 @@ function abstract_call_known( sv::AbsIntState, max_methods::Int, ) - end end end From 611dda9a55741c4db95a806f820f8eb4c62d7030 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 1 Dec 2024 17:36:44 -0500 Subject: [PATCH 469/495] vararg type (#2150) * vararg type * Update compiler.jl --- src/compiler.jl | 64 +++++++++++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 0db62b6d48..5eccd1e80d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -112,29 +112,32 @@ const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,Data Tys = (Float32, Float64) if length(sparam_vals) == arity - T = first(sparam_vals)::Type - legal = T ∈ Tys - - if legal - if name == :ldexp - if !(sparam_vals[2] <: Integer) - legal = false - end - elseif name == :pow - if sparam_vals[2] <: Integer - name = :powi - elseif sparam_vals[2] != T - legal = false - end - elseif name == :jl_rem2pi - else - if !all(==(T), sparam_vals) - legal = false + T = first(sparam_vals) + if (T isa Type) + T = T::Type + legal = T ∈ Tys + + if legal + if name == :ldexp + if !(sparam_vals[2] <: Integer) + legal = false + end + elseif name == :pow + if sparam_vals[2] <: Integer + name = :powi + elseif sparam_vals[2] != T + legal = false + end + elseif name == :jl_rem2pi + else + if !all(==(T), sparam_vals) + legal = false + end end end - end - if legal - return name, toinject, T + if legal + return name, toinject, T + end end end end @@ -144,15 +147,18 @@ const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,Data Tys = (Complex{Float32}, Complex{Float64}) if length(sparam_vals) == arity T = first(sparam_vals) - legal = T ∈ Tys - - if legal - if !all(==(T), sparam_vals) - legal = false + if (T isa Type) + T = T::Type + legal = T ∈ Tys + + if legal + if !all(==(T), sparam_vals) + legal = false + end + end + if legal + return name, toinject, T end - end - if legal - return name, toinject, T end end end From ab705ac56bb18ac6bd2fd7bcdbd33e6c7505a803 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 1 Dec 2024 17:37:07 -0500 Subject: [PATCH 470/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b5d14f19ad..0a8ef9e338 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.17" +version = "0.13.18" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 4c6e7ccc778c459dec3c5f2e56a4b530dabdfcab Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 1 Dec 2024 20:04:56 -0500 Subject: [PATCH 471/495] Absint through sret unions (#2155) * Absint through sret unions * Update Project.toml --- src/absint.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/absint.jl b/src/absint.jl index 169db8965f..3b9034bef6 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -652,6 +652,13 @@ function abs_typeof( break end cnt += 1 + if Enzyme.Compiler.is_sret_union(styp) + if cnt == ind + typ = UInt8 + break + end + cnt += 1 + end end end if Base.allocatedinline(typ) From 3ad827f69299299b92a1448f52dd746a65eb5db7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 2 Dec 2024 00:10:06 -0500 Subject: [PATCH 472/495] Don't use ref when unnnecessary for gradient sugar (#2156) * Don't use ref when unnnecessary for gradient sugar * fix * fix * fix * fix --- src/sugar.jl | 108 ++++++++++++++++++++++++++++----------------------- 1 file changed, 59 insertions(+), 49 deletions(-) diff --git a/src/sugar.jl b/src/sugar.jl index 3e68830100..b93b7fb0eb 100644 --- a/src/sugar.jl +++ b/src/sugar.jl @@ -254,19 +254,15 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) ``` """ +# TODO eventually add an invalidation edge here from inactive_type @generated function gradient( rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::ty_0, args::Vararg{Any,N}, ) where {F,ty_0,ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten,N} - toemit = Expr[quote - act_0 = - !(x isa Enzyme.Const) && - Compiler.active_reg_inner(Core.Typeof(x), (), nothing, Val(true)) == - Compiler.ActiveState #=justActive=# - end] rargs = Union{Symbol,Expr}[:x] + gentys = Type[x] acts = Symbol[Symbol("act_0")] for i = 1:N @@ -276,55 +272,69 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) push!(rargs, argidx) sym = Symbol("act_$i") push!(acts, sym) - push!( - toemit, - quote - $sym = - !($argidx isa Enzyme.Const) && - Compiler.active_reg_inner( - Core.Typeof($argidx), - (), - nothing, - Val(true), - ) == Compiler.ActiveState #=justActive=# - end, - ) + push!(gentys, args[i]) + end + + toemit = Expr[] + states = Compiler.ActivityState[] + + for (argidx, act, genty) in zip(rargs, acts, gentys) + if genty <: Enzyme.Const + push!( + toemit, + quote + $act = false + end + ) + push!(states, Compiler.AnyState) + else + state = Compiler.active_reg_inner(genty, (), nothing) + push!(states, state) + end end idx = 0 - shadows = Symbol[] - enz_args = Expr[] - resargs = Expr[] - for (arg, act) in zip(rargs, acts) - shad = Symbol("shad_$idx") - push!(shadows, shad) - push!(toemit, quote - $shad = if $arg isa Enzyme.Const - nothing - elseif $act - Ref(make_zero($arg)) - else - make_zero($arg) - end - end) - push!(enz_args, quote - if $arg isa Enzyme.Const - $arg - elseif $act + enz_args = Union{Expr,Symbol}[] + resargs = Union{Expr,Symbol}[] + for (i, (arg, act, state, genty)) in enumerate(zip(rargs, acts, states, gentys)) + shad = Symbol("shad_$i") + if genty <: Enzyme.Const + push!(enz_args, arg) + push!(resargs, :nothing) + elseif state == Compiler.MixedState + push!(toemit, quote + $shad = Ref(make_zero($arg)) + end) + push!(enz_args, quote MixedDuplicated($arg, $shad) - else - Duplicated($arg, $shad) - end - end) - push!(resargs, quote - if $arg isa Enzyme.Const - nothing - elseif $act + end) + push!(resargs, quote $shad[] - else + end) + elseif state == Compiler.DupState + push!(toemit, quote + $shad = make_zero($arg) + end) + push!(enz_args, quote + Duplicated($arg, $shad) + end) + push!(resargs, quote $shad - end - end) + end) + elseif state == Compiler.ActiveState + push!(enz_args, quote + Active($arg) + end) + push!(resargs, quote + res[1][$i] + end) + else + @assert state == Compiler.AnyState + push!(enz_args, quote + Const($arg) + end) + push!(resargs, :nothing) + end idx += 1 end push!(toemit, quote From 358d6475ef12965c4052b34bdc3f4183e5471f44 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Dec 2024 19:33:48 -0600 Subject: [PATCH 473/495] Precompilation is cool, we should do more of it (#2160) * Precompilation is cool, we should do more of it * fix * tm stuff * ix attempt * reset * more * ix * reduce * fix --- Project.toml | 1 + src/Enzyme.jl | 2 ++ src/compiler/orcv2.jl | 18 ++++++++++++------ src/precompile.jl | 13 +++++++++++++ 4 files changed, 28 insertions(+), 6 deletions(-) create mode 100644 src/precompile.jl diff --git a/Project.toml b/Project.toml index 0a8ef9e338..7a2040f9ef 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ObjectFile = "d8793406-e978-5875-9003-1fc021f44a92" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 942df0581c..92ec9a623f 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1587,4 +1587,6 @@ Returns true if within autodiff, otherwise false. """ @inline EnzymeCore.within_autodiff() = false +include("precompile.jl") + end # module diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index 7588eddb78..1640b05db2 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -83,7 +83,7 @@ function define_absolute_symbol(jd, name) return false end -function __init__() +function setup_globals() opt_level = Base.JLOptions().opt_level if opt_level < 2 optlevel = LLVM.API.LLVMCodeGenLevelNone @@ -105,11 +105,6 @@ function __init__() dg = LLVM.CreateDynamicLibrarySearchGeneratorForProcess(prefix) LLVM.add!(jd_main, dg) - if Sys.iswindows() && Int === Int64 - # TODO can we check isGNU? - define_absolute_symbol(jd_main, mangle(lljit, "___chkstk_ms")) - end - es = ExecutionSession(lljit) try lctm = LLVM.LocalLazyCallThroughManager(triple(lljit), es) @@ -120,6 +115,17 @@ function __init__() jit[] = CompilerInstance(lljit, nothing, nothing) end + jd_main, lljit +end + +function __init__() + jd_main, lljit = setup_globals() + + if Sys.iswindows() && Int === Int64 + # TODO can we check isGNU? + define_absolute_symbol(jd_main, mangle(lljit, "___chkstk_ms")) + end + hnd = unsafe_load(cglobal(:jl_libjulia_handle, Ptr{Cvoid})) for (k, v) in Compiler.JuliaGlobalNameMap ptr = unsafe_load(Base.reinterpret(Ptr{Ptr{Cvoid}}, Libdl.dlsym(hnd, k))) diff --git a/src/precompile.jl b/src/precompile.jl new file mode 100644 index 0000000000..c20eaac149 --- /dev/null +++ b/src/precompile.jl @@ -0,0 +1,13 @@ +using PrecompileTools: @setup_workload, @compile_workload + +@setup_workload begin + precompile_module = @eval module $(gensym()) + f(x) = x^2 + end + + Compiler.JIT.setup_globals() + + @compile_workload begin + Enzyme.autodiff(Reverse, precompile_module.f, Active(2.0)) + end +end From b046156330b368373cac4b349716695aa2940af7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Dec 2024 21:23:23 -0600 Subject: [PATCH 474/495] Only mark writeonly if pointer abi (#2163) * Only mark writeonly if pointer abi * Update attributes.jl --- src/llvm/attributes.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/llvm/attributes.jl b/src/llvm/attributes.jl index 3dd5973421..efa090e7e6 100644 --- a/src/llvm/attributes.jl +++ b/src/llvm/attributes.jl @@ -831,8 +831,10 @@ function annotate!(mod::LLVM.Module) for fn in funcs[fname] push!(parameter_attributes(fn, 2), LLVM.StringAttribute("enzyme_inactive")) push!(parameter_attributes(fn, 4), LLVM.StringAttribute("enzyme_inactive")) - push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("writeonly")) - push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("nocapture")) + if value_type(LLVM.parameters(fn)[4]) isa LLVM.PointerType + push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("writeonly")) + push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("nocapture")) + end if LLVM.version().major <= 15 push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) else From 6e513f2d25d3ace7aa969035ff9fbe71d22ae5a3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 4 Dec 2024 09:52:41 -0600 Subject: [PATCH 475/495] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 7a2040f9ef..fb83e4263a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.18" +version = "0.13.19" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -37,7 +37,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.7" -Enzyme_jll = "0.0.166" +Enzyme_jll = "0.0.167" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From ae1634498d375febb31bd06944d57a50a6b39c7e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 4 Dec 2024 19:24:10 -0600 Subject: [PATCH 476/495] CompatHelper: add new compat entry for PrecompileTools at version 1, (keep existing compat) (#2167) Co-authored-by: CompatHelper Julia --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index fb83e4263a..d86475034a 100644 --- a/Project.toml +++ b/Project.toml @@ -42,6 +42,7 @@ GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" ObjectFile = "0.4" +PrecompileTools = "1" Preferences = "1.4" SparseArrays = "1" SpecialFunctions = "1, 2" From 5b8586229054872653cd56addecef62bc8cab82d Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 4 Dec 2024 19:25:47 -0600 Subject: [PATCH 477/495] Mark extract value types (#2166) --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 5eccd1e80d..78a266c78e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4148,7 +4148,7 @@ end if !API.HasFromStack(inst) && ((isa(inst, LLVM.CallInst) && - (!isa(fn, LLVM.Function) || isempty(blocks(fn))) ) || isa(inst, LLVM.LoadInst) || isa(inst, LLVM.AllocaInst)) + (!isa(fn, LLVM.Function) || isempty(blocks(fn))) ) || isa(inst, LLVM.LoadInst) || isa(inst, LLVM.AllocaInst) || isa(inst, LLVM.ExtractValueInst)) legal, source_typ, byref = abs_typeof(inst) codegen_typ = value_type(inst) if legal From 6606cd96184364cb7d39ae0e75259be5066b5630 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 5 Dec 2024 05:03:25 +0100 Subject: [PATCH 478/495] vc/fixup isapplicable use v2 (#2158) * Forward interp and sv to isapplicable * invalidation for inactive now works * move tfunc to compilers * fixup! move tfunc to compilers * fixup! fixup! move tfunc to compilers * add ephermal cache * bump versions --------- Co-authored-by: William S. Moses --- Project.toml | 2 +- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/rules.jl | 22 ++------------- src/compiler/interpreter.jl | 39 +++++++++++++++----------- src/compiler/tfunc.jl | 56 +++++++++++++++++++++++++++++++++++++ test/ruleinvalidation.jl | 8 ++---- 6 files changed, 85 insertions(+), 44 deletions(-) create mode 100644 src/compiler/tfunc.jl diff --git a/Project.toml b/Project.toml index d86475034a..46b37ccb12 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ EnzymeStaticArraysExt = "StaticArrays" BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.8.7" +EnzymeCore = "0.8.8" Enzyme_jll = "0.0.167" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 28f92d9055..da662c545a 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.7" +version = "0.8.8" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 945951b216..dc33c11110 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -171,7 +171,7 @@ end function has_frule_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} return isapplicable(forward, TT; world, method_table, caller) @@ -180,7 +180,7 @@ end function has_rrule_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} return isapplicable(augmented_primal, TT; world, method_table, caller) @@ -192,7 +192,7 @@ end function isapplicable(@nospecialize(f), @nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool tt = Base.to_tuple_type(TT) sig = Base.signature_type(f, tt) mt = ccall(:jl_method_table_for, Any, (Any,), sig) @@ -211,12 +211,6 @@ function isapplicable(@nospecialize(f), @nospecialize(TT); if !fullmatch if caller isa Core.MethodInstance add_mt_backedge!(caller, mt, sig) - elseif caller isa Core.Compiler.MethodLookupResult - for j = 1:Core.Compiler.length(caller) - cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch - cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance - add_mt_backedge!(cspec, mt, sig) - end end end if Core.Compiler.isempty(matches) @@ -228,16 +222,6 @@ function isapplicable(@nospecialize(f), @nospecialize(TT); edge = Core.Compiler.specialize_method(match)::Core.MethodInstance add_backedge!(caller, edge, sig) end - elseif caller isa Core.Compiler.MethodLookupResult - for j = 1:Core.Compiler.length(caller) - cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch - cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance - for i = 1:Core.Compiler.length(matches) - match = Core.Compiler.getindex(matches, i)::Core.MethodMatch - edge = Core.Compiler.specialize_method(match)::Core.MethodInstance - add_backedge!(cspec, edge, sig) - end - end end return true end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index ff4c6c991b..2d02604eda 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -40,6 +40,8 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter inf_params::InferenceParams opt_params::OptimizationParams + rules_cache::IdDict{Any, Bool} + forward_rules::Bool reverse_rules::Bool deferred_lower::Bool @@ -78,6 +80,7 @@ function EnzymeInterpreter( # parameters for inference and optimization parms, OptimizationParams(), + IdDict{Any, Bool}(), forward_rules, reverse_rules, deferred_lower, @@ -168,6 +171,8 @@ function simplify_kw(@nospecialize(specTypes)) end end +include("tfunc.jl") + import Core.Compiler: CallInfo struct NoInlineCallInfo <: CallInfo info::CallInfo # wrapped call @@ -192,6 +197,7 @@ Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult(info.info, idx) +import .EnzymeRules: FwdConfig, RevConfig, Annotation using Core.Compiler: ArgInfo, StmtInfo, AbsIntState function Core.Compiler.abstract_call_gf_by_type( @nospecialize(interp::EnzymeInterpreter), @@ -212,31 +218,30 @@ function Core.Compiler.abstract_call_gf_by_type( max_methods::Int, ) callinfo = ret.info - method_table = Core.Compiler.method_table(interp) specTypes = simplify_kw(atype) - caller = if callinfo isa Core.Compiler.MethodMatchInfo && callinfo.results isa Core.Compiler.MethodLookupResult - callinfo.results - else - nothing - end if is_primitive_func(specTypes) callinfo = NoInlineCallInfo(callinfo, atype, :primitive) elseif is_alwaysinline_func(specTypes) callinfo = AlwaysInlineCallInfo(callinfo, atype) - elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table, caller) - callinfo = NoInlineCallInfo(callinfo, atype, :inactive) else - if interp.forward_rules - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table, caller) - callinfo = NoInlineCallInfo(callinfo, atype, :frule) - end - end - - if interp.reverse_rules - if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table, caller) - callinfo = NoInlineCallInfo(callinfo, atype, :rrule) + # 1. Check if function is inactive + if is_inactive_from_sig(interp, specTypes, sv) + callinfo = NoInlineCallInfo(callinfo, atype, :inactive) + else + # 2. Check if rule is defined + has_rule = get!(interp.rules_cache, specTypes) do + if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv) + return true + elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv) + return true + else + return false + end end + if has_rule + callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule) + end end end @static if VERSION ≥ v"1.11-" diff --git a/src/compiler/tfunc.jl b/src/compiler/tfunc.jl new file mode 100644 index 0000000000..701cfa8107 --- /dev/null +++ b/src/compiler/tfunc.jl @@ -0,0 +1,56 @@ +import EnzymeCore: Annotation +import EnzymeCore.EnzymeRules: FwdConfig, RevConfig, forward, augmented_primal, inactive, _annotate_tt + +function has_frule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), + @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + ft, tt = _annotate_tt(TT) + TT = Tuple{<:FwdConfig,<:Annotation{ft},Type{<:Annotation},tt...} + return isapplicable(interp, forward, TT, sv) +end + +function has_rrule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), + @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + ft, tt = _annotate_tt(TT) + TT = Tuple{<:RevConfig,<:Annotation{ft},Type{<:Annotation},tt...} + return isapplicable(interp, augmented_primal, TT, sv) +end + + +function is_inactive_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), + @nospecialize(TT), sv::Core.Compiler.AbsIntState) + return isapplicable(interp, inactive, TT, sv) +end + +# `hasmethod` is a precise match using `Core.Compiler.findsup`, +# but here we want the broader query using `Core.Compiler.findall`. +# Also add appropriate backedges to the caller `MethodInstance` if given. +function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), + @nospecialize(f), @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + tt = Base.to_tuple_type(TT) + sig = Base.signature_type(f, tt) + mt = ccall(:jl_method_table_for, Any, (Any,), sig) + mt isa Core.MethodTable || return false + result = Core.Compiler.findall(sig, Core.Compiler.method_table(interp); limit=-1) + (result === nothing || result === missing) && return false + @static if isdefined(Core.Compiler, :MethodMatchResult) + (; matches) = result + else + matches = result + end + # also need an edge to the method table in case something gets + # added that did not intersect with any existing method + fullmatch = Core.Compiler._any(match::Core.MethodMatch -> match.fully_covers, matches) + if !fullmatch + Core.Compiler.add_mt_backedge!(sv, mt, sig) + end + if Core.Compiler.isempty(matches) + return false + else + for i = 1:Core.Compiler.length(matches) + match = Core.Compiler.getindex(matches, i)::Core.MethodMatch + edge = Core.Compiler.specialize_method(match)::Core.MethodInstance + Core.Compiler.add_backedge!(sv, edge) + end + return true + end +end \ No newline at end of file diff --git a/test/ruleinvalidation.jl b/test/ruleinvalidation.jl index 704ada2b6e..37cb21b08f 100644 --- a/test/ruleinvalidation.jl +++ b/test/ruleinvalidation.jl @@ -34,18 +34,14 @@ for m in methods(forward, Tuple{Any,Const{typeof(issue696)},Vararg{Any}}) end @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 @static if VERSION < v"1.11-" -@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 + @test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 else -@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 + @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 end # now test invalidation for `inactive` inactive(::typeof(issue696), args...) = nothing @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -@static if VERSION < v"1.11-" -@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -else @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -end end # module From f8bf821916e89c37d7c592d1291b200c003906bb Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 5 Dec 2024 17:24:33 -0600 Subject: [PATCH 479/495] Fix partial store (#2172) * Fix partial store * replace uses --- src/llvm/transforms.jl | 32 +++++++++++++++++ test/passes.jl | 80 ++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 113 insertions(+) create mode 100644 test/passes.jl diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl index b1a5aaafbc..aebb8bab5c 100644 --- a/src/llvm/transforms.jl +++ b/src/llvm/transforms.jl @@ -1539,6 +1539,13 @@ function propagate_returned!(mod::LLVM.Module) end else for u in LLVM.uses(un) + u = LLVM.user(u) + if u isa LLVM.CallInst + op = LLVM.called_operand(u) + if op isa LLVM.Function && LLVM.name(op) == "llvm.enzymefakeread" + continue + end + end hasAnyUse = true break end @@ -1611,6 +1618,12 @@ end function delete_writes_into_removed_args(fn::LLVM.Function, toremove::Vector{Int64}, keepret::Bool) args = collect(parameters(fn)) + if !keepret + for u in LLVM.uses(fn) + u = LLVM.user(u) + replace_uses!(u, LLVM.UndefValue(value_type(u))) + end + end for tr in toremove tr = tr + 1 todorep = Tuple{LLVM.Instruction, LLVM.Value}[] @@ -2038,6 +2051,25 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine) if isempty(blocks(fn)) continue end + + rt = LLVM.return_type(LLVM.function_type(fn)) + if rt isa LLVM.PointerType && addrspace(rt) == 10 + for u in LLVM.uses(fn) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + B = IRBuilder() + nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u)) + position!(B, nextInst) + cl = call!(B, funcT, rfunc, LLVM.Value[u]) + LLVM.API.LLVMAddCallSiteAttribute( + cl, + LLVM.API.LLVMAttributeIndex(1), + EnumAttribute("nocapture"), + ) + end + end + end + # Ensure that interprocedural optimizations do not delete the use of returnRoots (or shadows) # if inactive sret, this will only occur on 2. If active sret, inactive retRoot, can on 3, and # active both can occur on 4. If the original sret is removed (at index 1) we no longer need diff --git a/test/passes.jl b/test/passes.jl new file mode 100644 index 0000000000..d85ae822fb --- /dev/null +++ b/test/passes.jl @@ -0,0 +1,80 @@ +using Enzyme, LLVM, Test + + +@testset "Partial return preservation" begin + LLVM.Context() do ctx + mod = parse(LLVM.Module, """ + source_filename = "start" + target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:10:11:12:13" + target triple = "x86_64-linux-gnu" + + declare noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}**, i64, {} addrspace(10)*) local_unnamed_addr #5 + + define internal fastcc nonnull {} addrspace(10)* @inner({} addrspace(10)* %v1, {} addrspace(10)* %v2) { + top: + %newstruct = call noalias nonnull dereferenceable(16) {} addrspace(10)* @julia.gc_alloc_obj({}** null, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 129778359735376 to {}*) to {} addrspace(10)*)) #30 + %a31 = addrspacecast {} addrspace(10)* %newstruct to {} addrspace(10)* addrspace(11)* + %a32 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %a31, i64 1 + store atomic {} addrspace(10)* %v1, {} addrspace(10)* addrspace(11)* %a31 release, align 8 + %a33 = addrspacecast {} addrspace(10)* %newstruct to i8 addrspace(11)* + %a34 = getelementptr inbounds i8, i8 addrspace(11)* %a33, i64 8 + %a35 = bitcast i8 addrspace(11)* %a34 to {} addrspace(10)* addrspace(11)* + store atomic {} addrspace(10)* %v2, {} addrspace(10)* addrspace(11)* %a35 release, align 8 + ret {} addrspace(10)* %newstruct + } + + define {} addrspace(10)* @caller({} addrspace(10)* %v1, {} addrspace(10)* %v2) { + top: + %ac = call fastcc nonnull {} addrspace(10)* @inner({} addrspace(10)* %v1, {} addrspace(10)* %v2) + %b = addrspacecast {} addrspace(10)* %ac to {} addrspace(10)* addrspace(11)* + %c = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %b unordered, align 8 + ret {} addrspace(10)* %c + } + + attributes #5 = { inaccessiblememonly mustprogress nofree nounwind willreturn allockind("alloc,uninitialized") allocsize(1) "enzyme_no_escaping_allocation" "enzymejl_world"="31504" } + """) + + Enzyme.Compiler.removeDeadArgs!(mod, Enzyme.Compiler.JIT.get_tm()) + + callfn = LLVM.functions(mod)["inner"] + @test length(collect(filter(Base.Fix2(isa, LLVM.StoreInst), collect(instructions(first(blocks(callfn))))))) == 2 + end +end + + +@testset "Dead return removal" begin + LLVM.Context() do ctx + mod = parse(LLVM.Module, """ + source_filename = "start" + target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:10:11:12:13" + target triple = "x86_64-linux-gnu" + + declare noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}**, i64, {} addrspace(10)*) local_unnamed_addr #5 + + define internal fastcc nonnull {} addrspace(10)* @julia_MyPrognosticVars_161({} addrspace(10)* %v1, {} addrspace(10)* %v2) { + top: + %newstruct = call noalias nonnull dereferenceable(16) {} addrspace(10)* @julia.gc_alloc_obj({}** null, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 129778359735376 to {}*) to {} addrspace(10)*)) #30 + %a31 = addrspacecast {} addrspace(10)* %newstruct to {} addrspace(10)* addrspace(11)* + %a32 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %a31, i64 1 + store atomic {} addrspace(10)* %v1, {} addrspace(10)* addrspace(11)* %a31 release, align 8 + %a33 = addrspacecast {} addrspace(10)* %newstruct to i8 addrspace(11)* + %a34 = getelementptr inbounds i8, i8 addrspace(11)* %a33, i64 8 + %a35 = bitcast i8 addrspace(11)* %a34 to {} addrspace(10)* addrspace(11)* + store atomic {} addrspace(10)* %v2, {} addrspace(10)* addrspace(11)* %a35 release, align 8 + ret {} addrspace(10)* %newstruct + } + + define void @caller({} addrspace(10)* %v1, {} addrspace(10)* %v2) { + top: + %ac = call fastcc nonnull {} addrspace(10)* @julia_MyPrognosticVars_161({} addrspace(10)* %v1, {} addrspace(10)* %v2) + ret void + } + + attributes #5 = { inaccessiblememonly mustprogress nofree nounwind willreturn allockind("alloc,uninitialized") allocsize(1) "enzyme_no_escaping_allocation" "enzymejl_world"="31504" } + """) + + Enzyme.Compiler.removeDeadArgs!(mod, Enzyme.Compiler.JIT.get_tm()) + callfn = LLVM.functions(mod)["caller"] + @test length(collect(instructions(first(blocks(callfn))))) == 1 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 549978b894..fd31a13ea0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -73,6 +73,7 @@ end include("abi.jl") include("typetree.jl") +include("passes.jl") include("optimize.jl") include("make_zero.jl") From 2a447f471a4b4adfd6198fc39f75e6b7fb390493 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 5 Dec 2024 17:24:45 -0600 Subject: [PATCH 480/495] Add verbose error message toggle (#2173) --- src/errors.jl | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/errors.jl b/src/errors.jl index b48c34b54d..2099f6ed64 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -25,15 +25,21 @@ struct IllegalTypeAnalysisException <: CompilationException bt::Union{Nothing,Vector{StackTraces.StackFrame}} end +const VERBOSE_ERRORS = Ref(false) function Base.showerror(io::IO, ece::IllegalTypeAnalysisException) print(io, "Enzyme compilation failed due to illegal type analysis.\n") - if ece.ir !== nothing - print(io, "Current scope: \n") - print(io, ece.ir) + print(io, " This usually indicates the use of a Union type, which is not fully supported with Enzyme.API.strictAliasing set to true [the default].\n") + print(io, " Ideally, remove the union (which will also make your code faster), or try setting Enzyme.API.strictAliasing!(false) before any autodiff call.\n") + print(io, " To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)\n") + if VERBOSE_ERRORS[] + if ece.ir !== nothing + print(io, "Current scope: \n") + print(io, ece.ir) + end + print(io, "\n Type analysis state: \n") + write(io, ece.sval) + print(io, '\n', ece.msg, '\n') end - print(io, "\n Type analysis state: \n") - write(io, ece.sval) - print(io, '\n', ece.msg, '\n') if ece.bt !== nothing print(io, "\nCaused by:") Base.show_backtrace(io, ece.bt) From 0dc217beb221c3ba0f2fb9f90d5ff1b7864a4e0c Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 5 Dec 2024 19:55:59 -0600 Subject: [PATCH 481/495] Fix returns_twice attr (#2175) --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 78a266c78e..2f4cbfc8da 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3338,7 +3338,7 @@ function GPUCompiler.codegen( if !has_fn_attr(f, EnumAttribute("alwaysinline")) continue end - if !has_fn_attr(f, EnumAttribute("returnstwice")) + if !has_fn_attr(f, EnumAttribute("returns_twice")) push!(function_attributes(f), EnumAttribute("returns_twice")) push!(toremove, name(f)) end From e8f3a89801f1902973c93e2eb690de8f8fdb35d2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 6 Dec 2024 00:09:39 -0600 Subject: [PATCH 482/495] Further simplify error messages (#2178) * Further simplify error messages * even better --- src/errors.jl | 66 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 47 insertions(+), 19 deletions(-) diff --git a/src/errors.jl b/src/errors.jl index 2099f6ed64..4f7b6f9480 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -1,4 +1,17 @@ +const VERBOSE_ERRORS = Ref(false) + abstract type CompilationException <: Base.Exception end + +struct EnzymeRuntimeException <: Base.Exception + msg::Cstring +end + +function Base.showerror(io::IO, ece::EnzymeRuntimeException) + print(io, "Enzyme execution failed.\n") + msg = Base.unsafe_string(ece.msg) + print(io, msg, '\n') +end + struct NoDerivativeException <: CompilationException msg::String ir::Union{Nothing,String} @@ -8,10 +21,22 @@ end function Base.showerror(io::IO, ece::NoDerivativeException) print(io, "Enzyme compilation failed.\n") if ece.ir !== nothing - print(io, "Current scope: \n") - print(io, ece.ir) + if VERBOSE_ERRORS[] + print(io, "Current scope: \n") + print(io, ece.ir) + else + print(io, " To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)\n") + end + end + if occursin("cannot handle unknown binary operator", ece.msg) + for msg in ece.msg.split('\n') + if occursin("cannot handle unknown binary operator", msg) + print('\n', msg, '\n') + end + end + else + print(io, '\n', ece.msg, '\n') end - print(io, '\n', ece.msg, '\n') if ece.bt !== nothing Base.show_backtrace(io, ece.bt) println(io) @@ -25,7 +50,6 @@ struct IllegalTypeAnalysisException <: CompilationException bt::Union{Nothing,Vector{StackTraces.StackFrame}} end -const VERBOSE_ERRORS = Ref(false) function Base.showerror(io::IO, ece::IllegalTypeAnalysisException) print(io, "Enzyme compilation failed due to illegal type analysis.\n") print(io, " This usually indicates the use of a Union type, which is not fully supported with Enzyme.API.strictAliasing set to true [the default].\n") @@ -54,10 +78,14 @@ struct IllegalFirstPointerException <: CompilationException end function Base.showerror(io::IO, ece::IllegalFirstPointerException) - print(io, "Enzyme compilation failed.\n") - if ece.ir !== nothing + print(io, "Enzyme compilation failed due to an internal error (first pointer exception).\n") + print(io, " Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl") + print(io, " To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)\n") + if VERBOSE_ERRORS[] + if ece.ir !== nothing print(io, "Current scope: \n") print(io, ece.ir) + end end print(io, '\n', ece.msg, '\n') if ece.bt !== nothing @@ -73,28 +101,28 @@ struct EnzymeInternalError <: CompilationException end function Base.showerror(io::IO, ece::EnzymeInternalError) - print(io, "Enzyme compilation failed.\n") - if ece.ir !== nothing + print(io, "Enzyme compilation failed due to an internal error.\n") + print(io, " Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl") + print(io, " To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)\n") + if VERBOSE_ERRORS[] + if ece.ir !== nothing print(io, "Current scope: \n") print(io, ece.ir) + end + print(io, '\n', ece.msg, '\n') + else + for msg in ece.msg.split('\n') + if occursin("Illegal replace ficticious phi for", msg) + print('\n', msg, '\n') + end + end end - print(io, '\n', ece.msg, '\n') if ece.bt !== nothing Base.show_backtrace(io, ece.bt) println(io) end end -struct EnzymeRuntimeException <: Base.Exception - msg::Cstring -end - -function Base.showerror(io::IO, ece::EnzymeRuntimeException) - print(io, "Enzyme execution failed.\n") - msg = Base.unsafe_string(ece.msg) - print(io, msg, '\n') -end - struct EnzymeMutabilityException <: Base.Exception msg::Cstring end From 81783453cf5b886bc4396ff9205c3d0c0868271c Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 6 Dec 2024 08:17:05 -0600 Subject: [PATCH 483/495] More x86 orcv2 (#2177) * More x86 orcv2 * Update orcv2.jl * Update orcv2.jl * Update validation.jl * Update orcv2.jl * Update validation.jl --- src/compiler/validation.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index e90f7d0712..e109415d0f 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -753,7 +753,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), arg1, fname, - reinterpret(Ptr{Cvoid}, JIT.lookup(hnd).ptr), + pointer(JIT.lookup(hnd)), ) else res = ccall( @@ -762,7 +762,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), arg1, fname, - reinterpret(Ptr{Cvoid}, JIT.lookup(hnd).ptr), + pointer(JIT.lookup(hnd)), ) end replaceWith = LLVM.ConstantInt( From 4f160a0de183a88eaacdff1944bd6dc5a509d72a Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 6 Dec 2024 10:06:25 -0600 Subject: [PATCH 484/495] Complex bessel (#2179) * Complex bessel * fix * more tests * Update EnzymeSpecialFunctionsExt.jl * Update Project.toml --- Project.toml | 2 +- ext/EnzymeSpecialFunctionsExt.jl | 3 +++ src/compiler.jl | 5 +++++ test/ext/specialfunctions.jl | 10 ++++------ 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 46b37ccb12..fa5a2c7e99 100644 --- a/Project.toml +++ b/Project.toml @@ -37,7 +37,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.8" -Enzyme_jll = "0.0.167" +Enzyme_jll = "0.0.168" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" diff --git a/ext/EnzymeSpecialFunctionsExt.jl b/ext/EnzymeSpecialFunctionsExt.jl index 65d87dc118..09e62e98c4 100644 --- a/ext/EnzymeSpecialFunctionsExt.jl +++ b/ext/EnzymeSpecialFunctionsExt.jl @@ -5,6 +5,9 @@ using Enzyme function __init__() Enzyme.Compiler.known_ops[typeof(SpecialFunctions._logabsgamma)] = (:logabsgamma, 1, (:digamma, typeof(SpecialFunctions.digamma))) + Enzyme.Compiler.cmplx_known_ops[typeof(SpecialFunctions.bessely)] = (:cmplx_jn, 2, nothing) + Enzyme.Compiler.cmplx_known_ops[typeof(SpecialFunctions.besselj)] = (:cmplx_jn, 2, nothing) + Enzyme.Compiler.cmplx_known_ops[typeof(SpecialFunctions.besselk)] = (:cmplx_kn, 2, nothing) end end diff --git a/src/compiler.jl b/src/compiler.jl index 2f4cbfc8da..5fcc53dbde 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -146,6 +146,11 @@ const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,Data name, arity, toinject = cmplx_known_ops[func] Tys = (Complex{Float32}, Complex{Float64}) if length(sparam_vals) == arity + if name == :cmplx_jn || name == :cmplx_yn + if (sparam_vals[2] ∈ Tys) && sparam_vals[2].parameters[1] == sparam_vals[1] + return name, toinject, sparam_vals[2] + end + end T = first(sparam_vals) if (T isa Type) T = T::Type diff --git a/test/ext/specialfunctions.jl b/test/ext/specialfunctions.jl index 1a87cf2d2b..a64c214489 100644 --- a/test/ext/specialfunctions.jl +++ b/test/ext/specialfunctions.jl @@ -16,11 +16,9 @@ using SpecialFunctions # test_scalar(SpecialFunctions.airyaiprime, x) # test_scalar(SpecialFunctions.airybi, x) # test_scalar(SpecialFunctions.airybiprime, x) - if x isa Real - test_scalar(SpecialFunctions.besselj0, x) - test_scalar(SpecialFunctions.besselj1, x) - test_scalar((y) -> SpecialFunctions.besselj(2, y), x) - end + test_scalar(SpecialFunctions.besselj0, x) + test_scalar(SpecialFunctions.besselj1, x) + test_scalar((y) -> SpecialFunctions.besselj(2, y), x) # test_scalar((y) -> SpecialFunctions.sphericalbessely(y, 0.5), 0.3) # test_scalar(SpecialFunctions.dawson, x) @@ -36,7 +34,7 @@ using SpecialFunctions # test_scalar(SpecialFunctions.erfcinv, x) end - if x isa Real && x > 0 + if !(x isa Real) || x > 0 test_scalar(SpecialFunctions.bessely0, x) test_scalar(SpecialFunctions.bessely1, x) test_scalar((y) -> SpecialFunctions.bessely(2, y), x) From 66ded5f20ecb69ad146dce90d25253a8b686fc84 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 6 Dec 2024 13:04:33 -0600 Subject: [PATCH 485/495] workaround i1 issue in llvm.jl (#2181) --- src/absint.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/absint.jl b/src/absint.jl index 3b9034bef6..50282e745c 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -205,7 +205,11 @@ end end function should_recurse(@nospecialize(typ2), @nospecialize(arg_t::LLVM.LLVMType), byref::GPUCompiler.ArgumentCC, dl::LLVM.DataLayout)::Bool - sz = sizeof(dl, arg_t) + sz = if arg_t == LLVM.IntType(1) + 1 + else + sizeof(dl, arg_t) + end if byref != GPUCompiler.BITS_VALUE if sz != sizeof(Int) throw(AssertionError("non bits type $byref of $typ2 has size $sz != sizeof(Int) from arg type $arg_t")) From 3b36ea25157efd37ebe63d1e025decb6defaeb43 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 6 Dec 2024 13:04:57 -0600 Subject: [PATCH 486/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fa5a2c7e99..f7bcad34fc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.19" +version = "0.13.20" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 2fa5bb1352e9771d5f42ae5e054dd459e8af5409 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 6 Dec 2024 19:19:44 -0600 Subject: [PATCH 487/495] Nofree for math methods (#2184) * Nofree for math methods * fix --- src/compiler.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 5fcc53dbde..0155e5da34 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3961,6 +3961,9 @@ end lowerConvention = false end k_name = LLVM.name(llvmfn) + if !has_fn_attr(llvmfn, EnumAttribute("nofree")) + push!(LLVM.function_attributes(llvmfn), EnumAttribute("nofree")) + end end name = string(name) From 865cced8bf96e6300663fe8d8775637957ec056f Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 7 Dec 2024 00:02:44 -0600 Subject: [PATCH 488/495] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f7bcad34fc..940446335e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.20" +version = "0.13.21" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 3edec409c4e43590320df0b02a3463a24638ae0e Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 7 Dec 2024 00:03:35 -0600 Subject: [PATCH 489/495] Fix higher order codegen (#2161) * Fix higher order codegen * fix * fix * working * Update validation.jl * handle, again * Update validation.jl --- src/compiler.jl | 26 +++- src/compiler/interpreter.jl | 25 +--- src/compiler/validation.jl | 231 ++++++------------------------------ src/llvm/transforms.jl | 188 +++++++++++++++++++++++++++++ src/rules/parallelrules.jl | 4 +- 5 files changed, 245 insertions(+), 229 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 0155e5da34..36d9c1473d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5226,12 +5226,12 @@ end # JIT ## -function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType)) +function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String) if job.config.params.ABI <: InlineABI return CompileResult( Val((Symbol(mod), Symbol(adjoint_name))), Val((Symbol(mod), Symbol(primal_name))), - TapeType, + TapeType ) end @@ -5269,7 +5269,7 @@ end const DumpPostOpt = Ref(false) # actual compilation -function _thunk(job, postopt::Bool = true) +function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, String, Union{String, Nothing}, Type, String} mod, meta = codegen(:llvm, job; optimize = false) adjointf, augmented_primalf = meta.adjointf, meta.augmented_primalf @@ -5287,7 +5287,12 @@ function _thunk(job, postopt::Bool = true) end # Run post optimization pipeline - if postopt + prepost = if postopt + mstr = if job.config.params.ABI <: InlineABI + "" + else + string(mod) + end if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI post_optimze!(mod, JIT.get_tm()) if DumpPostOpt[] @@ -5296,12 +5301,17 @@ function _thunk(job, postopt::Bool = true) else propagate_returned!(mod) end + mstr + else + "" end - return (mod, adjoint_name, primal_name, meta.TapeType) + return (mod, adjoint_name, primal_name, meta.TapeType, prepost) end const cache = Dict{UInt,CompileResult}() +const autodiff_cache = Dict{Ptr{Cvoid},Tuple{String, String}}() + const cache_lock = ReentrantLock() @inline function cached_compilation(@nospecialize(job::CompilerJob))::CompileResult key = hash(job) @@ -5313,6 +5323,12 @@ const cache_lock = ReentrantLock() if obj === nothing asm = _thunk(job) obj = _link(job, asm...) + if obj.adjoint isa Ptr{Nothing} + autodiff_cache[obj.adjoint] = (asm[2], asm[5]) + end + if obj.primal isa Ptr{Nothing} && asm[3] isa String + autodiff_cache[obj.primal] = (asm[3], asm[5]) + end cache[key] = obj end obj diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 2d02604eda..2f9d1fbf60 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -44,7 +44,6 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter forward_rules::Bool reverse_rules::Bool - deferred_lower::Bool broadcast_rewrite::Bool handler::T end @@ -55,7 +54,6 @@ function EnzymeInterpreter( world::UInt, forward_rules::Bool, reverse_rules::Bool, - deferred_lower::Bool = true, broadcast_rewrite::Bool = true, handler = nothing ) @@ -83,7 +81,6 @@ function EnzymeInterpreter( IdDict{Any, Bool}(), forward_rules, reverse_rules, - deferred_lower, broadcast_rewrite, handler ) @@ -94,10 +91,9 @@ EnzymeInterpreter( mt::Union{Nothing,Core.MethodTable}, world::UInt, mode::API.CDerivativeMode, - deferred_lower::Bool = true, broadcast_rewrite::Bool = true, handler = nothing -) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, broadcast_rewrite, handler) +) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, broadcast_rewrite, handler) Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params @@ -865,25 +861,6 @@ function abstract_call_known( end end - if interp.deferred_lower && f === Enzyme.autodiff && length(argtypes) >= 4 - if widenconst(argtypes[2]) <: Enzyme.Mode && - widenconst(argtypes[3]) <: Enzyme.Annotation && - widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation} - arginfo2 = ArgInfo( - fargs isa Nothing ? nothing : - [:(Enzyme.autodiff_deferred), fargs[2:end]...], - [Core.Const(Enzyme.autodiff_deferred), argtypes[2:end]...], - ) - return Base.@invoke abstract_call_known( - interp::AbstractInterpreter, - Enzyme.autodiff_deferred::Any, - arginfo2::ArgInfo, - si::StmtInfo, - sv::AbsIntState, - max_methods::Int, - ) - end - end if interp.handler != nothing return interp.handler(interp, f, arginfo, si, sv, max_methods) end diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index e109415d0f..525e4d874c 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -129,9 +129,7 @@ end function memoize!(ptr::Ptr{Cvoid}, fn::String)::String fn = get(ptr_map, ptr, fn) - if !haskey(ptr_map, ptr) - ptr_map[ptr] = fn - else + if haskey(ptr_map, ptr) @assert ptr_map[ptr] == fn end return fn @@ -185,194 +183,6 @@ function check_ir(@nospecialize(job::CompilerJob), mod::LLVM.Module) end end -# Rewrite calls with "jl_roots" to only have the jl_value_t attached and not { { {} addrspace(10)*, [1 x [2 x i64]], i64, i64 }, [2 x i64] } %unbox110183_replacementA -function rewrite_ccalls!(mod::LLVM.Module) - for f in collect(functions(mod)) - replaceAndErase = Tuple{Instruction,Instruction}[] - for bb in blocks(f), inst in instructions(bb) - if isa(inst, LLVM.CallInst) - fn = called_operand(inst) - changed = false - B = IRBuilder() - position!(B, inst) - if isa(fn, LLVM.Function) && LLVM.name(fn) == "llvm.julia.gc_preserve_begin" - uservals = LLVM.Value[] - for lval in collect(arguments(inst)) - llty = value_type(lval) - if isa(llty, LLVM.PointerType) - push!(uservals, lval) - continue - end - vals = get_julia_inner_types(B, nothing, lval) - for v in vals - if isa(v, LLVM.PointerNull) - subchanged = true - continue - end - push!(uservals, v) - end - if length(vals) == 1 && vals[1] == lval - continue - end - changed = true - end - if changed - prevname = LLVM.name(inst) - LLVM.name!(inst, "") - if !isdefined(LLVM, :OperandBundleDef) - newinst = call!( - B, - called_type(inst), - called_operand(inst), - uservals, - collect(operand_bundles(inst)), - prevname, - ) - else - newinst = call!( - B, - called_type(inst), - called_operand(inst), - uservals, - collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), - prevname, - ) - end - for idx in [ - LLVM.API.LLVMAttributeFunctionIndex, - LLVM.API.LLVMAttributeReturnIndex, - [ - LLVM.API.LLVMAttributeIndex(i) for - i = 1:(length(arguments(inst))) - ]..., - ] - idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) - count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) - Attrs = Base.unsafe_convert( - Ptr{LLVM.API.LLVMAttributeRef}, - Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), - ) - LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) - for j = 1:count - LLVM.API.LLVMAddCallSiteAttribute( - newinst, - idx, - unsafe_load(Attrs, j), - ) - end - Libc.free(Attrs) - end - API.EnzymeCopyMetadata(newinst, inst) - callconv!(newinst, callconv(inst)) - push!(replaceAndErase, (inst, newinst)) - end - continue - end - if !isdefined(LLVM, :OperandBundleDef) - newbundles = OperandBundle[] - else - newbundles = OperandBundleDef[] - end - for bunduse in operand_bundles(inst) - if isdefined(LLVM, :OperandBundleDef) - bunduse = LLVM.OperandBundleDef(bunduse) - end - - if !isdefined(LLVM, :OperandBundleDef) - if LLVM.tag(bunduse) != "jl_roots" - push!(newbundles, bunduse) - continue - end - else - if LLVM.tag_name(bunduse) != "jl_roots" - push!(newbundles, bunduse) - continue - end - end - uservals = LLVM.Value[] - subchanged = false - for lval in LLVM.inputs(bunduse) - llty = value_type(lval) - if isa(llty, LLVM.PointerType) - push!(uservals, lval) - continue - end - vals = get_julia_inner_types(B, nothing, lval) - for v in vals - if isa(v, LLVM.PointerNull) - subchanged = true - continue - end - push!(uservals, v) - end - if length(vals) == 1 && vals[1] == lval - continue - end - subchanged = true - end - if !subchanged - push!(newbundles, bunduse) - continue - end - changed = true - if !isdefined(LLVM, :OperandBundleDef) - push!(newbundles, OperandBundle(LLVM.tag(bunduse), uservals)) - else - push!( - newbundles, - OperandBundleDef(LLVM.tag_name(bunduse), uservals), - ) - end - end - changed = false - if changed - prevname = LLVM.name(inst) - LLVM.name!(inst, "") - newinst = call!( - B, - called_type(inst), - called_operand(inst), - collect(arguments(inst)), - newbundles, - prevname, - ) - for idx in [ - LLVM.API.LLVMAttributeFunctionIndex, - LLVM.API.LLVMAttributeReturnIndex, - [ - LLVM.API.LLVMAttributeIndex(i) for - i = 1:(length(arguments(inst))) - ]..., - ] - idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) - count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) - Attrs = Base.unsafe_convert( - Ptr{LLVM.API.LLVMAttributeRef}, - Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), - ) - LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) - for j = 1:count - LLVM.API.LLVMAddCallSiteAttribute( - newinst, - idx, - unsafe_load(Attrs, j), - ) - end - Libc.free(Attrs) - end - API.EnzymeCopyMetadata(newinst, inst) - callconv!(newinst, callconv(inst)) - push!(replaceAndErase, (inst, newinst)) - end - end - end - for (inst, newinst) in replaceAndErase - replace_uses!(inst, newinst) - LLVM.API.LLVMInstructionEraseFromParent(inst) - end - end -end - function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod::LLVM.Module) imported = Set(String[]) if haskey(functions(mod), "malloc") @@ -390,14 +200,14 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstPointerCast(mfn, value_type(f)))) eraseInst(mod, f) end - rewrite_ccalls!(mod) + Compiler.rewrite_ccalls!(mod) del = LLVM.Function[] for f in collect(functions(mod)) if in(f, del) continue end - check_ir!(job, errors, imported, f, del) + check_ir!(job, errors, imported, f, del, mod) end for d in del LLVM.API.LLVMDeleteFunction(d) @@ -408,7 +218,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod if in(f, del) continue end - check_ir!(job, errors, imported, f, del) + check_ir!(job, errors, imported, f, del, mod) end for d in del LLVM.API.LLVMDeleteFunction(d) @@ -417,7 +227,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod return errors end -function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, f::LLVM.Function, deletedfns::Vector{LLVM.Function}) +function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, f::LLVM.Function, deletedfns::Vector{LLVM.Function}, mod::LLVM.Module) calls = LLVM.CallInst[] isInline = API.EnzymeGetCLBool(cglobal((:EnzymeInline, API.libEnzyme))) != 0 mod = LLVM.parent(f) @@ -643,7 +453,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp while length(calls) > 0 inst = pop!(calls) - check_ir!(job, errors, imported, inst, calls) + check_ir!(job, errors, imported, inst, calls, mod) end return errors end @@ -690,7 +500,7 @@ end import GPUCompiler: DYNAMIC_CALL, DELAYED_BINDING, RUNTIME_FUNCTION, UNKNOWN_FUNCTION, POINTER_FUNCTION import GPUCompiler: backtrace, isintrinsic -function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, inst::LLVM.CallInst, calls::Vector{LLVM.CallInst}) +function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, inst::LLVM.CallInst, calls::Vector{LLVM.CallInst}, mod::LLVM.Module) world = job.world interp = GPUCompiler.get_interpreter(job) method_table = Core.Compiler.method_table(interp) @@ -1211,13 +1021,36 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp ptr_val = convert(Int, ptr_arg) ptr = Ptr{Cvoid}(ptr_val) + if haskey(autodiff_cache, ptr) + pname, pmod = autodiff_cache[ptr] + + @assert !haskey(functions(mod), pname) + + pmod = parse(LLVM.Module, pmod) + + @assert haskey(functions(pmod), pname) + + for fn in functions(pmod) + if !isempty(LLVM.blocks(fn)) + linkage!(fn, LLVM.name(fn) != pname ? LLVM.API.LLVMInternalLinkage : LLVM.API.LLVMExternalLinkage) + end + end + + GPUCompiler.link_library!(mod, pmod) + + replaceWith = functions(mod)[pname] + push!(function_attributes(replaceWith), EnumAttribute("alwaysinline")) + linkage!(functions(mod)[pname], LLVM.API.LLVMInternalLinkage) + replace_uses!(ptr_arg, LLVM.const_pointercast(replaceWith, value_type(ptr_arg))) + return errors + end + # look it up in the Julia JIT cache frames = ccall(:jl_lookup_code_address, Any, (Ptr{Cvoid}, Cint), ptr, 0) if length(frames) >= 1 fn, file, line, linfo, fromC, inlined = last(frames) - # Remember pointer in our global map fn = FFI.memoize!(ptr, string(fn)) if length(fn) > 1 && fromC @@ -1229,6 +1062,8 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp fn, LLVM.API.LLVMGetCalledFunctionType(inst), ) + # Remember pointer for subsequent restoration + push!(function_attributes(LLVM.Function(lfn)), StringAttribute("enzymejl_needs_restoration", string(reinterpret(UInt, ptr)))) else lfn = LLVM.API.LLVMConstBitCast( lfn, diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl index aebb8bab5c..2f9c61c0b4 100644 --- a/src/llvm/transforms.jl +++ b/src/llvm/transforms.jl @@ -1,4 +1,192 @@ +# Rewrite calls with "jl_roots" to only have the jl_value_t attached and not { { {} addrspace(10)*, [1 x [2 x i64]], i64, i64 }, [2 x i64] } %unbox110183_replacementA +function rewrite_ccalls!(mod::LLVM.Module) + for f in collect(functions(mod)) + replaceAndErase = Tuple{Instruction,Instruction}[] + for bb in blocks(f), inst in instructions(bb) + if isa(inst, LLVM.CallInst) + fn = called_operand(inst) + changed = false + B = IRBuilder() + position!(B, inst) + if isa(fn, LLVM.Function) && LLVM.name(fn) == "llvm.julia.gc_preserve_begin" + uservals = LLVM.Value[] + for lval in collect(arguments(inst)) + llty = value_type(lval) + if isa(llty, LLVM.PointerType) + push!(uservals, lval) + continue + end + vals = get_julia_inner_types(B, nothing, lval) + for v in vals + if isa(v, LLVM.PointerNull) + subchanged = true + continue + end + push!(uservals, v) + end + if length(vals) == 1 && vals[1] == lval + continue + end + changed = true + end + if changed + prevname = LLVM.name(inst) + LLVM.name!(inst, "") + if !isdefined(LLVM, :OperandBundleDef) + newinst = call!( + B, + called_type(inst), + called_operand(inst), + uservals, + collect(operand_bundles(inst)), + prevname, + ) + else + newinst = call!( + B, + called_type(inst), + called_operand(inst), + uservals, + collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), + prevname, + ) + end + for idx in [ + LLVM.API.LLVMAttributeFunctionIndex, + LLVM.API.LLVMAttributeReturnIndex, + [ + LLVM.API.LLVMAttributeIndex(i) for + i = 1:(length(arguments(inst))) + ]..., + ] + idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) + LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) + for j = 1:count + LLVM.API.LLVMAddCallSiteAttribute( + newinst, + idx, + unsafe_load(Attrs, j), + ) + end + Libc.free(Attrs) + end + API.EnzymeCopyMetadata(newinst, inst) + callconv!(newinst, callconv(inst)) + push!(replaceAndErase, (inst, newinst)) + end + continue + end + if !isdefined(LLVM, :OperandBundleDef) + newbundles = OperandBundle[] + else + newbundles = OperandBundleDef[] + end + for bunduse in operand_bundles(inst) + if isdefined(LLVM, :OperandBundleDef) + bunduse = LLVM.OperandBundleDef(bunduse) + end + + if !isdefined(LLVM, :OperandBundleDef) + if LLVM.tag(bunduse) != "jl_roots" + push!(newbundles, bunduse) + continue + end + else + if LLVM.tag_name(bunduse) != "jl_roots" + push!(newbundles, bunduse) + continue + end + end + uservals = LLVM.Value[] + subchanged = false + for lval in LLVM.inputs(bunduse) + llty = value_type(lval) + if isa(llty, LLVM.PointerType) + push!(uservals, lval) + continue + end + vals = get_julia_inner_types(B, nothing, lval) + for v in vals + if isa(v, LLVM.PointerNull) + subchanged = true + continue + end + push!(uservals, v) + end + if length(vals) == 1 && vals[1] == lval + continue + end + subchanged = true + end + if !subchanged + push!(newbundles, bunduse) + continue + end + changed = true + if !isdefined(LLVM, :OperandBundleDef) + push!(newbundles, OperandBundle(LLVM.tag(bunduse), uservals)) + else + push!( + newbundles, + OperandBundleDef(LLVM.tag_name(bunduse), uservals), + ) + end + end + changed = false + if changed + prevname = LLVM.name(inst) + LLVM.name!(inst, "") + newinst = call!( + B, + called_type(inst), + called_operand(inst), + collect(arguments(inst)), + newbundles, + prevname, + ) + for idx in [ + LLVM.API.LLVMAttributeFunctionIndex, + LLVM.API.LLVMAttributeReturnIndex, + [ + LLVM.API.LLVMAttributeIndex(i) for + i = 1:(length(arguments(inst))) + ]..., + ] + idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) + LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) + for j = 1:count + LLVM.API.LLVMAddCallSiteAttribute( + newinst, + idx, + unsafe_load(Attrs, j), + ) + end + Libc.free(Attrs) + end + API.EnzymeCopyMetadata(newinst, inst) + callconv!(newinst, callconv(inst)) + push!(replaceAndErase, (inst, newinst)) + end + end + end + for (inst, newinst) in replaceAndErase + replace_uses!(inst, newinst) + LLVM.API.LLVMInstructionEraseFromParent(inst) + end + end +end + function force_recompute!(mod::LLVM.Module) for f in functions(mod), bb in blocks(f) iter = LLVM.API.LLVMGetFirstInstruction(bb) diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index 78c9cd9ce8..d4356aba61 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -275,7 +275,7 @@ end world, ) - cmod, fwdmodenm, _, _ = _thunk(ejob, false) #=postopt=# + cmod, fwdmodenm, _, _, _ = _thunk(ejob, false) #=postopt=# LLVM.link!(mod, cmod) @@ -334,7 +334,7 @@ end world, ) - cmod, adjointnm, augfwdnm, TapeType = _thunk(ejob, false) #=postopt=# + cmod, adjointnm, augfwdnm, TapeType, _ = _thunk(ejob, false) #=postopt=# LLVM.link!(mod, cmod) From 551ddd1ca94f1a2f7b1fe018e6415af191115ec9 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 7 Dec 2024 00:59:01 -0600 Subject: [PATCH 490/495] Update errors.jl --- src/errors.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/errors.jl b/src/errors.jl index 4f7b6f9480..83946f1dbd 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -29,7 +29,7 @@ function Base.showerror(io::IO, ece::NoDerivativeException) end end if occursin("cannot handle unknown binary operator", ece.msg) - for msg in ece.msg.split('\n') + for msg in split(ece.msg, '\n') if occursin("cannot handle unknown binary operator", msg) print('\n', msg, '\n') end @@ -111,7 +111,7 @@ function Base.showerror(io::IO, ece::EnzymeInternalError) end print(io, '\n', ece.msg, '\n') else - for msg in ece.msg.split('\n') + for msg in split(ece.msg, '\n') if occursin("Illegal replace ficticious phi for", msg) print('\n', msg, '\n') end From 8e10a0a37d42db7ced533461bdc3d986ce22e3af Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 7 Dec 2024 11:05:03 -0600 Subject: [PATCH 491/495] World backedge holder (#2183) * World backedge holder * fix * fixup * Update compiler.jl * Update compiler.jl * Update interpreter.jl * Update interpreter.jl * Update interpreter.jl * Update interpreter.jl * Update interpreter.jl * Update interpreter.jl * try2 * nothing works rip * more test * hn * keep trying * more * fix * fix * isapplic * fix2 * mark broken --- src/compiler.jl | 17 +++- src/compiler/interpreter.jl | 176 ++++++++++++++++++++++++++++++++---- test/ruleinvalidation.jl | 7 +- 3 files changed, 181 insertions(+), 19 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 36d9c1473d..ddccde1a24 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5573,7 +5573,22 @@ function thunk_generator(world::UInt, source::LineNumberNode, @nospecialize(FA:: # new_ci.min_world = min_world[] new_ci.min_world = world new_ci.max_world = max_world[] - new_ci.edges = Core.MethodInstance[mi] + + edges = Core.MethodInstance[mi] + + if Mode == API.DEM_ForwardMode + push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}, world)) + Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.forward)) + else + push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}, world)) + end + + push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}, world)) + push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{Val{0}}, world)) + Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(Val(0))) + + new_ci.edges = edges + # XXX: setting this edge does not give us proper method invalidation, see # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. # invoking `code_llvm` also does the necessary codegen, as does calling the diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 2f9d1fbf60..bd57ec92dd 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -8,6 +8,7 @@ using Core.Compiler: OptimizationParams, MethodInstance using GPUCompiler: @safe_debug +using GPUCompiler if VERSION < v"1.11.0-DEV.1552" using GPUCompiler: CodeCache, WorldView, @safe_debug end @@ -23,6 +24,141 @@ else import Core.Compiler: get_world_counter, get_world_counter as get_inference_world end +function rule_backedge_holder_generator(world::UInt, source, self, ft::Type) + @nospecialize + sig = Tuple{typeof(Base.identity), Int} + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + has_ambig = Ptr{Int32}(C_NULL) + mthds = Base._methods_by_ftype( + sig, + nothing, + -1, #=lim=# + world, + false, #=ambig=# + min_world, + max_world, + has_ambig, + ) + mtypes, msp, m = mthds[1] + mi = ccall( + :jl_specializations_get_linfo, + Ref{Core.MethodInstance}, + (Any, Any, Any), + m, + mtypes, + msp, + ) + ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo + + # prepare a new code info + new_ci = copy(ci) + empty!(new_ci.code) + @static if isdefined(Core, :DebugInfo) + new_ci.debuginfo = Core.DebugInfo(:none) + else + empty!(new_ci.codelocs) + resize!(new_ci.linetable, 1) # see note below + end + empty!(new_ci.ssaflags) + new_ci.ssavaluetypes = 0 + new_ci.min_world = min_world[] + new_ci.max_world = max_world[] + + ### TODO: backedge from inactive, augmented_primal, forward, reverse + edges = Any[] + + @static if false + if ft == typeof(EnzymeRules.augmented_primal) + # this is illegal + # sig = Tuple{typeof(EnzymeRules.augmented_primal), <:RevConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}} + # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) + push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.augmented_primal), Tuple{<:RevConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}}, world)) + elseif ft == typeof(EnzymeRules.forward) + # this is illegal + # sig = Tuple{typeof(EnzymeRules.forward), <:FwdConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}} + # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) + push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.forward), Tuple{<:FwdConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}}, world)) + else + # sig = Tuple{typeof(EnzymeRules.inactive), Vararg{Annotation}} + # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) + push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.inactive), Tuple{Vararg{Annotation}}, world)) + + # sig = Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Annotation}} + # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) + push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.inactive_noinl), Tuple{Vararg{Annotation}}, world)) + + # sig = Tuple{typeof(EnzymeRules.noalias), Vararg{Any}} + # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) + push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.noalias), Tuple{Vararg{Any}}, world)) + + # sig = Tuple{typeof(EnzymeRules.inactive_type), Type} + # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) + push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.inactive_type), Tuple{Type}, world)) + end + end + + new_ci.edges = edges + + # XXX: setting this edge does not give us proper method invalidation, see + # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. + # invoking `code_llvm` also does the necessary codegen, as does calling the + # underlying C methods -- which GPUCompiler does, so everything Just Works. + + # prepare the slots + new_ci.slotnames = Symbol[Symbol("#self#"), :ft] + new_ci.slotflags = UInt8[0x00 for i = 1:2] + + # return the codegen world age + push!(new_ci.code, Core.Compiler.ReturnNode(0)) + push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` + @static if isdefined(Core, :DebugInfo) + else + push!(new_ci.codelocs, 1) # see note below + end + new_ci.ssavaluetypes += 1 + + return new_ci +end + +@eval Base.@assume_effects :removable :foldable :nothrow @inline function rule_backedge_holder(ft) + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, rule_backedge_holder_generator)) +end + +begin + # Forward-rule catch all + fwd_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}) + # Reverse-rule catch all + rev_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}) + # Inactive-rule catch all + ina_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}) + # All other derivative-related catch all (just for autodiff, not inference), including inactive_noinl, noalias, and inactive_type + gen_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{Val{0}}) + + + fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} + EnzymeRules.add_mt_backedge!(fwd_rule_be, ccall(:jl_method_table_for, Any, (Any,), fwd_sig)::Core.MethodTable, fwd_sig) + + rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} + EnzymeRules.add_mt_backedge!(rev_rule_be, ccall(:jl_method_table_for, Any, (Any,), rev_sig)::Core.MethodTable, rev_sig) + + + for ina_sig in ( + Tuple{typeof(EnzymeRules.inactive), Vararg{Any}}, + ) + EnzymeRules.add_mt_backedge!(ina_rule_be, ccall(:jl_method_table_for, Any, (Any,), ina_sig)::Core.MethodTable, ina_sig) + end + + for gen_sig in ( + Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Any}}, + Tuple{typeof(EnzymeRules.noalias), Vararg{Any}}, + Tuple{typeof(EnzymeRules.inactive_type), Type}, + ) + EnzymeRules.add_mt_backedge!(gen_rule_be, ccall(:jl_method_table_for, Any, (Any,), gen_sig)::Core.MethodTable, gen_sig) + end +end + struct EnzymeInterpreter{T} <: AbstractInterpreter @static if HAS_INTEGRATED_CACHE token::Any @@ -40,8 +176,6 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter inf_params::InferenceParams opt_params::OptimizationParams - rules_cache::IdDict{Any, Bool} - forward_rules::Bool reverse_rules::Bool broadcast_rewrite::Bool @@ -78,7 +212,6 @@ function EnzymeInterpreter( # parameters for inference and optimization parms, OptimizationParams(), - IdDict{Any, Bool}(), forward_rules, reverse_rules, broadcast_rewrite, @@ -99,6 +232,7 @@ Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params get_inference_world(@nospecialize(interp::EnzymeInterpreter)) = interp.world Core.Compiler.get_inference_cache(@nospecialize(interp::EnzymeInterpreter)) = interp.local_cache + @static if HAS_INTEGRATED_CACHE Core.Compiler.cache_owner(@nospecialize(interp::EnzymeInterpreter)) = interp.token else @@ -221,25 +355,35 @@ function Core.Compiler.abstract_call_gf_by_type( elseif is_alwaysinline_func(specTypes) callinfo = AlwaysInlineCallInfo(callinfo, atype) else - # 1. Check if function is inactive - if is_inactive_from_sig(interp, specTypes, sv) + method_table = Core.Compiler.method_table(interp) + if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) callinfo = NoInlineCallInfo(callinfo, atype, :inactive) else - # 2. Check if rule is defined - has_rule = get!(interp.rules_cache, specTypes) do - if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv) - return true - elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv) - return true - else - return false + if interp.forward_rules + if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) + callinfo = NoInlineCallInfo(callinfo, atype, :frule) + end + end + + if interp.reverse_rules + if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) + callinfo = NoInlineCallInfo(callinfo, atype, :rrule) end end - if has_rule - callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule) - end end + + if interp.forward_rules + Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}, interp.world)::Core.MethodInstance) + Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.forward)) + end + if interp.reverse_rules + Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}, interp.world)::Core.MethodInstance) + Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.augmented_primal)) + end + Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}, interp.world)::Core.MethodInstance) + Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(typeof(EnzymeRules.inactive))) end + @static if VERSION ≥ v"1.11-" return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) else diff --git a/test/ruleinvalidation.jl b/test/ruleinvalidation.jl index 37cb21b08f..501b0aac10 100644 --- a/test/ruleinvalidation.jl +++ b/test/ruleinvalidation.jl @@ -42,6 +42,9 @@ end # now test invalidation for `inactive` inactive(::typeof(issue696), args...) = nothing @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 - +@static if VERSION < v"1.11-" + @test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 +else + @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 +end end # module From 5ae464337177200053f0d19bdfcf042ba6a796ba Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 7 Dec 2024 11:36:27 -0600 Subject: [PATCH 492/495] Update EnzymeSpecialFunctionsExt.jl --- ext/EnzymeSpecialFunctionsExt.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/EnzymeSpecialFunctionsExt.jl b/ext/EnzymeSpecialFunctionsExt.jl index 09e62e98c4..5c3b59a265 100644 --- a/ext/EnzymeSpecialFunctionsExt.jl +++ b/ext/EnzymeSpecialFunctionsExt.jl @@ -5,7 +5,8 @@ using Enzyme function __init__() Enzyme.Compiler.known_ops[typeof(SpecialFunctions._logabsgamma)] = (:logabsgamma, 1, (:digamma, typeof(SpecialFunctions.digamma))) - Enzyme.Compiler.cmplx_known_ops[typeof(SpecialFunctions.bessely)] = (:cmplx_jn, 2, nothing) + Enzyme.Compiler.cmplx_known_ops[typeof(SpecialFunctions.bessely)] = (:cmplx_yn, 2, nothing) + Enzyme.Compiler.cmplx_known_ops[typeof(SpecialFunctions.besseli)] = (:cmplx_jn, 2, nothing) Enzyme.Compiler.cmplx_known_ops[typeof(SpecialFunctions.besselj)] = (:cmplx_jn, 2, nothing) Enzyme.Compiler.cmplx_known_ops[typeof(SpecialFunctions.besselk)] = (:cmplx_kn, 2, nothing) end From 6ab27de1434683acc39ecbbc236ff90040322330 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 8 Dec 2024 13:13:16 -0600 Subject: [PATCH 493/495] absint fixup (#2185) * absint fixup * Update Project.toml --------- Co-authored-by: William Moses --- src/absint.jl | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 50282e745c..e34500a335 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -566,10 +566,18 @@ function abs_typeof( elseif fo > offset offset = offset - typed_fieldoffset(typ, lasti) typ = typed_fieldtype(typ, lasti) - @assert Base.isconcretetype(typ) - if !Base.allocatedinline(typ) - legal = false - end + if offset == 0 + if !Base.allocatedinline(typ) + if byref != GPUCompiler.BITS_VALUE + legal = false + end + byref = GPUCompiler.MUT_REF + end + else + if !Base.isconcretetype(typ) || !Base.allocatedinline(typ) + legal = false + end + end seen = true break end @@ -603,7 +611,11 @@ function abs_typeof( typ2 = typ while legal && should_recurse(typ2, value_type(arg), byref, dl) - idx, _ = first_non_ghost(typ2) + if !Base.isconcretetype(typ2) + legal = false + break + end + idx, _ = first_non_ghost(typ2) if idx != -1 typ2 = typed_fieldtype(typ2, idx) if Base.allocatedinline(typ2) From 1d3b801a550e1688cd7c4fb1639c4d1d029b48a0 Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Sun, 8 Dec 2024 19:13:59 +0000 Subject: [PATCH 494/495] Tests II: more Julia function tests (#969) * Julia function tests * Nested reverse test * Disable failing tests on earlier Julia versions * Enable try/catch test * Printing to find CI error * Mark try test as broken on Julia 1.6 * Mark try test as broken on Julia 1.7 * Remove skipmissing test * Remove print debugging * Revert changes * Remove version check * updates for Enzyme changes * remove higher order test --- test/runtests.jl | 105 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index fd31a13ea0..c9cde02c28 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -707,6 +707,111 @@ end Enzyme.API.strictAliasing!(true) f10(x) = hypot(x, 2x) @test autodiff(Reverse, f10, Active, Active(2.0))[1][1] == sqrt(5) + @test autodiff(Forward, f10, Duplicated(2.0, 1.0))[1] == sqrt(5) + + f11(x) = x * sum(LinRange(x, 10.0, 6)) + @test autodiff(Reverse, f11, Active, Active(2.0))[1][1] == 42 + @test autodiff(Forward, f11, Duplicated(2.0, 1.0))[1] == 42 + + f12(x, k) = get(Dict(1 => 1.0, 2 => x, 3 => 3.0), k, 1.0) + @test autodiff(Reverse, f12, Active, Active(2.0), Const(2))[1] == (1.0, nothing) + @test autodiff(Forward, f12, Duplicated(2.0, 1.0), Const(2)) == (1.0,) + @test autodiff(Reverse, f12, Active, Active(2.0), Const(3))[1] == (0.0, nothing) + @test autodiff(Forward, f12, Duplicated(2.0, 1.0), Const(3)) == (0.0,) + @test autodiff(Reverse, f12, Active, Active(2.0), Const(4))[1] == (0.0, nothing) + @test autodiff(Forward, f12, Duplicated(2.0, 1.0), Const(4)) == (0.0,) + + f13(x) = muladd(x, 3, x) + @test autodiff(Reverse, f13, Active, Active(2.0))[1][1] == 4 + @test autodiff(Forward, f13, Duplicated(2.0, 1.0))[1] == 4 + + f14(x) = x * cmp(x, 3) + @test autodiff(Reverse, f14, Active, Active(2.0))[1][1] == -1 + @test autodiff(Forward, f14, Duplicated(2.0, 1.0))[1] == -1 + + f15(x) = x * argmax([1.0, 3.0, 2.0]) + @test autodiff(Reverse, f15, Active, Active(3.0))[1][1] == 2 + @test autodiff(Forward, f15, Duplicated(3.0, 1.0))[1] == 2 + + f16(x) = evalpoly(2, (1, 2, x)) + @test autodiff(Reverse, f16, Active, Active(3.0))[1][1] == 4 + @test autodiff(Forward, f16, Duplicated(3.0, 1.0))[1] == 4 + + f17(x) = @evalpoly(2, 1, 2, x) + @test autodiff(Reverse, f17, Active, Active(3.0))[1][1] == 4 + @test autodiff(Forward, f17, Duplicated(3.0, 1.0))[1] == 4 + + f18(x) = widemul(x, 5.0f0) + @test autodiff(Reverse, f18, Active, Active(2.0f0))[1][1] == 5 + @test autodiff(Forward, f18, Duplicated(2.0f0, 1.0f0))[1] == 5 + + f19(x) = copysign(x, -x) + @test autodiff(Reverse, f19, Active, Active(2.0))[1][1] == -1 + @test autodiff(Forward, f19, Duplicated(2.0, 1.0))[1] == -1 + + f20(x) = sum([ifelse(i > 5, i, zero(i)) for i in [x, 2x, 3x, 4x]]) + @test autodiff(Reverse, f20, Active, Active(2.0))[1][1] == 7 + @test autodiff(Forward, f20, Duplicated(2.0, 1.0))[1] == 7 + + function f21(x) + nt = (a=x, b=2x, c=3x) + return nt.c + end + @test autodiff(Reverse, f21, Active, Active(2.0))[1][1] == 3 + @test autodiff(Forward, f21, Duplicated(2.0, 1.0))[1] == 3 + + f22(x) = sum(fill(x, (3, 3))) + @test autodiff(Reverse, f22, Active, Active(2.0))[1][1] == 9 + @test autodiff(Forward, f22, Duplicated(2.0, 1.0))[1] == 9 + + function f23(x) + a = similar(rand(3, 3)) + fill!(a, x) + return sum(a) + end + @test autodiff(Reverse, f23, Active, Active(2.0))[1][1] == 9 + @test autodiff(Forward, f23, Duplicated(2.0, 1.0))[1] == 9 + + function f24(x) + try + return 3x + catch + return 2x + end + end + @test autodiff(Reverse, f24, Active, Active(2.0))[1][1] == 3 + @test autodiff(Forward, f24, Duplicated(2.0, 1.0))[1] == 3 + + function f25(x) + try + sqrt(-1.0) + return 3x + catch + return 2x + end + end + @test autodiff(Reverse, f25, Active, Active(2.0))[1][1] == 2 + @test autodiff(Forward, f25, Duplicated(2.0, 1.0))[1] == 2 + + f26(x) = circshift([1.0, 2x, 3.0], 1)[end] + @test autodiff(Reverse, f26, Active, Active(2.0))[1][1] == 2 + @test autodiff(Forward, f26, Duplicated(2.0, 1.0))[1] == 2 + + f27(x) = repeat([x 3x], 3)[2, 2] + @test autodiff(Reverse, f27, Active, Active(2.0))[1][1] == 3 + @test autodiff(Forward, f27, Duplicated(2.0, 1.0))[1] == 3 + + f28(x) = x * sum(trues(4, 3)) + @test autodiff(Reverse, f28, Active, Active(2.0))[1][1] == 12 + @test autodiff(Forward, f28, Duplicated(2.0, 1.0))[1] == 12 + + f29(x) = sum(Set([1.0, x, 2x, x])) + @test autodiff(Reverse, f29, Active, Active(2.0))[1][1] == 3 + @test autodiff(Forward, f29, Duplicated(2.0, 1.0))[1] == 3 + + f30(x) = reverse([x 2.0 3x])[1] + @test autodiff(Reverse, f30, Active, Active(2.0))[1][1] == 3 + @test autodiff(Forward, f30, Duplicated(2.0, 1.0))[1] == 3 end function deadarg_pow(z::T, i) where {T<:Real} From 7c0823fa64426a745dae8fc7a50980be0a0a8dc8 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 9 Dec 2024 18:26:22 -0600 Subject: [PATCH 495/495] Fix method table override (#2191) * Fix method table override * fix * fix --------- Co-authored-by: William Moses --- Project.toml | 2 +- src/compiler/interpreter.jl | 7 +++---- src/compiler/validation.jl | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 940446335e..691488b87e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.21" +version = "0.13.22" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index bd57ec92dd..60751a1004 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -165,7 +165,7 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter else code_cache::CodeCache end - method_table::Union{Nothing,Core.MethodTable} + method_table::Core.Compiler.MethodTableView # Cache of inference results for this particular interpreter local_cache::Vector{InferenceResult} @@ -201,7 +201,7 @@ function EnzymeInterpreter( return EnzymeInterpreter( cache_or_token, - mt, + mt == nothing ? Core.Compiler.InternalMethodTable(world) : Core.Compiler.OverlayMethodTable(world, mt), # Initially empty cache Vector{InferenceResult}(), @@ -253,8 +253,7 @@ Core.Compiler.may_compress(@nospecialize(::EnzymeInterpreter)) = true Core.Compiler.may_discard_trees(@nospecialize(::EnzymeInterpreter)) = false Core.Compiler.verbose_stmt_info(@nospecialize(::EnzymeInterpreter)) = false -Core.Compiler.method_table(@nospecialize(interp::EnzymeInterpreter), sv::InferenceState) = - Core.Compiler.OverlayMethodTable(interp.world, interp.method_table) +Core.Compiler.method_table(@nospecialize(interp::EnzymeInterpreter)) = interp.method_table function is_alwaysinline_func(@nospecialize(TT))::Bool isa(TT, DataType) || return false diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 525e4d874c..20be3891a0 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -483,7 +483,7 @@ end end @inline function has_method(@nospecialize(sig::Type), world::UInt, mt::Core.Compiler.OverlayMethodTable) - return has_method(sig, mt.mt, mt.world) || has_method(sig, nothing, mt.world) + return has_method(sig, mt.world, mt.mt) || has_method(sig, mt.world, nothing) end @inline function is_inactive(@nospecialize(tys::Union{Vector{Union{Type,Core.TypeofVararg}}, Core.SimpleVector}), world::UInt, @nospecialize(mt))