From 6afb414c21543fb62e2f862268a707657c83cd72 Mon Sep 17 00:00:00 2001 From: Sameer Deshmukh Date: Thu, 11 Nov 2021 14:32:13 -0800 Subject: [PATCH] Nan in linalg eig (#67544) Summary: Fixes https://github.com/pytorch/pytorch/issues/61251. As per the comment here (https://github.com/pytorch/pytorch/issues/61251#issuecomment-954676082), a consensus has been reached to raise an error if there is a NaN value in the input when calling `eig()`. This PR implements that feature. cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano Pull Request resolved: https://github.com/pytorch/pytorch/pull/67544 Reviewed By: malfet Differential Revision: D32310919 Pulled By: mruberry fbshipit-source-id: fc74a1ae2d929157c2d4c9051e3e9a4bf03dd5be --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 1 + aten/src/ATen/native/BatchLinearAlgebraKernel.cpp | 1 + test/test_linalg.py | 12 ++++++++++++ 3 files changed, 14 insertions(+) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 34d3f2dd92090..853f5c2cbeb5e 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -2712,6 +2712,7 @@ std::tuple linalg_eig_out_info(const Tensor& input, Tensor& va } std::tuple linalg_eig_out(const Tensor& input, Tensor& values, Tensor& vectors) { + TORCH_CHECK(input.isfinite().all().item(), "torch.linalg.eig: input tensor should not contain infs or NaNs."); squareCheckInputs(input, "linalg.eig"); // unlike NumPy for real-valued inputs the output is always complex-valued diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index 0ff2b4155e5e0..7194e80c770ab 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -206,6 +206,7 @@ std::tuple eig_kernel_impl(const Tensor& self, bool& eigenvector }); // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) singleCheckErrors(info, "eig_cpu"); + return std::tuple(vals_, vecs_); } diff --git a/test/test_linalg.py b/test/test_linalg.py index dc6d5c75e3b6f..7bcae35a2260f 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -2295,6 +2295,18 @@ def test_eig_errors_and_warnings(self, device, dtype): with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): torch.linalg.eig(a, out=(out_w, out_v)) + @skipCPUIfNoLapack + @skipCUDAIfNoMagma + @dtypes(*floating_and_complex_types()) + def test_eig_with_nan(self, device, dtype): + for val in [np.inf, np.nan]: + for batch_dim in [(), (10,)]: + a = make_tensor((*batch_dim, 5, 5), device=device, dtype=dtype) + a[..., -1, -1] = val + + with self.assertRaisesRegex(RuntimeError, "torch.linalg.eig: input tensor should not"): + torch.linalg.eig(a) + @skipCPUIfNoLapack @skipCUDAIfNoMagma # NumPy computes only in float64 and complex128 precisions