mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
torch.linalg routines return torch.linalg.LinAlgError when a numerical error in the computation is found. (#68571)
Summary: This PR fixes https://github.com/pytorch/pytorch/issues/64785 by introducing a `torch.LinAlgError` for reporting errors caused by bad values in linear algebra routines which should allow users to easily catch errors caused by numerical errors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/68571 Reviewed By: malfet Differential Revision: D33254087 Pulled By: albanD fbshipit-source-id: 94b59000fdb6a9765e397158e526d1f815f18f0f
This commit is contained in:
parent
6a84449290
commit
d100d98db8
|
|
@ -12,6 +12,7 @@
|
|||
#include <c10/util/Exception.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
|
||||
|
||||
namespace c10 {
|
||||
|
||||
class CuDNNError : public c10::Error {
|
||||
|
|
@ -70,21 +71,28 @@ namespace at { namespace cuda { namespace solver {
|
|||
C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);
|
||||
}}} // namespace at::cuda::solver
|
||||
|
||||
#define TORCH_CUSOLVER_CHECK(EXPR) \
|
||||
do { \
|
||||
cusolverStatus_t __err = EXPR; \
|
||||
if (__err == CUSOLVER_STATUS_EXECUTION_FAILED) { \
|
||||
TORCH_CHECK(__err == CUSOLVER_STATUS_SUCCESS, \
|
||||
"cusolver error: ", \
|
||||
at::cuda::solver::cusolverGetErrorMessage(__err), \
|
||||
", when calling `" #EXPR "`", \
|
||||
". This error may appear if the input matrix contains NaN."); \
|
||||
} else { \
|
||||
TORCH_CHECK(__err == CUSOLVER_STATUS_SUCCESS, \
|
||||
"cusolver error: ", \
|
||||
at::cuda::solver::cusolverGetErrorMessage(__err), \
|
||||
", when calling `" #EXPR "`"); \
|
||||
} \
|
||||
// When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan.
|
||||
// When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
|
||||
#define TORCH_CUSOLVER_CHECK(EXPR) \
|
||||
do { \
|
||||
cusolverStatus_t __err = EXPR; \
|
||||
if ((CUDA_VERSION < 11500 && \
|
||||
__err == CUSOLVER_STATUS_EXECUTION_FAILED) || \
|
||||
(CUDA_VERSION >= 11500 && \
|
||||
__err == CUSOLVER_STATUS_INVALID_VALUE)) { \
|
||||
TORCH_CHECK_LINALG( \
|
||||
false, \
|
||||
"cusolver error: ", \
|
||||
at::cuda::solver::cusolverGetErrorMessage(__err), \
|
||||
", when calling `" #EXPR "`", \
|
||||
". This error may appear if the input matrix contains NaN."); \
|
||||
} else { \
|
||||
TORCH_CHECK( \
|
||||
__err == CUSOLVER_STATUS_SUCCESS, \
|
||||
"cusolver error: ", \
|
||||
at::cuda::solver::cusolverGetErrorMessage(__err), \
|
||||
", when calling `" #EXPR "`"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#else
|
||||
|
|
|
|||
|
|
@ -2836,7 +2836,7 @@ Tensor linalg_tensorinv(const Tensor& self, int64_t ind) {
|
|||
// If the reshaped self is not invertible catch this error
|
||||
Tensor result, info;
|
||||
std::tie(result, info) = at::linalg_inv_ex(self.reshape({prod_ind_end, prod_ind_end}), /*check_errors=*/false);
|
||||
TORCH_CHECK(info.item<int64_t>() == 0, "Failed to invert the input tensor, because it is singular.");
|
||||
singleCheckErrors(info.item<int64_t>(), "inv");
|
||||
|
||||
return result.reshape(shape_ind_end);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
|
|
@ -261,23 +262,23 @@ static inline void singleCheckErrors(int64_t info, const char* name, int64_t bat
|
|||
} else if (info > 0) {
|
||||
if (strstr(name, "inv")) {
|
||||
// inv, inverse, cholesky_inverse, etc.
|
||||
TORCH_CHECK(false, name, batch_string,
|
||||
TORCH_CHECK_LINALG(false, name, batch_string,
|
||||
": The diagonal element ", info, " is zero, the inversion could not be completed because the input matrix is singular.");
|
||||
} else if (strstr(name, "solve")) {
|
||||
// solve, linalg_solve, cholesky_solve, etc.
|
||||
TORCH_CHECK(false, name, batch_string,
|
||||
TORCH_CHECK_LINALG(false, name, batch_string,
|
||||
": The diagonal element ", info, " is zero, the solve could not be completed because the input matrix is singular.");
|
||||
} else if (strstr(name, "cholesky")) {
|
||||
TORCH_CHECK(false, name, batch_string,
|
||||
TORCH_CHECK_LINALG(false, name, batch_string,
|
||||
": The factorization could not be completed because the input is not positive-definite (the leading minor of order ", info, " is not positive-definite).");
|
||||
} else if (strstr(name, "svd")) {
|
||||
TORCH_CHECK(false, name, batch_string,
|
||||
TORCH_CHECK_LINALG(false, name, batch_string,
|
||||
": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated singular values (error code: ", info, ").");
|
||||
} else if (strstr(name, "eig") || strstr(name, "syevd")) {
|
||||
TORCH_CHECK(false, name, batch_string,
|
||||
TORCH_CHECK_LINALG(false, name, batch_string,
|
||||
": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated eigenvalues (error code: ", info, ").");
|
||||
} else if (strstr(name, "lstsq")) {
|
||||
TORCH_CHECK(false, name, batch_string,
|
||||
TORCH_CHECK_LINALG(false, name, batch_string,
|
||||
": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, ").");
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(false, name, ": Unknown error code: ", info, ".");
|
||||
|
|
|
|||
|
|
@ -229,6 +229,12 @@ class C10_API OnnxfiBackendSystemError : public Error {
|
|||
using Error::Error;
|
||||
};
|
||||
|
||||
// Used for numerical errors from the linalg module. These
|
||||
// turn into LinAlgError when they cross into Python.
|
||||
class C10_API LinAlgError : public Error {
|
||||
using Error::Error;
|
||||
};
|
||||
|
||||
// A utility function to return an exception std::string by prepending its
|
||||
// exception type before its what() content
|
||||
C10_API std::string GetExceptionString(const std::exception& e);
|
||||
|
|
@ -486,6 +492,10 @@ namespace detail {
|
|||
// TODO: We're going to get a lot of similar looking string literals
|
||||
// this way; check if this actually affects binary size.
|
||||
|
||||
// Like TORCH_CHECK, but raises LinAlgError instead of Error.
|
||||
#define TORCH_CHECK_LINALG(cond, ...) \
|
||||
TORCH_CHECK_WITH_MSG(LinAlgError, cond, "LINALG", __VA_ARGS__)
|
||||
|
||||
// Like TORCH_CHECK, but raises IndexErrors instead of Errors.
|
||||
#define TORCH_CHECK_INDEX(cond, ...) \
|
||||
TORCH_CHECK_WITH_MSG(IndexError, cond, "INDEX", __VA_ARGS__)
|
||||
|
|
|
|||
|
|
@ -504,7 +504,7 @@ class TestLinalg(TestCase):
|
|||
# if the input matrix is not positive definite, an error should be raised
|
||||
A = torch.eye(3, 3, dtype=dtype, device=device)
|
||||
A[-1, -1] = 0 # Now A is not positive definite
|
||||
with self.assertRaisesRegex(RuntimeError, r'minor of order 3 is not positive-definite'):
|
||||
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'):
|
||||
torch.linalg.cholesky(A)
|
||||
with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'):
|
||||
np.linalg.cholesky(A.cpu().numpy())
|
||||
|
|
@ -514,7 +514,7 @@ class TestLinalg(TestCase):
|
|||
A = A.reshape((1, 3, 3))
|
||||
A = A.repeat(5, 1, 1)
|
||||
A[4, -1, -1] = 0 # Now A[4] is not positive definite
|
||||
with self.assertRaisesRegex(RuntimeError, r'\(Batch element 4\): The factorization could not be completed'):
|
||||
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 4\): The factorization could not be completed'):
|
||||
torch.linalg.cholesky(A)
|
||||
|
||||
# if out tensor with wrong shape is passed a warning is given
|
||||
|
|
@ -692,7 +692,7 @@ class TestLinalg(TestCase):
|
|||
A[-1, -1] = 0 # Now A is singular
|
||||
_, info = torch.linalg.cholesky_ex(A)
|
||||
self.assertEqual(info, 3)
|
||||
with self.assertRaisesRegex(RuntimeError, r'minor of order 3 is not positive-definite'):
|
||||
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'):
|
||||
torch.linalg.cholesky_ex(A, check_errors=True)
|
||||
|
||||
# if at least one matrix in the batch is not positive definite,
|
||||
|
|
@ -706,7 +706,7 @@ class TestLinalg(TestCase):
|
|||
expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device)
|
||||
expected_info[3] = 2
|
||||
self.assertEqual(info, expected_info)
|
||||
with self.assertRaisesRegex(RuntimeError, r'\(Batch element 3\): The factorization could not be completed'):
|
||||
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The factorization could not be completed'):
|
||||
torch.linalg.cholesky_ex(A, check_errors=True)
|
||||
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
|
|
@ -2943,12 +2943,12 @@ class TestLinalg(TestCase):
|
|||
error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|The algorithm failed to converge)'
|
||||
a = torch.full((3, 3), float('nan'), dtype=dtype, device=device)
|
||||
a[0] = float('nan')
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg):
|
||||
svd(a)
|
||||
error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|\(Batch element 1\): The algorithm failed to converge)'
|
||||
a = torch.randn(3, 33, 33, dtype=dtype, device=device)
|
||||
a[1, 0, 0] = float('nan')
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg):
|
||||
svd(a)
|
||||
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
|
|
@ -3296,7 +3296,8 @@ class TestLinalg(TestCase):
|
|||
A[-1, -1] = 0 # Now A is singular
|
||||
info = torch.linalg.inv_ex(A).info
|
||||
self.assertEqual(info, 3)
|
||||
with self.assertRaisesRegex(RuntimeError, r'diagonal element 3 is zero, the inversion could not be completed'):
|
||||
with self.assertRaisesRegex(torch.linalg.LinAlgError,
|
||||
r'diagonal element 3 is zero, the inversion could not be completed'):
|
||||
torch.linalg.inv_ex(A, check_errors=True)
|
||||
|
||||
# if at least one matrix in the batch is not positive definite,
|
||||
|
|
@ -3310,7 +3311,7 @@ class TestLinalg(TestCase):
|
|||
expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device)
|
||||
expected_info[3] = 2
|
||||
self.assertEqual(info, expected_info)
|
||||
with self.assertRaisesRegex(RuntimeError, r'\(Batch element 3\): The diagonal element 2 is zero'):
|
||||
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The diagonal element 2 is zero'):
|
||||
torch.linalg.inv_ex(A, check_errors=True)
|
||||
|
||||
@slowTest
|
||||
|
|
@ -3347,7 +3348,7 @@ class TestLinalg(TestCase):
|
|||
def run_test_singular_input(batch_dim, n):
|
||||
x = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
|
||||
x[n, -1, -1] = 0
|
||||
with self.assertRaisesRegex(RuntimeError, rf'\(Batch element {n}\): The diagonal element 3 is zero'):
|
||||
with self.assertRaisesRegex(torch.linalg.LinAlgError, rf'\(Batch element {n}\): The diagonal element 3 is zero'):
|
||||
torch.inverse(x)
|
||||
|
||||
for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
|
||||
|
|
@ -3364,7 +3365,7 @@ class TestLinalg(TestCase):
|
|||
x = torch.empty((8, 10, 616, 616), dtype=dtype, device=device)
|
||||
x[:] = torch.eye(616, dtype=dtype, device=device)
|
||||
x[..., 10, 10] = 0
|
||||
with self.assertRaisesRegex(RuntimeError, r'\(Batch element 0\): The diagonal element 11 is zero'):
|
||||
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 0\): The diagonal element 11 is zero'):
|
||||
torch.inverse(x)
|
||||
|
||||
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-7, torch.complex128: 1e-7})
|
||||
|
|
@ -3498,7 +3499,7 @@ class TestLinalg(TestCase):
|
|||
def run_test_singular_input(batch_dim, n):
|
||||
a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
|
||||
a[n, -1, -1] = 0
|
||||
with self.assertRaisesRegex(RuntimeError, rf"\(Batch element {n}\): The diagonal element 3 is zero"):
|
||||
with self.assertRaisesRegex(torch.linalg.LinAlgError, rf"\(Batch element {n}\): The diagonal element 3 is zero"):
|
||||
torch.linalg.inv(a)
|
||||
|
||||
for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
|
||||
|
|
@ -3629,7 +3630,7 @@ class TestLinalg(TestCase):
|
|||
a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
|
||||
a[n, -1, -1] = 0
|
||||
b = torch.randn(batch_dim, 3, 1, dtype=dtype, device=device)
|
||||
with self.assertRaisesRegex(RuntimeError, rf'\(Batch element {n}\): The diagonal element 3 is zero'):
|
||||
with self.assertRaisesRegex(torch.linalg.LinAlgError, rf'\(Batch element {n}\): The diagonal element 3 is zero'):
|
||||
torch.linalg.solve(a, b)
|
||||
|
||||
for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
|
||||
|
|
@ -4036,7 +4037,7 @@ class TestLinalg(TestCase):
|
|||
a = torch.eye(prod_ind_end, dtype=dtype, device=device)
|
||||
a[-1, -1] = 0 # Now `a` is singular
|
||||
a = a.reshape(a_shape)
|
||||
with self.assertRaisesRegex(RuntimeError, "Failed to invert the input tensor, because it is singular"):
|
||||
with self.assertRaisesRegex(torch.linalg.LinAlgError, "The diagonal element"):
|
||||
torch.linalg.tensorinv(a, ind=ind)
|
||||
|
||||
# test for non-invertible input
|
||||
|
|
@ -7627,7 +7628,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
|||
a = torch.randn(3, 3, device=device, dtype=dtype)
|
||||
a[1, 1] = 0
|
||||
if self.device_type == 'cpu':
|
||||
with self.assertRaisesRegex(RuntimeError, r"cholesky_inverse: The diagonal element 2 is zero"):
|
||||
with self.assertRaisesRegex(torch.linalg.LinAlgError, r"cholesky_inverse: The diagonal element 2 is zero"):
|
||||
torch.cholesky_inverse(a)
|
||||
# cholesky_inverse on GPU does not raise an error for this case
|
||||
elif self.device_type == 'cuda':
|
||||
|
|
|
|||
|
|
@ -9,13 +9,35 @@
|
|||
|
||||
#include <torch/csrc/THP.h>
|
||||
|
||||
PyObject *THPException_FatalError;
|
||||
PyObject *THPException_FatalError, *THPException_LinAlgError;
|
||||
|
||||
#define ASSERT_TRUE(cond) if (!(cond)) return false
|
||||
bool THPException_init(PyObject *module)
|
||||
{
|
||||
ASSERT_TRUE(THPException_FatalError = PyErr_NewException("torch.FatalError", nullptr, nullptr));
|
||||
ASSERT_TRUE(PyModule_AddObject(module, "FatalError", THPException_FatalError) == 0);
|
||||
|
||||
// Set the doc string here since _add_docstr throws malloc errors if tp_doc is modified
|
||||
// for an error class.
|
||||
ASSERT_TRUE(THPException_LinAlgError = PyErr_NewExceptionWithDoc("torch._C._LinAlgError",
|
||||
"Error raised by torch.linalg function when the cause of error is a numerical inconsistency in the data.\n \
|
||||
For example, you can the torch.linalg.inv function will raise torch.linalg.LinAlgError when it finds that \
|
||||
a matrix is not invertible.\n \
|
||||
\n\
|
||||
Example:\n \
|
||||
>>> matrix = torch.eye(3, 3)\n \
|
||||
>>> matrix[-1, -1] = 0\n \
|
||||
>>> matrix\n \
|
||||
tensor([[1., 0., 0.],\n \
|
||||
[0., 1., 0.],\n \
|
||||
[0., 0., 0.]])\n \
|
||||
>>> torch.linalg.inv(matrix)\n \
|
||||
Traceback (most recent call last):\n \
|
||||
File \"<stdin>\", line 1, in <module>\n \
|
||||
torch._C._LinAlgError: torch.linalg.inv: The diagonal element 3 is zero, the inversion\n \
|
||||
could not be completed because the input matrix is singular.", PyExc_RuntimeError, nullptr));
|
||||
ASSERT_TRUE(PyModule_AddObject(module, "_LinAlgError", THPException_LinAlgError) == 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -176,6 +198,13 @@ AttributeError::AttributeError(const char* format, ...) {
|
|||
va_end(fmt_args);
|
||||
}
|
||||
|
||||
LinAlgError::LinAlgError(const char* format, ...) {
|
||||
va_list fmt_args;
|
||||
va_start(fmt_args, format);
|
||||
msg = formatMessage(format, fmt_args);
|
||||
va_end(fmt_args);
|
||||
}
|
||||
|
||||
void PyWarningHandler::InternalHandler::process(
|
||||
const c10::SourceLocation& source_location,
|
||||
const std::string& msg,
|
||||
|
|
|
|||
|
|
@ -81,6 +81,12 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) {
|
|||
PyErr_SetString(PyExc_NotImplementedError, torch::processErrorMsg(msg)); \
|
||||
retstmnt; \
|
||||
} \
|
||||
catch (const c10::LinAlgError& e) { \
|
||||
auto msg = torch::get_cpp_stacktraces_enabled() ? \
|
||||
e.what() : e.what_without_backtrace(); \
|
||||
PyErr_SetString(THPException_LinAlgError, torch::processErrorMsg(msg)); \
|
||||
retstmnt; \
|
||||
} \
|
||||
catch (const c10::Error& e) { \
|
||||
auto msg = torch::get_cpp_stacktraces_enabled() ? \
|
||||
e.what() : e.what_without_backtrace(); \
|
||||
|
|
@ -151,7 +157,7 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) {
|
|||
|
||||
#define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)
|
||||
|
||||
extern PyObject *THPException_FatalError;
|
||||
extern PyObject *THPException_FatalError, *THPException_LinAlgError;
|
||||
|
||||
// Throwing this exception means that the python error flags have been already
|
||||
// set and control should be immediately returned to the interpreter.
|
||||
|
|
@ -334,6 +340,14 @@ struct AttributeError : public PyTorchError {
|
|||
}
|
||||
};
|
||||
|
||||
// Translates to Python LinAlgError
|
||||
struct LinAlgError : public PyTorchError {
|
||||
LinAlgError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
|
||||
PyObject* python_type() override {
|
||||
return THPException_LinAlgError;
|
||||
}
|
||||
};
|
||||
|
||||
struct WarningMeta {
|
||||
WarningMeta(const c10::SourceLocation& _source_location,
|
||||
// NOLINTNEXTLINE(modernize-pass-by-value)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import sys
|
|||
import torch
|
||||
from torch._C import _add_docstr, _linalg # type: ignore[attr-defined]
|
||||
|
||||
LinAlgError = torch._C._LinAlgError # type: ignore[attr-defined]
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
common_notes = {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user