Generalize gesvdjBatched to run whith full_matrices==false (#88502)

As brought up in https://github.com/pytorch/pytorch/issues/86234#issuecomment-1268296036, our heuristic for which SVD backend to choose was not great in some cases.
The case in which there could be some improvements is when we have a
large batch of very small non-square matrices.

This PR, adapts the calling code to gesvdj by creating two temporary
square buffers to allow to call gesvdjBatched, and then copies back the
result into the output buffers.

We then modify the heuristic that chooses between gesvdj and
gesvdjBatched.

Fixes https://github.com/pytorch/pytorch/issues/86234
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88502
Approved by: https://github.com/IvanYashchuk, https://github.com/nikitaved, https://github.com/mruberry, https://github.com/xwang233
This commit is contained in:
lezcano 2022-11-07 19:21:24 +00:00 committed by PyTorch MergeBot
parent 9dadf8fcc2
commit d8506ff42b
3 changed files with 46 additions and 38 deletions

View File

@ -656,23 +656,21 @@ inline static void apply_svd_cusolver_gesvdjBatched(const Tensor& A, const Tenso
using value_t = typename c10::scalar_value_type<scalar_t>::type;
int m = cuda_int_cast(A.size(-2), "m");
int n = cuda_int_cast(A.size(-1), "n");
int k = std::min(m, n);
int batchsize = cuda_int_cast(batchCount(A), "batch size");
int lda = A.stride(-1);
int ldu = compute_uv ? U.stride(-1) : m;
int ldv = compute_uv ? V.stride(-1) : n;
// Need to pass allocated memory to the function, otherwise it fails
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * m * k) : c10::DataPtr{};
auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * n * k) : c10::DataPtr{};
auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * m * ldu) : c10::DataPtr{};
auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * n * ldv) : c10::DataPtr{};
auto A_data = A.data_ptr<scalar_t>();
auto U_data = compute_uv ? U.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_U.get());
auto S_data = S.data_ptr<value_t>();
auto V_data = compute_uv ? V.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_V.get());
int lda = A.stride(-1);
int ldu = compute_uv ? U.stride(-1) : m;
int ldv = compute_uv ? V.stride(-1) : n;
TORCH_INTERNAL_ASSERT(m <= 32 && n <= 32, "gesvdjBatched requires both matrix dimensions not greater than 32, but got "
"m = ", m, " n = ", n);
@ -695,10 +693,42 @@ inline static void apply_svd_cusolver_gesvdjBatched(const Tensor& A, const Tenso
TORCH_CUSOLVER_CHECK(cusolverDnDestroyGesvdjInfo(gesvdj_params));
}
inline static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V, const Tensor& infos, bool compute_uv) {
inline static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V, const Tensor& infos, bool full_matrices, bool compute_uv) {
auto m = A.size(-2);
auto n = A.size(-1);
auto k = std::min(m, n);
// The kernel assumes full_matrices == true
// If full_matrices == false and m != n, we create auxiliary tensors of the right size and copy the results back
auto U_ = U;
auto V_ = V;
if (compute_uv && !full_matrices) {
auto sizes = A.sizes().vec();
if (m > n) {
// Size of U with full_matrices == True
sizes.end()[-1] = m;
// U, V should be a batch of Fortran contiguous arrays
U_ = U.new_empty(sizes).mT();
} else if (m < n) {
// Size of V with full_matrices == True
sizes.end()[-2] = n;
V_ = V.new_empty(sizes).mT();
}
}
// Here U_ and V_ are batches of F-contig square matrices
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "svd_cuda_gesvdjBatched", [&] {
apply_svd_cusolver_gesvdjBatched<scalar_t>(A, U, S, V, infos, compute_uv);
apply_svd_cusolver_gesvdjBatched<scalar_t>(A, U_, S, V_, infos, compute_uv);
});
// Copy the result back if we created any new matrix
if (compute_uv && !full_matrices) {
if (!U_.is_alias_of(U)) {
U.copy_(U_.narrow(-1, 0, k));
}
if (!V_.is_alias_of(V)) {
V.copy_(V_.narrow(-1, 0, k));
}
}
}
template<typename scalar_t>
@ -832,21 +862,23 @@ void svd_cusolver(const Tensor& A,
const Tensor& V,
const Tensor& info) {
// Here U and V are F-contig whenever they are defined (i.e. whenever compute_uv=true)
const auto batch_size = batchCount(A);
const auto m = A.size(-2);
const auto n = A.size(-1);
const auto k = std::min(m, n);
static const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
// The default heuristic is to use gesvdj driver
// The default heuristic is to use the gesvdj driver
const auto driver_v = driver.value_or("gesvdj");
if (driver_v == "gesvd") {
svd_cusolver_gesvd(A, U, S, V, info, full_matrices, compute_uv);
} else if (driver_v == "gesvdj") {
if (m <= 32 && n <= 32 && batch_size > 1 && (full_matrices || m == n)) {
svd_cusolver_gesvdjBatched(cloneBatchedColumnMajor(A), U, S, V, info, compute_uv);
// See the benchmarks in
// https://github.com/pytorch/pytorch/pull/88502#issuecomment-1303860789
// The m <= 32 && n <= 32 restrictions come from the limitations of the cusolver backend. See the cusolver docs
if (m <= 32 && n <= 32) {
svd_cusolver_gesvdjBatched(cloneBatchedColumnMajor(A), U, S, V, info, full_matrices, compute_uv);
} else {
// gesvdj driver may be numerically unstable for large sized matrix
svd_cusolver_gesvdj(cloneBatchedColumnMajor(A), U, S, V, info, full_matrices, compute_uv);

View File

@ -3236,8 +3236,6 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops
xfail('sparse.sampled_addmm'), # sparse
xfail('cross'), # The default value of dim in op is *very* weird. No wonder it doesn't work
xfail('svd', device_type='cuda'), # not unique, see test_linalg_svd for manual test
xfail('linalg.svd', device_type='cuda'), # not unique, see test_linalg_svd for manual test
skip('linalg.eigh', ''), # not unique, see test_linalg_eigh for manual test
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
# ----------------------------------------------------------------------

View File

@ -1540,7 +1540,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.cfloat, torch.cdouble)
@precisionOverride({torch.cfloat: 2e-4})
@precisionOverride({torch.cfloat: 5e-4})
def test_norm_complex(self, device, dtype):
def gen_error_message(input_size, ord, keepdim, dim=None):
return "complex norm failed for input size %s, ord=%s, keepdim=%s, dim=%s" % (
@ -2476,28 +2476,6 @@ class TestLinalg(TestCase):
result = torch.linalg.svd(a, full_matrices=False)
self.assertEqual(result.S, S)
# This test doesn't work with MAGMA backend https://github.com/pytorch/pytorch/issues/72106
@skipMeta
@skipCUDAIfRocm
@skipCUDAIfNoCusolver
@skipCPUIfNoLapack
@dtypes(*floating_and_complex_types())
def test_svd_nan_error(self, device, dtype):
for svd in [torch.svd, torch.linalg.svd]:
# if input contains NaN then an error is triggered for svd
# When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan.
# When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|The algorithm failed to converge)'
a = torch.full((3, 3), float('nan'), dtype=dtype, device=device)
a[0] = float('nan')
with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg):
svd(a)
error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|\(Batch element 1\): The algorithm failed to converge)'
a = torch.randn(3, 33, 33, dtype=dtype, device=device)
a[1, 0, 0] = float('nan')
with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg):
svd(a)
def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix