From d100d98db8424d6d88c9d72d156c49ba94e9d18e Mon Sep 17 00:00:00 2001 From: Sameer Deshmukh Date: Thu, 23 Dec 2021 10:51:36 -0800 Subject: [PATCH] `torch.linalg` routines return `torch.linalg.LinAlgError` when a numerical error in the computation is found. (#68571) Summary: This PR fixes https://github.com/pytorch/pytorch/issues/64785 by introducing a `torch.LinAlgError` for reporting errors caused by bad values in linear algebra routines which should allow users to easily catch errors caused by numerical errors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/68571 Reviewed By: malfet Differential Revision: D33254087 Pulled By: albanD fbshipit-source-id: 94b59000fdb6a9765e397158e526d1f815f18f0f --- aten/src/ATen/cuda/Exceptions.h | 38 ++++++++++++++--------- aten/src/ATen/native/LinearAlgebra.cpp | 2 +- aten/src/ATen/native/LinearAlgebraUtils.h | 13 ++++---- c10/util/Exception.h | 10 ++++++ test/test_linalg.py | 29 ++++++++--------- torch/csrc/Exceptions.cpp | 31 +++++++++++++++++- torch/csrc/Exceptions.h | 16 +++++++++- torch/linalg/__init__.py | 2 ++ 8 files changed, 103 insertions(+), 38 deletions(-) diff --git a/aten/src/ATen/cuda/Exceptions.h b/aten/src/ATen/cuda/Exceptions.h index 2d1fd05fa2e..94afbf09201 100644 --- a/aten/src/ATen/cuda/Exceptions.h +++ b/aten/src/ATen/cuda/Exceptions.h @@ -12,6 +12,7 @@ #include #include + namespace c10 { class CuDNNError : public c10::Error { @@ -70,21 +71,28 @@ namespace at { namespace cuda { namespace solver { C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status); }}} // namespace at::cuda::solver -#define TORCH_CUSOLVER_CHECK(EXPR) \ - do { \ - cusolverStatus_t __err = EXPR; \ - if (__err == CUSOLVER_STATUS_EXECUTION_FAILED) { \ - TORCH_CHECK(__err == CUSOLVER_STATUS_SUCCESS, \ - "cusolver error: ", \ - at::cuda::solver::cusolverGetErrorMessage(__err), \ - ", when calling `" #EXPR "`", \ - ". This error may appear if the input matrix contains NaN."); \ - } else { \ - TORCH_CHECK(__err == CUSOLVER_STATUS_SUCCESS, \ - "cusolver error: ", \ - at::cuda::solver::cusolverGetErrorMessage(__err), \ - ", when calling `" #EXPR "`"); \ - } \ +// When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan. +// When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue. +#define TORCH_CUSOLVER_CHECK(EXPR) \ + do { \ + cusolverStatus_t __err = EXPR; \ + if ((CUDA_VERSION < 11500 && \ + __err == CUSOLVER_STATUS_EXECUTION_FAILED) || \ + (CUDA_VERSION >= 11500 && \ + __err == CUSOLVER_STATUS_INVALID_VALUE)) { \ + TORCH_CHECK_LINALG( \ + false, \ + "cusolver error: ", \ + at::cuda::solver::cusolverGetErrorMessage(__err), \ + ", when calling `" #EXPR "`", \ + ". This error may appear if the input matrix contains NaN."); \ + } else { \ + TORCH_CHECK( \ + __err == CUSOLVER_STATUS_SUCCESS, \ + "cusolver error: ", \ + at::cuda::solver::cusolverGetErrorMessage(__err), \ + ", when calling `" #EXPR "`"); \ + } \ } while (0) #else diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 03e14b7a679..1a127050ed5 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -2836,7 +2836,7 @@ Tensor linalg_tensorinv(const Tensor& self, int64_t ind) { // If the reshaped self is not invertible catch this error Tensor result, info; std::tie(result, info) = at::linalg_inv_ex(self.reshape({prod_ind_end, prod_ind_end}), /*check_errors=*/false); - TORCH_CHECK(info.item() == 0, "Failed to invert the input tensor, because it is singular."); + singleCheckErrors(info.item(), "inv"); return result.reshape(shape_ind_end); } diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 056c1b8b392..0345cc99eb8 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -261,23 +262,23 @@ static inline void singleCheckErrors(int64_t info, const char* name, int64_t bat } else if (info > 0) { if (strstr(name, "inv")) { // inv, inverse, cholesky_inverse, etc. - TORCH_CHECK(false, name, batch_string, + TORCH_CHECK_LINALG(false, name, batch_string, ": The diagonal element ", info, " is zero, the inversion could not be completed because the input matrix is singular."); } else if (strstr(name, "solve")) { // solve, linalg_solve, cholesky_solve, etc. - TORCH_CHECK(false, name, batch_string, + TORCH_CHECK_LINALG(false, name, batch_string, ": The diagonal element ", info, " is zero, the solve could not be completed because the input matrix is singular."); } else if (strstr(name, "cholesky")) { - TORCH_CHECK(false, name, batch_string, + TORCH_CHECK_LINALG(false, name, batch_string, ": The factorization could not be completed because the input is not positive-definite (the leading minor of order ", info, " is not positive-definite)."); } else if (strstr(name, "svd")) { - TORCH_CHECK(false, name, batch_string, + TORCH_CHECK_LINALG(false, name, batch_string, ": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated singular values (error code: ", info, ")."); } else if (strstr(name, "eig") || strstr(name, "syevd")) { - TORCH_CHECK(false, name, batch_string, + TORCH_CHECK_LINALG(false, name, batch_string, ": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated eigenvalues (error code: ", info, ")."); } else if (strstr(name, "lstsq")) { - TORCH_CHECK(false, name, batch_string, + TORCH_CHECK_LINALG(false, name, batch_string, ": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, ")."); } else { TORCH_INTERNAL_ASSERT(false, name, ": Unknown error code: ", info, "."); diff --git a/c10/util/Exception.h b/c10/util/Exception.h index 05416da5bb2..0eb0c6a80bf 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -229,6 +229,12 @@ class C10_API OnnxfiBackendSystemError : public Error { using Error::Error; }; +// Used for numerical errors from the linalg module. These +// turn into LinAlgError when they cross into Python. +class C10_API LinAlgError : public Error { + using Error::Error; +}; + // A utility function to return an exception std::string by prepending its // exception type before its what() content C10_API std::string GetExceptionString(const std::exception& e); @@ -486,6 +492,10 @@ namespace detail { // TODO: We're going to get a lot of similar looking string literals // this way; check if this actually affects binary size. +// Like TORCH_CHECK, but raises LinAlgError instead of Error. +#define TORCH_CHECK_LINALG(cond, ...) \ + TORCH_CHECK_WITH_MSG(LinAlgError, cond, "LINALG", __VA_ARGS__) + // Like TORCH_CHECK, but raises IndexErrors instead of Errors. #define TORCH_CHECK_INDEX(cond, ...) \ TORCH_CHECK_WITH_MSG(IndexError, cond, "INDEX", __VA_ARGS__) diff --git a/test/test_linalg.py b/test/test_linalg.py index 495c7e742a7..041e46edc84 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -504,7 +504,7 @@ class TestLinalg(TestCase): # if the input matrix is not positive definite, an error should be raised A = torch.eye(3, 3, dtype=dtype, device=device) A[-1, -1] = 0 # Now A is not positive definite - with self.assertRaisesRegex(RuntimeError, r'minor of order 3 is not positive-definite'): + with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'): torch.linalg.cholesky(A) with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'): np.linalg.cholesky(A.cpu().numpy()) @@ -514,7 +514,7 @@ class TestLinalg(TestCase): A = A.reshape((1, 3, 3)) A = A.repeat(5, 1, 1) A[4, -1, -1] = 0 # Now A[4] is not positive definite - with self.assertRaisesRegex(RuntimeError, r'\(Batch element 4\): The factorization could not be completed'): + with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 4\): The factorization could not be completed'): torch.linalg.cholesky(A) # if out tensor with wrong shape is passed a warning is given @@ -692,7 +692,7 @@ class TestLinalg(TestCase): A[-1, -1] = 0 # Now A is singular _, info = torch.linalg.cholesky_ex(A) self.assertEqual(info, 3) - with self.assertRaisesRegex(RuntimeError, r'minor of order 3 is not positive-definite'): + with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'): torch.linalg.cholesky_ex(A, check_errors=True) # if at least one matrix in the batch is not positive definite, @@ -706,7 +706,7 @@ class TestLinalg(TestCase): expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device) expected_info[3] = 2 self.assertEqual(info, expected_info) - with self.assertRaisesRegex(RuntimeError, r'\(Batch element 3\): The factorization could not be completed'): + with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The factorization could not be completed'): torch.linalg.cholesky_ex(A, check_errors=True) @skipCUDAIfNoMagmaAndNoCusolver @@ -2943,12 +2943,12 @@ class TestLinalg(TestCase): error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|The algorithm failed to converge)' a = torch.full((3, 3), float('nan'), dtype=dtype, device=device) a[0] = float('nan') - with self.assertRaisesRegex(RuntimeError, error_msg): + with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg): svd(a) error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|\(Batch element 1\): The algorithm failed to converge)' a = torch.randn(3, 33, 33, dtype=dtype, device=device) a[1, 0, 0] = float('nan') - with self.assertRaisesRegex(RuntimeError, error_msg): + with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg): svd(a) @skipCUDAIfNoMagmaAndNoCusolver @@ -3296,7 +3296,8 @@ class TestLinalg(TestCase): A[-1, -1] = 0 # Now A is singular info = torch.linalg.inv_ex(A).info self.assertEqual(info, 3) - with self.assertRaisesRegex(RuntimeError, r'diagonal element 3 is zero, the inversion could not be completed'): + with self.assertRaisesRegex(torch.linalg.LinAlgError, + r'diagonal element 3 is zero, the inversion could not be completed'): torch.linalg.inv_ex(A, check_errors=True) # if at least one matrix in the batch is not positive definite, @@ -3310,7 +3311,7 @@ class TestLinalg(TestCase): expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device) expected_info[3] = 2 self.assertEqual(info, expected_info) - with self.assertRaisesRegex(RuntimeError, r'\(Batch element 3\): The diagonal element 2 is zero'): + with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The diagonal element 2 is zero'): torch.linalg.inv_ex(A, check_errors=True) @slowTest @@ -3347,7 +3348,7 @@ class TestLinalg(TestCase): def run_test_singular_input(batch_dim, n): x = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1) x[n, -1, -1] = 0 - with self.assertRaisesRegex(RuntimeError, rf'\(Batch element {n}\): The diagonal element 3 is zero'): + with self.assertRaisesRegex(torch.linalg.LinAlgError, rf'\(Batch element {n}\): The diagonal element 3 is zero'): torch.inverse(x) for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: @@ -3364,7 +3365,7 @@ class TestLinalg(TestCase): x = torch.empty((8, 10, 616, 616), dtype=dtype, device=device) x[:] = torch.eye(616, dtype=dtype, device=device) x[..., 10, 10] = 0 - with self.assertRaisesRegex(RuntimeError, r'\(Batch element 0\): The diagonal element 11 is zero'): + with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 0\): The diagonal element 11 is zero'): torch.inverse(x) @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-7, torch.complex128: 1e-7}) @@ -3498,7 +3499,7 @@ class TestLinalg(TestCase): def run_test_singular_input(batch_dim, n): a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1) a[n, -1, -1] = 0 - with self.assertRaisesRegex(RuntimeError, rf"\(Batch element {n}\): The diagonal element 3 is zero"): + with self.assertRaisesRegex(torch.linalg.LinAlgError, rf"\(Batch element {n}\): The diagonal element 3 is zero"): torch.linalg.inv(a) for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: @@ -3629,7 +3630,7 @@ class TestLinalg(TestCase): a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1) a[n, -1, -1] = 0 b = torch.randn(batch_dim, 3, 1, dtype=dtype, device=device) - with self.assertRaisesRegex(RuntimeError, rf'\(Batch element {n}\): The diagonal element 3 is zero'): + with self.assertRaisesRegex(torch.linalg.LinAlgError, rf'\(Batch element {n}\): The diagonal element 3 is zero'): torch.linalg.solve(a, b) for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: @@ -4036,7 +4037,7 @@ class TestLinalg(TestCase): a = torch.eye(prod_ind_end, dtype=dtype, device=device) a[-1, -1] = 0 # Now `a` is singular a = a.reshape(a_shape) - with self.assertRaisesRegex(RuntimeError, "Failed to invert the input tensor, because it is singular"): + with self.assertRaisesRegex(torch.linalg.LinAlgError, "The diagonal element"): torch.linalg.tensorinv(a, ind=ind) # test for non-invertible input @@ -7627,7 +7628,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A a = torch.randn(3, 3, device=device, dtype=dtype) a[1, 1] = 0 if self.device_type == 'cpu': - with self.assertRaisesRegex(RuntimeError, r"cholesky_inverse: The diagonal element 2 is zero"): + with self.assertRaisesRegex(torch.linalg.LinAlgError, r"cholesky_inverse: The diagonal element 2 is zero"): torch.cholesky_inverse(a) # cholesky_inverse on GPU does not raise an error for this case elif self.device_type == 'cuda': diff --git a/torch/csrc/Exceptions.cpp b/torch/csrc/Exceptions.cpp index d9c5c18b82a..8bf89ae7cbd 100644 --- a/torch/csrc/Exceptions.cpp +++ b/torch/csrc/Exceptions.cpp @@ -9,13 +9,35 @@ #include -PyObject *THPException_FatalError; +PyObject *THPException_FatalError, *THPException_LinAlgError; #define ASSERT_TRUE(cond) if (!(cond)) return false bool THPException_init(PyObject *module) { ASSERT_TRUE(THPException_FatalError = PyErr_NewException("torch.FatalError", nullptr, nullptr)); ASSERT_TRUE(PyModule_AddObject(module, "FatalError", THPException_FatalError) == 0); + + // Set the doc string here since _add_docstr throws malloc errors if tp_doc is modified + // for an error class. + ASSERT_TRUE(THPException_LinAlgError = PyErr_NewExceptionWithDoc("torch._C._LinAlgError", + "Error raised by torch.linalg function when the cause of error is a numerical inconsistency in the data.\n \ +For example, you can the torch.linalg.inv function will raise torch.linalg.LinAlgError when it finds that \ +a matrix is not invertible.\n \ +\n\ +Example:\n \ +>>> matrix = torch.eye(3, 3)\n \ +>>> matrix[-1, -1] = 0\n \ +>>> matrix\n \ + tensor([[1., 0., 0.],\n \ + [0., 1., 0.],\n \ + [0., 0., 0.]])\n \ +>>> torch.linalg.inv(matrix)\n \ +Traceback (most recent call last):\n \ +File \"\", line 1, in \n \ +torch._C._LinAlgError: torch.linalg.inv: The diagonal element 3 is zero, the inversion\n \ +could not be completed because the input matrix is singular.", PyExc_RuntimeError, nullptr)); + ASSERT_TRUE(PyModule_AddObject(module, "_LinAlgError", THPException_LinAlgError) == 0); + return true; } @@ -176,6 +198,13 @@ AttributeError::AttributeError(const char* format, ...) { va_end(fmt_args); } +LinAlgError::LinAlgError(const char* format, ...) { + va_list fmt_args; + va_start(fmt_args, format); + msg = formatMessage(format, fmt_args); + va_end(fmt_args); +} + void PyWarningHandler::InternalHandler::process( const c10::SourceLocation& source_location, const std::string& msg, diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index fe7e12a502c..348dba3de06 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -81,6 +81,12 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) { PyErr_SetString(PyExc_NotImplementedError, torch::processErrorMsg(msg)); \ retstmnt; \ } \ + catch (const c10::LinAlgError& e) { \ + auto msg = torch::get_cpp_stacktraces_enabled() ? \ + e.what() : e.what_without_backtrace(); \ + PyErr_SetString(THPException_LinAlgError, torch::processErrorMsg(msg)); \ + retstmnt; \ + } \ catch (const c10::Error& e) { \ auto msg = torch::get_cpp_stacktraces_enabled() ? \ e.what() : e.what_without_backtrace(); \ @@ -151,7 +157,7 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) { #define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr) -extern PyObject *THPException_FatalError; +extern PyObject *THPException_FatalError, *THPException_LinAlgError; // Throwing this exception means that the python error flags have been already // set and control should be immediately returned to the interpreter. @@ -334,6 +340,14 @@ struct AttributeError : public PyTorchError { } }; +// Translates to Python LinAlgError +struct LinAlgError : public PyTorchError { + LinAlgError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3); + PyObject* python_type() override { + return THPException_LinAlgError; + } +}; + struct WarningMeta { WarningMeta(const c10::SourceLocation& _source_location, // NOLINTNEXTLINE(modernize-pass-by-value) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index fa4c3c495ff..801363c48f3 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -4,6 +4,8 @@ import sys import torch from torch._C import _add_docstr, _linalg # type: ignore[attr-defined] +LinAlgError = torch._C._LinAlgError # type: ignore[attr-defined] + Tensor = torch.Tensor common_notes = {