From 999926d8441db6cca57fae7699c0a4a01652e747 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 5 Nov 2023 14:17:02 -0600 Subject: [PATCH] Fix hcat_fill and custom rule type propagation (#1128) --- src/compiler.jl | 5 +++-- src/internal_rules.jl | 51 +++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 17 +++++++++++++++ 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index a85bc90b522..da31ebf9afa 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4740,7 +4740,8 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, idx = 0 dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(orig))))) - for (v, Ty) in zip(actives, Tys) + Tys2 = (eltype(A) for A in activity[2+isKWCall:end] if A <: Active) + for (v, Ty) in zip(actives, Tys2) TT = typetree(Ty, ctx, dl) Typ = C_NULL ext = extract_value!(B, res, idx) @@ -9294,7 +9295,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if params.run_enzyme # Generate the adjoint - jlrules = String[] + jlrules = String["enzyme_custom"] for (fname, (ftyp, mi)) in foundTys haskey(functions(mod), fname) || continue push!(jlrules, fname) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 1be9c157b19..5f8b409be62 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -428,3 +428,54 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, return (nothing,nothing) end + +@static if VERSION >= v"1.7-" +# Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) +function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} + primal = if EnzymeRules.needs_primal(config) + out.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + out.dval + else + nothing + end + func.val(out.val, inp.val) + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} + nr, nc = size(out.val,1), size(out.val,2) + for b in 1:EnzymeRules.width(config) + da = if EnzymeRules.width(config) == 1 + out.dval + else + out.dval[b] + end + i = 1 + j = 1 + if (typeof(inp) <: Active) + dinp = ntuple(Val(length(inp.val))) do k + Base.@_inline_meta + res = da[i, j] + da[i, j] = 0 + j += 1 + if j == nc+1 + i += 1 + j = 1 + end + T = BT.parameters[k] + if T <: AbstractFloat + T(res) + else + T(0) + end + end + return (nothing, dinp)::Tuple{Nothing, BT} + end + end + return (nothing, nothing) +end +end diff --git a/test/runtests.jl b/test/runtests.jl index 44d003f86fa..b98ed06131c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2552,3 +2552,20 @@ end y = A \ b @test dA ≈ (-z * transpose(y)) end + +@static if VERSION >= v"1.7-" +@testset "hvcat_fill" begin + ar = Matrix{Float64}(undef, 2, 3) + dar = [1.0 2.0 3.0; 4.0 5.0 6.0] + + res = first(Enzyme.autodiff(Reverse, Base.hvcat_fill!, Const, Duplicated(ar, dar), Active((1, 2.2, 3, 4.4, 5, 6.6)))) + + @test res[2][1] == 0 + @test res[2][2] ≈ 2.0 + @test res[2][3] ≈ 0 + @test res[2][4] ≈ 4.0 + @test res[2][5] ≈ 0 + @test res[2][6] ≈ 6.0 +end +end +