diff --git a/src/Enzyme.jl b/src/Enzyme.jl index aa018ea23b..17a7c6ff5d 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -682,15 +682,19 @@ code, as well as high-order differentiation. if A isa UnionAll rt = Compiler.primal_return_type(rmode, Val(world), FTy, tt) - A2 = A{rt} + rt = Core.Compiler.return_type(f.val, tt) + A2 = A{rt} + if rt == Union{} + throw(ErrorException("Return type inferred to be Union{}. Giving up.")) + end else @assert A isa DataType rt = A + if rt == Union{} + throw(ErrorException("Return type inferred to be Union{}. Giving up.")) + end end - if rt == Union{} - error("Return type inferred to be Union{}. Giving up.") - end ModifiedBetweenT = falses_from_args(Nargs + 1) ModifiedBetween = Val(ModifiedBetweenT)