From d8506ff42b3d0dd8d25ab989967daffba13268cd Mon Sep 17 00:00:00 2001 From: lezcano Date: Mon, 7 Nov 2022 19:21:24 +0000 Subject: [PATCH] Generalize gesvdjBatched to run whith full_matrices==false (#88502) As brought up in https://github.com/pytorch/pytorch/issues/86234#issuecomment-1268296036, our heuristic for which SVD backend to choose was not great in some cases. The case in which there could be some improvements is when we have a large batch of very small non-square matrices. This PR, adapts the calling code to gesvdj by creating two temporary square buffers to allow to call gesvdjBatched, and then copies back the result into the output buffers. We then modify the heuristic that chooses between gesvdj and gesvdjBatched. Fixes https://github.com/pytorch/pytorch/issues/86234 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88502 Approved by: https://github.com/IvanYashchuk, https://github.com/nikitaved, https://github.com/mruberry, https://github.com/xwang233 --- .../cuda/linalg/BatchLinearAlgebraLib.cpp | 58 ++++++++++++++----- test/functorch/test_vmap.py | 2 - test/test_linalg.py | 24 +------- 3 files changed, 46 insertions(+), 38 deletions(-) diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp index 01788e0bdff..89c1246a32d 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp @@ -656,23 +656,21 @@ inline static void apply_svd_cusolver_gesvdjBatched(const Tensor& A, const Tenso using value_t = typename c10::scalar_value_type::type; int m = cuda_int_cast(A.size(-2), "m"); int n = cuda_int_cast(A.size(-1), "n"); - int k = std::min(m, n); int batchsize = cuda_int_cast(batchCount(A), "batch size"); + int lda = A.stride(-1); + int ldu = compute_uv ? U.stride(-1) : m; + int ldv = compute_uv ? V.stride(-1) : n; // Need to pass allocated memory to the function, otherwise it fails auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); - auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * m * k) : c10::DataPtr{}; - auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * n * k) : c10::DataPtr{}; + auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * m * ldu) : c10::DataPtr{}; + auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * n * ldv) : c10::DataPtr{}; auto A_data = A.data_ptr(); auto U_data = compute_uv ? U.data_ptr() : reinterpret_cast(dataPtr_U.get()); auto S_data = S.data_ptr(); auto V_data = compute_uv ? V.data_ptr() : reinterpret_cast(dataPtr_V.get()); - int lda = A.stride(-1); - int ldu = compute_uv ? U.stride(-1) : m; - int ldv = compute_uv ? V.stride(-1) : n; - TORCH_INTERNAL_ASSERT(m <= 32 && n <= 32, "gesvdjBatched requires both matrix dimensions not greater than 32, but got " "m = ", m, " n = ", n); @@ -695,10 +693,42 @@ inline static void apply_svd_cusolver_gesvdjBatched(const Tensor& A, const Tenso TORCH_CUSOLVER_CHECK(cusolverDnDestroyGesvdjInfo(gesvdj_params)); } -inline static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V, const Tensor& infos, bool compute_uv) { +inline static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V, const Tensor& infos, bool full_matrices, bool compute_uv) { + auto m = A.size(-2); + auto n = A.size(-1); + auto k = std::min(m, n); + // The kernel assumes full_matrices == true + // If full_matrices == false and m != n, we create auxiliary tensors of the right size and copy the results back + auto U_ = U; + auto V_ = V; + if (compute_uv && !full_matrices) { + auto sizes = A.sizes().vec(); + if (m > n) { + // Size of U with full_matrices == True + sizes.end()[-1] = m; + // U, V should be a batch of Fortran contiguous arrays + U_ = U.new_empty(sizes).mT(); + } else if (m < n) { + // Size of V with full_matrices == True + sizes.end()[-2] = n; + V_ = V.new_empty(sizes).mT(); + } + } + // Here U_ and V_ are batches of F-contig square matrices + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "svd_cuda_gesvdjBatched", [&] { - apply_svd_cusolver_gesvdjBatched(A, U, S, V, infos, compute_uv); + apply_svd_cusolver_gesvdjBatched(A, U_, S, V_, infos, compute_uv); }); + + // Copy the result back if we created any new matrix + if (compute_uv && !full_matrices) { + if (!U_.is_alias_of(U)) { + U.copy_(U_.narrow(-1, 0, k)); + } + if (!V_.is_alias_of(V)) { + V.copy_(V_.narrow(-1, 0, k)); + } + } } template @@ -832,21 +862,23 @@ void svd_cusolver(const Tensor& A, const Tensor& V, const Tensor& info) { // Here U and V are F-contig whenever they are defined (i.e. whenever compute_uv=true) - const auto batch_size = batchCount(A); const auto m = A.size(-2); const auto n = A.size(-1); const auto k = std::min(m, n); static const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html"; - // The default heuristic is to use gesvdj driver + // The default heuristic is to use the gesvdj driver const auto driver_v = driver.value_or("gesvdj"); if (driver_v == "gesvd") { svd_cusolver_gesvd(A, U, S, V, info, full_matrices, compute_uv); } else if (driver_v == "gesvdj") { - if (m <= 32 && n <= 32 && batch_size > 1 && (full_matrices || m == n)) { - svd_cusolver_gesvdjBatched(cloneBatchedColumnMajor(A), U, S, V, info, compute_uv); + // See the benchmarks in + // https://github.com/pytorch/pytorch/pull/88502#issuecomment-1303860789 + // The m <= 32 && n <= 32 restrictions come from the limitations of the cusolver backend. See the cusolver docs + if (m <= 32 && n <= 32) { + svd_cusolver_gesvdjBatched(cloneBatchedColumnMajor(A), U, S, V, info, full_matrices, compute_uv); } else { // gesvdj driver may be numerically unstable for large sized matrix svd_cusolver_gesvdj(cloneBatchedColumnMajor(A), U, S, V, info, full_matrices, compute_uv); diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index e8863781ad3..3acab4172fc 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3236,8 +3236,6 @@ class TestVmapOperatorsOpInfo(TestCase): xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops xfail('sparse.sampled_addmm'), # sparse xfail('cross'), # The default value of dim in op is *very* weird. No wonder it doesn't work - xfail('svd', device_type='cuda'), # not unique, see test_linalg_svd for manual test - xfail('linalg.svd', device_type='cuda'), # not unique, see test_linalg_svd for manual test skip('linalg.eigh', ''), # not unique, see test_linalg_eigh for manual test skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format # ---------------------------------------------------------------------- diff --git a/test/test_linalg.py b/test/test_linalg.py index 86790677f56..273c74d4e61 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1540,7 +1540,7 @@ class TestLinalg(TestCase): @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.cfloat, torch.cdouble) - @precisionOverride({torch.cfloat: 2e-4}) + @precisionOverride({torch.cfloat: 5e-4}) def test_norm_complex(self, device, dtype): def gen_error_message(input_size, ord, keepdim, dim=None): return "complex norm failed for input size %s, ord=%s, keepdim=%s, dim=%s" % ( @@ -2476,28 +2476,6 @@ class TestLinalg(TestCase): result = torch.linalg.svd(a, full_matrices=False) self.assertEqual(result.S, S) - # This test doesn't work with MAGMA backend https://github.com/pytorch/pytorch/issues/72106 - @skipMeta - @skipCUDAIfRocm - @skipCUDAIfNoCusolver - @skipCPUIfNoLapack - @dtypes(*floating_and_complex_types()) - def test_svd_nan_error(self, device, dtype): - for svd in [torch.svd, torch.linalg.svd]: - # if input contains NaN then an error is triggered for svd - # When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan. - # When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue. - error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|The algorithm failed to converge)' - a = torch.full((3, 3), float('nan'), dtype=dtype, device=device) - a[0] = float('nan') - with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg): - svd(a) - error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|\(Batch element 1\): The algorithm failed to converge)' - a = torch.randn(3, 33, 33, dtype=dtype, device=device) - a[1, 0, 0] = float('nan') - with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg): - svd(a) - def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype): from torch.testing._internal.common_utils import random_hermitian_pd_matrix