diff --git a/src/Enzyme.jl b/src/Enzyme.jl index e2244171a5..9258079ba6 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -687,7 +687,7 @@ code, as well as high-order differentiation. A2 = A if A isa UnionAll - rt = Compiler.primal_return_type(mode, FTy, tt) + rt = Compiler.primal_return_type(Reverse, FTy, tt) A2 = A{rt} if rt == Union{} rt = Nothing @@ -840,7 +840,7 @@ code, as well as high-order differentiation. FT = Core.Typeof(f.val) if RT isa UnionAll - rt = Compiler.primal_return_type(mode, FT, tt) + rt = Compiler.primal_return_type(Forward, FT, tt) if rt == Union{} rt = Nothing end @@ -1382,7 +1382,7 @@ result, ∂v, ∂A TT = Tuple{args...} primal_tt = Tuple{map(eltype, args)...} - rt0 = Compiler.primal_return_type(mode, eltype(FA), primal_tt) + rt0 = Compiler.primal_return_type(Reverse, eltype(FA), primal_tt) rt = Compiler.remove_innerty(A2){rt0}