mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Generalize gesvdjBatched to run whith full_matrices==false (#88502)
As brought up in https://github.com/pytorch/pytorch/issues/86234#issuecomment-1268296036, our heuristic for which SVD backend to choose was not great in some cases. The case in which there could be some improvements is when we have a large batch of very small non-square matrices. This PR, adapts the calling code to gesvdj by creating two temporary square buffers to allow to call gesvdjBatched, and then copies back the result into the output buffers. We then modify the heuristic that chooses between gesvdj and gesvdjBatched. Fixes https://github.com/pytorch/pytorch/issues/86234 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88502 Approved by: https://github.com/IvanYashchuk, https://github.com/nikitaved, https://github.com/mruberry, https://github.com/xwang233
This commit is contained in:
parent
9dadf8fcc2
commit
d8506ff42b
|
|
@ -656,23 +656,21 @@ inline static void apply_svd_cusolver_gesvdjBatched(const Tensor& A, const Tenso
|
||||||
using value_t = typename c10::scalar_value_type<scalar_t>::type;
|
using value_t = typename c10::scalar_value_type<scalar_t>::type;
|
||||||
int m = cuda_int_cast(A.size(-2), "m");
|
int m = cuda_int_cast(A.size(-2), "m");
|
||||||
int n = cuda_int_cast(A.size(-1), "n");
|
int n = cuda_int_cast(A.size(-1), "n");
|
||||||
int k = std::min(m, n);
|
|
||||||
int batchsize = cuda_int_cast(batchCount(A), "batch size");
|
int batchsize = cuda_int_cast(batchCount(A), "batch size");
|
||||||
|
int lda = A.stride(-1);
|
||||||
|
int ldu = compute_uv ? U.stride(-1) : m;
|
||||||
|
int ldv = compute_uv ? V.stride(-1) : n;
|
||||||
|
|
||||||
// Need to pass allocated memory to the function, otherwise it fails
|
// Need to pass allocated memory to the function, otherwise it fails
|
||||||
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
|
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
|
||||||
auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * m * k) : c10::DataPtr{};
|
auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * m * ldu) : c10::DataPtr{};
|
||||||
auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * n * k) : c10::DataPtr{};
|
auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * n * ldv) : c10::DataPtr{};
|
||||||
|
|
||||||
auto A_data = A.data_ptr<scalar_t>();
|
auto A_data = A.data_ptr<scalar_t>();
|
||||||
auto U_data = compute_uv ? U.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_U.get());
|
auto U_data = compute_uv ? U.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_U.get());
|
||||||
auto S_data = S.data_ptr<value_t>();
|
auto S_data = S.data_ptr<value_t>();
|
||||||
auto V_data = compute_uv ? V.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_V.get());
|
auto V_data = compute_uv ? V.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_V.get());
|
||||||
|
|
||||||
int lda = A.stride(-1);
|
|
||||||
int ldu = compute_uv ? U.stride(-1) : m;
|
|
||||||
int ldv = compute_uv ? V.stride(-1) : n;
|
|
||||||
|
|
||||||
TORCH_INTERNAL_ASSERT(m <= 32 && n <= 32, "gesvdjBatched requires both matrix dimensions not greater than 32, but got "
|
TORCH_INTERNAL_ASSERT(m <= 32 && n <= 32, "gesvdjBatched requires both matrix dimensions not greater than 32, but got "
|
||||||
"m = ", m, " n = ", n);
|
"m = ", m, " n = ", n);
|
||||||
|
|
||||||
|
|
@ -695,10 +693,42 @@ inline static void apply_svd_cusolver_gesvdjBatched(const Tensor& A, const Tenso
|
||||||
TORCH_CUSOLVER_CHECK(cusolverDnDestroyGesvdjInfo(gesvdj_params));
|
TORCH_CUSOLVER_CHECK(cusolverDnDestroyGesvdjInfo(gesvdj_params));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V, const Tensor& infos, bool compute_uv) {
|
inline static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V, const Tensor& infos, bool full_matrices, bool compute_uv) {
|
||||||
|
auto m = A.size(-2);
|
||||||
|
auto n = A.size(-1);
|
||||||
|
auto k = std::min(m, n);
|
||||||
|
// The kernel assumes full_matrices == true
|
||||||
|
// If full_matrices == false and m != n, we create auxiliary tensors of the right size and copy the results back
|
||||||
|
auto U_ = U;
|
||||||
|
auto V_ = V;
|
||||||
|
if (compute_uv && !full_matrices) {
|
||||||
|
auto sizes = A.sizes().vec();
|
||||||
|
if (m > n) {
|
||||||
|
// Size of U with full_matrices == True
|
||||||
|
sizes.end()[-1] = m;
|
||||||
|
// U, V should be a batch of Fortran contiguous arrays
|
||||||
|
U_ = U.new_empty(sizes).mT();
|
||||||
|
} else if (m < n) {
|
||||||
|
// Size of V with full_matrices == True
|
||||||
|
sizes.end()[-2] = n;
|
||||||
|
V_ = V.new_empty(sizes).mT();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Here U_ and V_ are batches of F-contig square matrices
|
||||||
|
|
||||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "svd_cuda_gesvdjBatched", [&] {
|
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "svd_cuda_gesvdjBatched", [&] {
|
||||||
apply_svd_cusolver_gesvdjBatched<scalar_t>(A, U, S, V, infos, compute_uv);
|
apply_svd_cusolver_gesvdjBatched<scalar_t>(A, U_, S, V_, infos, compute_uv);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Copy the result back if we created any new matrix
|
||||||
|
if (compute_uv && !full_matrices) {
|
||||||
|
if (!U_.is_alias_of(U)) {
|
||||||
|
U.copy_(U_.narrow(-1, 0, k));
|
||||||
|
}
|
||||||
|
if (!V_.is_alias_of(V)) {
|
||||||
|
V.copy_(V_.narrow(-1, 0, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename scalar_t>
|
template<typename scalar_t>
|
||||||
|
|
@ -832,21 +862,23 @@ void svd_cusolver(const Tensor& A,
|
||||||
const Tensor& V,
|
const Tensor& V,
|
||||||
const Tensor& info) {
|
const Tensor& info) {
|
||||||
// Here U and V are F-contig whenever they are defined (i.e. whenever compute_uv=true)
|
// Here U and V are F-contig whenever they are defined (i.e. whenever compute_uv=true)
|
||||||
const auto batch_size = batchCount(A);
|
|
||||||
const auto m = A.size(-2);
|
const auto m = A.size(-2);
|
||||||
const auto n = A.size(-1);
|
const auto n = A.size(-1);
|
||||||
const auto k = std::min(m, n);
|
const auto k = std::min(m, n);
|
||||||
|
|
||||||
static const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
|
static const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
|
||||||
|
|
||||||
// The default heuristic is to use gesvdj driver
|
// The default heuristic is to use the gesvdj driver
|
||||||
const auto driver_v = driver.value_or("gesvdj");
|
const auto driver_v = driver.value_or("gesvdj");
|
||||||
|
|
||||||
if (driver_v == "gesvd") {
|
if (driver_v == "gesvd") {
|
||||||
svd_cusolver_gesvd(A, U, S, V, info, full_matrices, compute_uv);
|
svd_cusolver_gesvd(A, U, S, V, info, full_matrices, compute_uv);
|
||||||
} else if (driver_v == "gesvdj") {
|
} else if (driver_v == "gesvdj") {
|
||||||
if (m <= 32 && n <= 32 && batch_size > 1 && (full_matrices || m == n)) {
|
// See the benchmarks in
|
||||||
svd_cusolver_gesvdjBatched(cloneBatchedColumnMajor(A), U, S, V, info, compute_uv);
|
// https://github.com/pytorch/pytorch/pull/88502#issuecomment-1303860789
|
||||||
|
// The m <= 32 && n <= 32 restrictions come from the limitations of the cusolver backend. See the cusolver docs
|
||||||
|
if (m <= 32 && n <= 32) {
|
||||||
|
svd_cusolver_gesvdjBatched(cloneBatchedColumnMajor(A), U, S, V, info, full_matrices, compute_uv);
|
||||||
} else {
|
} else {
|
||||||
// gesvdj driver may be numerically unstable for large sized matrix
|
// gesvdj driver may be numerically unstable for large sized matrix
|
||||||
svd_cusolver_gesvdj(cloneBatchedColumnMajor(A), U, S, V, info, full_matrices, compute_uv);
|
svd_cusolver_gesvdj(cloneBatchedColumnMajor(A), U, S, V, info, full_matrices, compute_uv);
|
||||||
|
|
|
||||||
|
|
@ -3236,8 +3236,6 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||||
xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops
|
xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops
|
||||||
xfail('sparse.sampled_addmm'), # sparse
|
xfail('sparse.sampled_addmm'), # sparse
|
||||||
xfail('cross'), # The default value of dim in op is *very* weird. No wonder it doesn't work
|
xfail('cross'), # The default value of dim in op is *very* weird. No wonder it doesn't work
|
||||||
xfail('svd', device_type='cuda'), # not unique, see test_linalg_svd for manual test
|
|
||||||
xfail('linalg.svd', device_type='cuda'), # not unique, see test_linalg_svd for manual test
|
|
||||||
skip('linalg.eigh', ''), # not unique, see test_linalg_eigh for manual test
|
skip('linalg.eigh', ''), # not unique, see test_linalg_eigh for manual test
|
||||||
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
||||||
# ----------------------------------------------------------------------
|
# ----------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -1540,7 +1540,7 @@ class TestLinalg(TestCase):
|
||||||
@skipCUDAIfNoMagma
|
@skipCUDAIfNoMagma
|
||||||
@skipCPUIfNoLapack
|
@skipCPUIfNoLapack
|
||||||
@dtypes(torch.cfloat, torch.cdouble)
|
@dtypes(torch.cfloat, torch.cdouble)
|
||||||
@precisionOverride({torch.cfloat: 2e-4})
|
@precisionOverride({torch.cfloat: 5e-4})
|
||||||
def test_norm_complex(self, device, dtype):
|
def test_norm_complex(self, device, dtype):
|
||||||
def gen_error_message(input_size, ord, keepdim, dim=None):
|
def gen_error_message(input_size, ord, keepdim, dim=None):
|
||||||
return "complex norm failed for input size %s, ord=%s, keepdim=%s, dim=%s" % (
|
return "complex norm failed for input size %s, ord=%s, keepdim=%s, dim=%s" % (
|
||||||
|
|
@ -2476,28 +2476,6 @@ class TestLinalg(TestCase):
|
||||||
result = torch.linalg.svd(a, full_matrices=False)
|
result = torch.linalg.svd(a, full_matrices=False)
|
||||||
self.assertEqual(result.S, S)
|
self.assertEqual(result.S, S)
|
||||||
|
|
||||||
# This test doesn't work with MAGMA backend https://github.com/pytorch/pytorch/issues/72106
|
|
||||||
@skipMeta
|
|
||||||
@skipCUDAIfRocm
|
|
||||||
@skipCUDAIfNoCusolver
|
|
||||||
@skipCPUIfNoLapack
|
|
||||||
@dtypes(*floating_and_complex_types())
|
|
||||||
def test_svd_nan_error(self, device, dtype):
|
|
||||||
for svd in [torch.svd, torch.linalg.svd]:
|
|
||||||
# if input contains NaN then an error is triggered for svd
|
|
||||||
# When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan.
|
|
||||||
# When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
|
|
||||||
error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|The algorithm failed to converge)'
|
|
||||||
a = torch.full((3, 3), float('nan'), dtype=dtype, device=device)
|
|
||||||
a[0] = float('nan')
|
|
||||||
with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg):
|
|
||||||
svd(a)
|
|
||||||
error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|\(Batch element 1\): The algorithm failed to converge)'
|
|
||||||
a = torch.randn(3, 33, 33, dtype=dtype, device=device)
|
|
||||||
a[1, 0, 0] = float('nan')
|
|
||||||
with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg):
|
|
||||||
svd(a)
|
|
||||||
|
|
||||||
def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype):
|
def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype):
|
||||||
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
|
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user