Skip to content

Commit

Permalink
Fix hcat_fill and custom rule type propagation (#1128)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored and michel2323 committed Nov 7, 2023
1 parent c8dec44 commit 999926d
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4740,7 +4740,8 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,

idx = 0
dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(orig)))))
for (v, Ty) in zip(actives, Tys)
Tys2 = (eltype(A) for A in activity[2+isKWCall:end] if A <: Active)
for (v, Ty) in zip(actives, Tys2)
TT = typetree(Ty, ctx, dl)
Typ = C_NULL
ext = extract_value!(B, res, idx)
Expand Down Expand Up @@ -9294,7 +9295,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};

if params.run_enzyme
# Generate the adjoint
jlrules = String[]
jlrules = String["enzyme_custom"]
for (fname, (ftyp, mi)) in foundTys
haskey(functions(mod), fname) || continue
push!(jlrules, fname)
Expand Down
51 changes: 51 additions & 0 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,54 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache,

return (nothing,nothing)
end

@static if VERSION >= v"1.7-"
# Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float)
function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple}
primal = if EnzymeRules.needs_primal(config)
out.val
else
nothing
end
shadow = if EnzymeRules.needs_shadow(config)
out.dval
else
nothing
end
func.val(out.val, inp.val)
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple}
nr, nc = size(out.val,1), size(out.val,2)
for b in 1:EnzymeRules.width(config)
da = if EnzymeRules.width(config) == 1
out.dval
else
out.dval[b]
end
i = 1
j = 1
if (typeof(inp) <: Active)
dinp = ntuple(Val(length(inp.val))) do k
Base.@_inline_meta
res = da[i, j]
da[i, j] = 0
j += 1
if j == nc+1
i += 1
j = 1
end
T = BT.parameters[k]
if T <: AbstractFloat
T(res)
else
T(0)
end
end
return (nothing, dinp)::Tuple{Nothing, BT}
end
end
return (nothing, nothing)
end
end
17 changes: 17 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2552,3 +2552,20 @@ end
y = A \ b
@test dA (-z * transpose(y))
end

@static if VERSION >= v"1.7-"
@testset "hvcat_fill" begin
ar = Matrix{Float64}(undef, 2, 3)
dar = [1.0 2.0 3.0; 4.0 5.0 6.0]

res = first(Enzyme.autodiff(Reverse, Base.hvcat_fill!, Const, Duplicated(ar, dar), Active((1, 2.2, 3, 4.4, 5, 6.6))))

@test res[2][1] == 0
@test res[2][2] 2.0
@test res[2][3] 0
@test res[2][4] 4.0
@test res[2][5] 0
@test res[2][6] 6.0
end
end

0 comments on commit 999926d

Please sign in to comment.