torch.linalg routines return torch.linalg.LinAlgError when a numerical error in the computation is found. (#68571)

Summary:
This PR fixes https://github.com/pytorch/pytorch/issues/64785 by introducing a `torch.LinAlgError` for reporting errors caused by bad values in linear algebra routines which should allow users to easily catch errors caused by numerical errors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/68571

Reviewed By: malfet

Differential Revision: D33254087

Pulled By: albanD

fbshipit-source-id: 94b59000fdb6a9765e397158e526d1f815f18f0f
This commit is contained in:
Sameer Deshmukh 2021-12-23 10:51:36 -08:00 committed by Facebook GitHub Bot
parent 6a84449290
commit d100d98db8
8 changed files with 103 additions and 38 deletions

View File

@ -12,6 +12,7 @@
#include <c10/util/Exception.h>
#include <c10/cuda/CUDAException.h>
namespace c10 {
class CuDNNError : public c10::Error {
@ -70,21 +71,28 @@ namespace at { namespace cuda { namespace solver {
C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);
}}} // namespace at::cuda::solver
#define TORCH_CUSOLVER_CHECK(EXPR) \
do { \
cusolverStatus_t __err = EXPR; \
if (__err == CUSOLVER_STATUS_EXECUTION_FAILED) { \
TORCH_CHECK(__err == CUSOLVER_STATUS_SUCCESS, \
"cusolver error: ", \
at::cuda::solver::cusolverGetErrorMessage(__err), \
", when calling `" #EXPR "`", \
". This error may appear if the input matrix contains NaN."); \
} else { \
TORCH_CHECK(__err == CUSOLVER_STATUS_SUCCESS, \
"cusolver error: ", \
at::cuda::solver::cusolverGetErrorMessage(__err), \
", when calling `" #EXPR "`"); \
} \
// When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan.
// When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
#define TORCH_CUSOLVER_CHECK(EXPR) \
do { \
cusolverStatus_t __err = EXPR; \
if ((CUDA_VERSION < 11500 && \
__err == CUSOLVER_STATUS_EXECUTION_FAILED) || \
(CUDA_VERSION >= 11500 && \
__err == CUSOLVER_STATUS_INVALID_VALUE)) { \
TORCH_CHECK_LINALG( \
false, \
"cusolver error: ", \
at::cuda::solver::cusolverGetErrorMessage(__err), \
", when calling `" #EXPR "`", \
". This error may appear if the input matrix contains NaN."); \
} else { \
TORCH_CHECK( \
__err == CUSOLVER_STATUS_SUCCESS, \
"cusolver error: ", \
at::cuda::solver::cusolverGetErrorMessage(__err), \
", when calling `" #EXPR "`"); \
} \
} while (0)
#else

View File

@ -2836,7 +2836,7 @@ Tensor linalg_tensorinv(const Tensor& self, int64_t ind) {
// If the reshaped self is not invertible catch this error
Tensor result, info;
std::tie(result, info) = at::linalg_inv_ex(self.reshape({prod_ind_end, prod_ind_end}), /*check_errors=*/false);
TORCH_CHECK(info.item<int64_t>() == 0, "Failed to invert the input tensor, because it is singular.");
singleCheckErrors(info.item<int64_t>(), "inv");
return result.reshape(shape_ind_end);
}

View File

@ -2,6 +2,7 @@
#include <c10/core/ScalarType.h>
#include <c10/util/irange.h>
#include <c10/util/Exception.h>
#include <ATen/ATen.h>
#include <ATen/ExpandUtils.h>
#include <ATen/TensorUtils.h>
@ -261,23 +262,23 @@ static inline void singleCheckErrors(int64_t info, const char* name, int64_t bat
} else if (info > 0) {
if (strstr(name, "inv")) {
// inv, inverse, cholesky_inverse, etc.
TORCH_CHECK(false, name, batch_string,
TORCH_CHECK_LINALG(false, name, batch_string,
": The diagonal element ", info, " is zero, the inversion could not be completed because the input matrix is singular.");
} else if (strstr(name, "solve")) {
// solve, linalg_solve, cholesky_solve, etc.
TORCH_CHECK(false, name, batch_string,
TORCH_CHECK_LINALG(false, name, batch_string,
": The diagonal element ", info, " is zero, the solve could not be completed because the input matrix is singular.");
} else if (strstr(name, "cholesky")) {
TORCH_CHECK(false, name, batch_string,
TORCH_CHECK_LINALG(false, name, batch_string,
": The factorization could not be completed because the input is not positive-definite (the leading minor of order ", info, " is not positive-definite).");
} else if (strstr(name, "svd")) {
TORCH_CHECK(false, name, batch_string,
TORCH_CHECK_LINALG(false, name, batch_string,
": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated singular values (error code: ", info, ").");
} else if (strstr(name, "eig") || strstr(name, "syevd")) {
TORCH_CHECK(false, name, batch_string,
TORCH_CHECK_LINALG(false, name, batch_string,
": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated eigenvalues (error code: ", info, ").");
} else if (strstr(name, "lstsq")) {
TORCH_CHECK(false, name, batch_string,
TORCH_CHECK_LINALG(false, name, batch_string,
": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, ").");
} else {
TORCH_INTERNAL_ASSERT(false, name, ": Unknown error code: ", info, ".");

View File

@ -229,6 +229,12 @@ class C10_API OnnxfiBackendSystemError : public Error {
using Error::Error;
};
// Used for numerical errors from the linalg module. These
// turn into LinAlgError when they cross into Python.
class C10_API LinAlgError : public Error {
using Error::Error;
};
// A utility function to return an exception std::string by prepending its
// exception type before its what() content
C10_API std::string GetExceptionString(const std::exception& e);
@ -486,6 +492,10 @@ namespace detail {
// TODO: We're going to get a lot of similar looking string literals
// this way; check if this actually affects binary size.
// Like TORCH_CHECK, but raises LinAlgError instead of Error.
#define TORCH_CHECK_LINALG(cond, ...) \
TORCH_CHECK_WITH_MSG(LinAlgError, cond, "LINALG", __VA_ARGS__)
// Like TORCH_CHECK, but raises IndexErrors instead of Errors.
#define TORCH_CHECK_INDEX(cond, ...) \
TORCH_CHECK_WITH_MSG(IndexError, cond, "INDEX", __VA_ARGS__)

View File

@ -504,7 +504,7 @@ class TestLinalg(TestCase):
# if the input matrix is not positive definite, an error should be raised
A = torch.eye(3, 3, dtype=dtype, device=device)
A[-1, -1] = 0 # Now A is not positive definite
with self.assertRaisesRegex(RuntimeError, r'minor of order 3 is not positive-definite'):
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'):
torch.linalg.cholesky(A)
with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'):
np.linalg.cholesky(A.cpu().numpy())
@ -514,7 +514,7 @@ class TestLinalg(TestCase):
A = A.reshape((1, 3, 3))
A = A.repeat(5, 1, 1)
A[4, -1, -1] = 0 # Now A[4] is not positive definite
with self.assertRaisesRegex(RuntimeError, r'\(Batch element 4\): The factorization could not be completed'):
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 4\): The factorization could not be completed'):
torch.linalg.cholesky(A)
# if out tensor with wrong shape is passed a warning is given
@ -692,7 +692,7 @@ class TestLinalg(TestCase):
A[-1, -1] = 0 # Now A is singular
_, info = torch.linalg.cholesky_ex(A)
self.assertEqual(info, 3)
with self.assertRaisesRegex(RuntimeError, r'minor of order 3 is not positive-definite'):
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'):
torch.linalg.cholesky_ex(A, check_errors=True)
# if at least one matrix in the batch is not positive definite,
@ -706,7 +706,7 @@ class TestLinalg(TestCase):
expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device)
expected_info[3] = 2
self.assertEqual(info, expected_info)
with self.assertRaisesRegex(RuntimeError, r'\(Batch element 3\): The factorization could not be completed'):
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The factorization could not be completed'):
torch.linalg.cholesky_ex(A, check_errors=True)
@skipCUDAIfNoMagmaAndNoCusolver
@ -2943,12 +2943,12 @@ class TestLinalg(TestCase):
error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|The algorithm failed to converge)'
a = torch.full((3, 3), float('nan'), dtype=dtype, device=device)
a[0] = float('nan')
with self.assertRaisesRegex(RuntimeError, error_msg):
with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg):
svd(a)
error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|\(Batch element 1\): The algorithm failed to converge)'
a = torch.randn(3, 33, 33, dtype=dtype, device=device)
a[1, 0, 0] = float('nan')
with self.assertRaisesRegex(RuntimeError, error_msg):
with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg):
svd(a)
@skipCUDAIfNoMagmaAndNoCusolver
@ -3296,7 +3296,8 @@ class TestLinalg(TestCase):
A[-1, -1] = 0 # Now A is singular
info = torch.linalg.inv_ex(A).info
self.assertEqual(info, 3)
with self.assertRaisesRegex(RuntimeError, r'diagonal element 3 is zero, the inversion could not be completed'):
with self.assertRaisesRegex(torch.linalg.LinAlgError,
r'diagonal element 3 is zero, the inversion could not be completed'):
torch.linalg.inv_ex(A, check_errors=True)
# if at least one matrix in the batch is not positive definite,
@ -3310,7 +3311,7 @@ class TestLinalg(TestCase):
expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device)
expected_info[3] = 2
self.assertEqual(info, expected_info)
with self.assertRaisesRegex(RuntimeError, r'\(Batch element 3\): The diagonal element 2 is zero'):
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The diagonal element 2 is zero'):
torch.linalg.inv_ex(A, check_errors=True)
@slowTest
@ -3347,7 +3348,7 @@ class TestLinalg(TestCase):
def run_test_singular_input(batch_dim, n):
x = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
x[n, -1, -1] = 0
with self.assertRaisesRegex(RuntimeError, rf'\(Batch element {n}\): The diagonal element 3 is zero'):
with self.assertRaisesRegex(torch.linalg.LinAlgError, rf'\(Batch element {n}\): The diagonal element 3 is zero'):
torch.inverse(x)
for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
@ -3364,7 +3365,7 @@ class TestLinalg(TestCase):
x = torch.empty((8, 10, 616, 616), dtype=dtype, device=device)
x[:] = torch.eye(616, dtype=dtype, device=device)
x[..., 10, 10] = 0
with self.assertRaisesRegex(RuntimeError, r'\(Batch element 0\): The diagonal element 11 is zero'):
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 0\): The diagonal element 11 is zero'):
torch.inverse(x)
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-7, torch.complex128: 1e-7})
@ -3498,7 +3499,7 @@ class TestLinalg(TestCase):
def run_test_singular_input(batch_dim, n):
a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
a[n, -1, -1] = 0
with self.assertRaisesRegex(RuntimeError, rf"\(Batch element {n}\): The diagonal element 3 is zero"):
with self.assertRaisesRegex(torch.linalg.LinAlgError, rf"\(Batch element {n}\): The diagonal element 3 is zero"):
torch.linalg.inv(a)
for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
@ -3629,7 +3630,7 @@ class TestLinalg(TestCase):
a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
a[n, -1, -1] = 0
b = torch.randn(batch_dim, 3, 1, dtype=dtype, device=device)
with self.assertRaisesRegex(RuntimeError, rf'\(Batch element {n}\): The diagonal element 3 is zero'):
with self.assertRaisesRegex(torch.linalg.LinAlgError, rf'\(Batch element {n}\): The diagonal element 3 is zero'):
torch.linalg.solve(a, b)
for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
@ -4036,7 +4037,7 @@ class TestLinalg(TestCase):
a = torch.eye(prod_ind_end, dtype=dtype, device=device)
a[-1, -1] = 0 # Now `a` is singular
a = a.reshape(a_shape)
with self.assertRaisesRegex(RuntimeError, "Failed to invert the input tensor, because it is singular"):
with self.assertRaisesRegex(torch.linalg.LinAlgError, "The diagonal element"):
torch.linalg.tensorinv(a, ind=ind)
# test for non-invertible input
@ -7627,7 +7628,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
a = torch.randn(3, 3, device=device, dtype=dtype)
a[1, 1] = 0
if self.device_type == 'cpu':
with self.assertRaisesRegex(RuntimeError, r"cholesky_inverse: The diagonal element 2 is zero"):
with self.assertRaisesRegex(torch.linalg.LinAlgError, r"cholesky_inverse: The diagonal element 2 is zero"):
torch.cholesky_inverse(a)
# cholesky_inverse on GPU does not raise an error for this case
elif self.device_type == 'cuda':

View File

@ -9,13 +9,35 @@
#include <torch/csrc/THP.h>
PyObject *THPException_FatalError;
PyObject *THPException_FatalError, *THPException_LinAlgError;
#define ASSERT_TRUE(cond) if (!(cond)) return false
bool THPException_init(PyObject *module)
{
ASSERT_TRUE(THPException_FatalError = PyErr_NewException("torch.FatalError", nullptr, nullptr));
ASSERT_TRUE(PyModule_AddObject(module, "FatalError", THPException_FatalError) == 0);
// Set the doc string here since _add_docstr throws malloc errors if tp_doc is modified
// for an error class.
ASSERT_TRUE(THPException_LinAlgError = PyErr_NewExceptionWithDoc("torch._C._LinAlgError",
"Error raised by torch.linalg function when the cause of error is a numerical inconsistency in the data.\n \
For example, you can the torch.linalg.inv function will raise torch.linalg.LinAlgError when it finds that \
a matrix is not invertible.\n \
\n\
Example:\n \
>>> matrix = torch.eye(3, 3)\n \
>>> matrix[-1, -1] = 0\n \
>>> matrix\n \
tensor([[1., 0., 0.],\n \
[0., 1., 0.],\n \
[0., 0., 0.]])\n \
>>> torch.linalg.inv(matrix)\n \
Traceback (most recent call last):\n \
File \"<stdin>\", line 1, in <module>\n \
torch._C._LinAlgError: torch.linalg.inv: The diagonal element 3 is zero, the inversion\n \
could not be completed because the input matrix is singular.", PyExc_RuntimeError, nullptr));
ASSERT_TRUE(PyModule_AddObject(module, "_LinAlgError", THPException_LinAlgError) == 0);
return true;
}
@ -176,6 +198,13 @@ AttributeError::AttributeError(const char* format, ...) {
va_end(fmt_args);
}
LinAlgError::LinAlgError(const char* format, ...) {
va_list fmt_args;
va_start(fmt_args, format);
msg = formatMessage(format, fmt_args);
va_end(fmt_args);
}
void PyWarningHandler::InternalHandler::process(
const c10::SourceLocation& source_location,
const std::string& msg,

View File

@ -81,6 +81,12 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) {
PyErr_SetString(PyExc_NotImplementedError, torch::processErrorMsg(msg)); \
retstmnt; \
} \
catch (const c10::LinAlgError& e) { \
auto msg = torch::get_cpp_stacktraces_enabled() ? \
e.what() : e.what_without_backtrace(); \
PyErr_SetString(THPException_LinAlgError, torch::processErrorMsg(msg)); \
retstmnt; \
} \
catch (const c10::Error& e) { \
auto msg = torch::get_cpp_stacktraces_enabled() ? \
e.what() : e.what_without_backtrace(); \
@ -151,7 +157,7 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) {
#define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)
extern PyObject *THPException_FatalError;
extern PyObject *THPException_FatalError, *THPException_LinAlgError;
// Throwing this exception means that the python error flags have been already
// set and control should be immediately returned to the interpreter.
@ -334,6 +340,14 @@ struct AttributeError : public PyTorchError {
}
};
// Translates to Python LinAlgError
struct LinAlgError : public PyTorchError {
LinAlgError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
PyObject* python_type() override {
return THPException_LinAlgError;
}
};
struct WarningMeta {
WarningMeta(const c10::SourceLocation& _source_location,
// NOLINTNEXTLINE(modernize-pass-by-value)

View File

@ -4,6 +4,8 @@ import sys
import torch
from torch._C import _add_docstr, _linalg # type: ignore[attr-defined]
LinAlgError = torch._C._LinAlgError # type: ignore[attr-defined]
Tensor = torch.Tensor
common_notes = {