Skip to content

Commit

Permalink
Nan in linalg eig (pytorch#67544)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#61251. As per the comment here (pytorch#61251 (comment)), a consensus has been reached to raise an error if there is a NaN value in the input when calling `eig()`. This PR implements that feature.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Pull Request resolved: pytorch#67544

Reviewed By: malfet

Differential Revision: D32310919

Pulled By: mruberry

fbshipit-source-id: fc74a1ae2d929157c2d4c9051e3e9a4bf03dd5be
  • Loading branch information
v0dro authored and facebook-github-bot committed Nov 11, 2021
1 parent d049772 commit 6afb414
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 0 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2712,6 +2712,7 @@ std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Tensor& va
}

std::tuple<Tensor&, Tensor&> linalg_eig_out(const Tensor& input, Tensor& values, Tensor& vectors) {
TORCH_CHECK(input.isfinite().all().item<bool>(), "torch.linalg.eig: input tensor should not contain infs or NaNs.");
squareCheckInputs(input, "linalg.eig");

// unlike NumPy for real-valued inputs the output is always complex-valued
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ std::tuple<Tensor, Tensor> eig_kernel_impl(const Tensor& self, bool& eigenvector
});
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
singleCheckErrors(info, "eig_cpu");

return std::tuple<Tensor, Tensor>(vals_, vecs_);
}

Expand Down
12 changes: 12 additions & 0 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2295,6 +2295,18 @@ def test_eig_errors_and_warnings(self, device, dtype):
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
torch.linalg.eig(a, out=(out_w, out_v))

@skipCPUIfNoLapack
@skipCUDAIfNoMagma
@dtypes(*floating_and_complex_types())
def test_eig_with_nan(self, device, dtype):
for val in [np.inf, np.nan]:
for batch_dim in [(), (10,)]:
a = make_tensor((*batch_dim, 5, 5), device=device, dtype=dtype)
a[..., -1, -1] = val

with self.assertRaisesRegex(RuntimeError, "torch.linalg.eig: input tensor should not"):
torch.linalg.eig(a)

@skipCPUIfNoLapack
@skipCUDAIfNoMagma
# NumPy computes only in float64 and complex128 precisions
Expand Down

0 comments on commit 6afb414

Please sign in to comment.