Fix SVD error code handling for OpenBLAS 0.3.15+ and MKL 2022+ (#68812)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/67693.

Reference LAPACK (used in OpenBLAS) changed info error code for svd when inputs contain non-finite numbers. In PyTorch, we raise an internal assert error for negative `info` error codes because usually, it would indicate wrong implementation. However, this is not the case with SVD now in newer versions of LAPACK. MKL (tried 2021.4.0) still gives a positive error code for this kind of input. This change aligns with the OpenBLAS and MKL behavior in our code.

**UPDATE:**
MKL 2022 has uses the latest reference LAPACK behavior and returns the same `info` as OpenBLAS 0.3.15+
This PR fixes https://github.com/pytorch/pytorch/issues/71645 that is due to the updated MKL version in CI.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/68812

Reviewed By: mrshenli

Differential Revision: D33844257

Pulled By: ngimel

fbshipit-source-id: fd1c86e37e405b330633d039f49dce466391b66e
(cherry picked from commit c00a9bdeb0)
This commit is contained in:
Ivan Yashchuk 2022-01-28 15:36:25 -08:00 committed by PyTorch MergeBot
parent bc9d1e709a
commit 2017b404ec
2 changed files with 34 additions and 5 deletions

View File

@ -306,6 +306,16 @@ static inline void singleCheckErrors(int64_t info, const c10::string_view name,
batch_string = ": (Batch element " + std::to_string(batch_id) + ")";
}
if (info < 0) {
// Reference LAPACK 3.10+ changed `info` behavior for inputs with non-finite values
// Previously, it would return `info` > 0, but now it returns `info` = -4
// OpenBLAS 0.3.15+ uses the Reference LAPACK 3.10+.
// MKL 2022.0+ uses the Reference LAPACK 3.10+.
// Older version of MKL and OpenBLAS follow the old behavior (return `info` > 0).
// Here we check for the case where `info` is -4 and raise an error
if (name.find("svd") != name.npos) {
TORCH_CHECK_LINALG(info != -4, name, batch_string,
": The algorithm failed to converge because the input matrix contained non-finite values.");
}
TORCH_INTERNAL_ASSERT(false, name, batch_string,
": Argument ", -info, " has illegal value. Most certainly there is a bug in the implementation calling the backend library.");
} else if (info > 0) {

View File

@ -1685,11 +1685,10 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
def test_norm_extreme_values(self, device):
if torch.device(device).type == 'cpu':
self.skipTest("Test broken on cpu (see gh-71645)")
vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf]
matrix_ords = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf]
# matrix_ords 'nuc', 2, -2 are skipped currently
# See issue https://github.com/pytorch/pytorch/issues/71911
matrix_ords = ['fro', 1, inf, -1, -inf]
vectors = []
matrices = []
for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2):
@ -1727,8 +1726,8 @@ class TestLinalg(TestCase):
if is_broken_matrix_norm_case(ord, x):
continue
else:
result = torch.linalg.norm(x, ord=ord)
result_n = np.linalg.norm(x_n, ord=ord)
result = torch.linalg.norm(x, ord=ord)
self.assertEqual(result, result_n, msg=msg)
# Test degenerate shape results match numpy for linalg.norm vector norms
@ -2614,6 +2613,26 @@ class TestLinalg(TestCase):
S_s = torch.svd(A, compute_uv=False).S
self.assertEqual(S_s, S)
@skipMeta
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(*floating_and_complex_types())
def test_svd_nan_error(self, device, dtype):
for svd in [torch.svd, torch.linalg.svd]:
# if input contains NaN then an error is triggered for svd
# When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan.
# When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|The algorithm failed to converge)'
a = torch.full((3, 3), float('nan'), dtype=dtype, device=device)
a[0] = float('nan')
with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg):
svd(a)
error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|\(Batch element 1\): The algorithm failed to converge)'
a = torch.randn(3, 33, 33, dtype=dtype, device=device)
a[1, 0, 0] = float('nan')
with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg):
svd(a)
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.complex128)