diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 2ef66a1571..5b58746b2a 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -212,4 +212,53 @@ let # overload `inlining_policy` end end -end # module Interpreter +import Core.Compiler: abstract_call, abstract_call_known, ArgInfo, StmtInfo, AbsIntState, get_max_methods, + CallMeta, Effects, NoCallInfo, widenconst, mapany + +struct AutodiffCallInfo <: CallInfo + # ... + info::CallInfo +end + + +unwrap_annotation(A::Type{<:Enzyme.Annotation}) = eltype(A) +unwrap_annotation(A::Core.Const) = Core.Const((A.val::Enzyme.Annotation).val) + +function abstract_autodiff(interp::AbstractInterpreter, @nospecialize(f), + arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int) + (; fargs, argtypes) = arginfo + # Requires Mode to be Const + if length(argtypes) < 4 || !(widenconst(argtypes[3]) <: Enzyme.Annotation) # [autodiff, mode, FA, A, ...] + return Base.@invoke abstract_call_known( + interp::AbstractInterpreter, f, arginfo::ArgInfo, + si::StmtInfo, sv::AbsIntState, max_methods::Int) + end + + primal_argvec = mapany(unwrap_annotation, Any[argtypes[3], argtypes[5:end]...]) + primal_call = abstract_call(interp, ArgInfo(nothing, primal_argvec), si, sv, max_methods) + primal_info = primal_call.info + primal_rt = primal_call.rt + # TODO: Calculate proper return type of autodiff + autodiff_rt = Any + # autodiff_rt = primal_rt + @show primal_call + @static if VERSION < v"1.11.0-" + return CallMeta(autodiff_rt, Effects(), AutodiffCallInfo(primal_info)) + else + return CallMeta(Nothing, autodiff_rt, Effects(), AutodiffCallInfo(primal_info)) + end +end + +function abstract_call_known(interp::EnzymeInterpreter, @nospecialize(f), + arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, + max_methods::Int = get_max_methods(interp, f, sv)) + + if f === Enzyme.autodiff || f === Enzyme.autodiff_deferred + return abstract_autodiff(interp, f, arginfo, si, sv, max_methods) + end + return Base.@invoke abstract_call_known( + interp::AbstractInterpreter, f, arginfo::ArgInfo, + si::StmtInfo, sv::AbsIntState, max_methods::Int) +end + +end