Skip to content

Commit

Permalink
Improve rule arg mixed errors (#1530)
Browse files Browse the repository at this point in the history
* Improve rule arg mixed errors

* fixup

* improve errs
  • Loading branch information
wsmoses authored Jun 10, 2024
1 parent df7dd87 commit 6c2b0d9
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 9 deletions.
16 changes: 15 additions & 1 deletion src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,14 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils,

push!(activity, Ty)

elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg(arg.typ, world) )
elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg_inner(arg.typ, (), world, #=justActive=#Val(true)) == ActiveState)
Ty = Active{arg.typ}
llty = convert(LLVMType, Ty)
arty = convert(LLVMType, arg.typ; allow_boxed=true)
if B !== nothing
if active_reg_inner(arg.typ, (), world, #=justActive=#Val(false)) == MixedState
emit_error(B, orig, "Enzyme: Argument type $(arg.typ) has mixed internal activity types in evaluation of custom rule for $mi. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")
end
al0 = al = emit_allocobj!(B, Ty)
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al))))
al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived))
Expand Down Expand Up @@ -716,6 +719,17 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4]))
innerTy = value_type(parameters(llvmf)[tape_idx+(sret !== nothing)+(RT <: Active)])
if innerTy != value_type(tape)
if isabstracttype(TapeT)
msg = sprint() do io
println(io, "Enzyme : mismatch between innerTy $innerTy and tape type $(value_type(tape))")
println(io, "tape_idx=", tape_idx)
println(io, "sret=", sret)
println(io, "RT=", RT)
println(io, "tape=", tape)
println(io, "llvmf=", string(llvmf))
end
throw(AssertionError(msg))
end
llty = convert(LLVMType, TapeT; allow_boxed=true)
al0 = al = emit_allocobj!(B, TapeT)
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al))))
Expand Down
12 changes: 10 additions & 2 deletions src/rules/jitrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,7 @@ for (N, Width) in Iterators.product(0:30, 1:10)
eval(func_runtime_iterate_rev(N, Width))
end

function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false, endcast=true)
function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false, endcast=true, firstconst_after_tape=true)
width = get_width(gutils)
mode = get_mode(gutils)
mod = LLVM.parent(LLVM.parent(LLVM.parent(orig)))
Expand All @@ -1132,7 +1132,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder,
T_jlvalue = LLVM.StructType(LLVM.LLVMType[])
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)

if firstconst
if firstconst && !firstconst_after_tape
val = new_from_original(gutils, operands(orig)[start])
if lookup
val = lookup_value(gutils, val, B)
Expand Down Expand Up @@ -1196,6 +1196,14 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder,
else
pushfirst!(vals, unsafe_to_llvm(Val(ReturnType)))
end

if firstconst && firstconst_after_tape
val = new_from_original(gutils, operands(orig)[start])
if lookup
val = lookup_value(gutils, val, B)
end
pushfirst!(vals, val)
end

if mode != API.DEM_ForwardMode
uncacheable = get_uncacheable(gutils, orig)
Expand Down
12 changes: 6 additions & 6 deletions src/rules/typeunstablerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,19 @@ function body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primarg
end

function func_runtime_newstruct_augfwd(N, Width)
primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width)
primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; mixed_or_active=true)
body = body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs)

quote
function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, ::Type{NewType}, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, NewType, $(typeargs...)}
function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, ::Type{NewType}, RT::Val{ReturnType}, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, NewType, $(typeargs...)}
$body
end
end
end

@generated function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, ::Type{NewType}, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, NewType}
@generated function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, ::Type{NewType}, RT::Val{ReturnType}, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, NewType}
N = div(length(allargs)+2, Width+1)-1
primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs)
primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; mixed_or_active=true)
return body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs)
end

Expand Down Expand Up @@ -325,7 +325,7 @@ function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tap

width = get_width(gutils)

sret = generic_setup(orig, runtime_newstruct_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset, B, false; firstconst=true, endcast = false)
sret = generic_setup(orig, runtime_newstruct_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset, B, false; firstconst=true, endcast = false, firstconst_after_tape=true)

if width == 1
shadow = sret
Expand Down Expand Up @@ -369,7 +369,7 @@ function common_newstructv_rev(offset, B, orig, gutils, tape)
if !newstruct_common(#=fwd=#false, #=run=#false, offset, B, orig, gutils, #=normalR=#nothing, #=shadowR=#nothing)
@assert tape !== C_NULL
width = get_width(gutils)
generic_setup(orig, runtime_newstruct_rev, Nothing, gutils, #=start=#offset, B, true; firstconst=true, tape)
generic_setup(orig, runtime_newstruct_rev, Nothing, gutils, #=start=#offset, B, true; firstconst=true, tape, firstconst_after_tape=true)
end

return nothing
Expand Down

0 comments on commit 6c2b0d9

Please sign in to comment.