Fix SVD error code handling for OpenBLAS 0.3.15+ and MKL 2022+ (again) (#72357)

Summary:
This PR was opened as copy of https://github.com/pytorch/pytorch/pull/68812 by request https://github.com/pytorch/pytorch/pull/68812#issuecomment-1030215862.

-----

Fixes https://github.com/pytorch/pytorch/issues/67693.

Reference LAPACK (used in OpenBLAS) changed info error code for svd when inputs contain non-finite numbers. In PyTorch, we raise an internal assert error for negative `info` error codes because usually, it would indicate the wrong implementation. However, this is not the case with SVD now in newer versions of LAPACK. MKL (tried 2021.4.0) still gives a positive error code for this kind of input. This change aligns with the OpenBLAS and MKL behavior in our code.

MKL 2022 has uses the latest reference LAPACK behavior and returns the same `info` as OpenBLAS 0.3.15+
This PR also fixes https://github.com/pytorch/pytorch/issues/71645 that is due to the updated MKL version in CI.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72357

Reviewed By: albanD

Differential Revision: D34012245

Pulled By: ngimel

fbshipit-source-id: 2b66c173cc3458d8c766b542d0d569191cdce310
(cherry picked from commit fa29e65611)
This commit is contained in:
Ivan Yashchuk 2022-02-07 13:30:37 -08:00 committed by PyTorch MergeBot
parent bc1fb7a618
commit 29c81bbff5
2 changed files with 36 additions and 5 deletions

View File

@ -314,6 +314,16 @@ static inline void singleCheckErrors(int64_t info, const c10::string_view name,
batch_string = ": (Batch element " + std::to_string(batch_id) + ")";
}
if (info < 0) {
// Reference LAPACK 3.10+ changed `info` behavior for inputs with non-finite values
// Previously, it would return `info` > 0, but now it returns `info` = -4
// OpenBLAS 0.3.15+ uses the Reference LAPACK 3.10+.
// MKL 2022.0+ uses the Reference LAPACK 3.10+.
// Older version of MKL and OpenBLAS follow the old behavior (return `info` > 0).
// Here we check for the case where `info` is -4 and raise an error
if (name.find("svd") != name.npos) {
TORCH_CHECK_LINALG(info != -4, name, batch_string,
": The algorithm failed to converge because the input matrix contained non-finite values.");
}
TORCH_INTERNAL_ASSERT(false, name, batch_string,
": Argument ", -info, " has illegal value. Most certainly there is a bug in the implementation calling the backend library.");
} else if (info > 0) {

View File

@ -1685,11 +1685,10 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
def test_norm_extreme_values(self, device):
if torch.device(device).type == 'cpu':
self.skipTest("Test broken on cpu (see gh-71645)")
vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf]
matrix_ords = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf]
# matrix_ords 'nuc', 2, -2 are skipped currently
# See issue https://github.com/pytorch/pytorch/issues/71911
matrix_ords = ['fro', 1, inf, -1, -inf]
vectors = []
matrices = []
for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2):
@ -1727,8 +1726,8 @@ class TestLinalg(TestCase):
if is_broken_matrix_norm_case(ord, x):
continue
else:
result = torch.linalg.norm(x, ord=ord)
result_n = np.linalg.norm(x_n, ord=ord)
result = torch.linalg.norm(x, ord=ord)
self.assertEqual(result, result_n, msg=msg)
# Test degenerate shape results match numpy for linalg.norm vector norms
@ -2651,6 +2650,28 @@ class TestLinalg(TestCase):
result = torch.linalg.svd(a, full_matrices=False)
self.assertEqual(result.S, S)
# This test doesn't work with MAGMA backend https://github.com/pytorch/pytorch/issues/72106
@skipMeta
@skipCUDAIfRocm
@skipCUDAIfNoCusolver
@skipCPUIfNoLapack
@dtypes(*floating_and_complex_types())
def test_svd_nan_error(self, device, dtype):
for svd in [torch.svd, torch.linalg.svd]:
# if input contains NaN then an error is triggered for svd
# When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan.
# When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|The algorithm failed to converge)'
a = torch.full((3, 3), float('nan'), dtype=dtype, device=device)
a[0] = float('nan')
with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg):
svd(a)
error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|\(Batch element 1\): The algorithm failed to converge)'
a = torch.randn(3, 33, 33, dtype=dtype, device=device)
a[1, 0, 0] = float('nan')
with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg):
svd(a)
def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix