diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 9258079ba6..5e16971894 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1381,10 +1381,13 @@ result, ∂v, ∂A TT = Tuple{args...} - primal_tt = Tuple{map(eltype, args)...} - rt0 = Compiler.primal_return_type(Reverse, eltype(FA), primal_tt) - - rt = Compiler.remove_innerty(A2){rt0} + rt = if A2 isa UnionAll + primal_tt = Tuple{map(eltype, args)...} + rt0 = Compiler.primal_return_type(Reverse, eltype(FA), primal_tt) + A2{rt0} + else + A2 + end primal_ptr = Compiler.deferred_codegen( FA,