Back out "[pytorch][PR] Performance and memory improvements to batched torch.linalg.solve" (#71421)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71421

Original commit changeset: 7a0dd443cd0e

Original Phabricator Diff: D33028236 (410e91adee)

Test Plan: PyTorch OSS CI

Reviewed By: ngimel

Differential Revision: D33637628

fbshipit-source-id: 1e81485be202b2f9d6a1ff315279cc099754c2dc
(cherry picked from commit c2d730bfeb)
This commit is contained in:
Mike Ruberry 2022-01-19 09:21:50 -08:00 committed by PyTorch MergeBot
parent 8a9243996c
commit a0ada2d22b
4 changed files with 18 additions and 51 deletions

View File

@ -908,8 +908,8 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor
// _linalg_broadcast_batch_dims also includes linearSolveCheckInputs
// it checks for squareness of 'input' and 'shape' compatibility of 'other' and 'input'
Tensor other_broadcasted;
std::tie(other_broadcasted, std::ignore) = _linalg_broadcast_batch_dims(other_, input, "linalg.solve");
Tensor other_broadcasted, input_broadcasted;
std::tie(other_broadcasted, input_broadcasted) = _linalg_broadcast_batch_dims(other_, input, "linalg.solve");
auto squeezed_other_broadcasted = at::squeeze(other_broadcasted, -1);
auto squeezed_result_shape = squeezed_other_broadcasted.sizes();
@ -945,17 +945,18 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor
// lu_factor_stub+lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted'
result.copy_(other_broadcasted);
auto input_working_copy = cloneBatchedColumnMajor(input_broadcasted);
TORCH_INTERNAL_ASSERT(infos.scalar_type() == kInt);
TORCH_INTERNAL_ASSERT(infos.device() == input.device());
infos.resize_({std::max<int64_t>(1, batchCount(input))});
infos.resize_({std::max<int64_t>(1, batchCount(input_broadcasted))});
// if input is empty infos might not get filled; make sure infos doesn't contain garbage then
if (input.numel() == 0) {
infos.fill_(0);
}
// compute the LU factorization of 'input_working_copy'
auto input_working_copy = cloneBatchedColumnMajor(input);
auto pivots_shape = IntArrayRef(input.sizes().data(), input.dim() - 2).vec(); // input.shape[:-2]
auto pivots_shape = IntArrayRef(input_broadcasted.sizes().data(), input_broadcasted.dim() - 2).vec(); // input_broadcasted.shape[:-2]
pivots_shape.push_back(std::min(input.size(-2), input.size(-1)));
Tensor pivots = at::empty(pivots_shape, input.options().dtype(kInt));
lu_factor_stub(input.device().type(), input_working_copy, pivots, infos, /*compute_pivots=*/true);
@ -978,7 +979,8 @@ Tensor& linalg_solve_out(const Tensor& input, const Tensor& other, Tensor& resul
// Now check LAPACK/MAGMA error codes
// batchCheckErrors(Tensor, char*) calls 'infos = infos.to(kCPU)'
if (input.dim() > 2) {
bool vector_case = linalg_solve_is_vector_rhs(input, other);
if (vector_case ? result.dim() > 1 : result.dim() > 2) {
batchCheckErrors(infos, "linalg.solve");
} else {
singleCheckErrors(infos.item().toInt(), "linalg.solve");

View File

@ -908,8 +908,8 @@ void apply_lu_solve(const Tensor& b, const Tensor& lu, const Tensor& pivots, Tra
const auto trans = to_blas(transpose);
auto pivots_data = pivots.data_ptr<int>();
auto b_stride = matrixStride(b);
auto lu_stride = lu.dim() > 2 ? lu.stride(-3) : 0;
auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
auto lu_stride = matrixStride(lu);
auto pivots_stride = pivots.size(-1);
auto batch_size = batchCount(b);
auto n = lu.size(-2);

View File

@ -2838,8 +2838,8 @@ static void apply_lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const
auto pivots_data = pivots_cpu.data_ptr<magma_int_t>();
auto b_stride = matrixStride(b);
auto lu_stride = lu.dim() > 2 ? lu.stride(-3) : 0;
auto pivots_stride = pivots_cpu.dim() > 1 ? pivots_cpu.stride(-2) : 0;
auto lu_stride = matrixStride(lu);
auto pivots_stride = pivots_cpu.size(-1);
auto batch_size = batchCount(b);
magma_int_t n = magma_int_cast(lu.size(-2), "n");
@ -2883,8 +2883,6 @@ static void apply_lu_solve_batched_magma(const Tensor& b, const Tensor& lu, cons
"Calling torch.lu_solve on a CUDA tensor requires compiling ",
"PyTorch with MAGMA. Please rebuild with MAGMA.");
#else
TORCH_INTERNAL_ASSERT(batchCount(b) == batchCount(lu), "batch_size of b and lu must be the same");
TORCH_INTERNAL_ASSERT(batchCount(lu) == batchCount(pivots.unsqueeze(-1)), "batch_size of lu and pivots must be the same");
auto trans = to_magma(transpose);
auto b_data = b.data_ptr<scalar_t>();
auto lu_data = lu.data_ptr<scalar_t>();
@ -2951,36 +2949,9 @@ static void lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const Tenso
});
}
namespace {
c10::MaybeOwned<Tensor> maybe_expand_lu(const Tensor& b, const Tensor& lu) {
if (batchCount(b) != batchCount(lu)) {
IntArrayRef b_batch_size(b.sizes().data(), b.dim() - 2);
std::vector<int64_t> expand_size = b_batch_size.vec();
expand_size.insert(expand_size.end(), {lu.size(-2), lu.size(-1)});
return c10::MaybeOwned<Tensor>::owned(
cloneBatchedColumnMajor(lu.expand(expand_size)));
} else {
return c10::MaybeOwned<Tensor>::borrowed(lu);
}
}
c10::MaybeOwned<Tensor> maybe_expand_pivots(const Tensor& b,const Tensor& pivots) {
if (batchCount(b) != batchCount(pivots.unsqueeze(-1))) {
IntArrayRef b_batch_size(b.sizes().data(), b.dim() - 2);
std::vector<int64_t> expand_size = b_batch_size.vec();
expand_size.insert(expand_size.end(), {pivots.size(-1)});
return c10::MaybeOwned<Tensor>::owned(
pivots.expand(expand_size).clone(at::MemoryFormat::Contiguous));
} else {
return c10::MaybeOwned<Tensor>::borrowed(pivots);
}
}
} // anonymous namespace
static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) {
auto batch_size = batchCount(b);
auto batch_size = batchCount(lu);
auto m = lu.size(-2);
auto b2 = b.size(-1);
bool over_magma_dim_limit = b2 > 1024; // magma implementation of LU solve cannot handle a b tensor with last dim > 1024 (https://bitbucket.org/icl/magma/issues/19/dgesv_batched-dgetrs_batched-fails-for)
@ -2996,15 +2967,11 @@ static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Ten
#endif // ifdef USE_CUSOLVER
#ifdef CUDART_VERSION
else if ((batch_size > 2 && m <= 128) || (batch_size > 8 && over_magma_dim_limit)) {
c10::MaybeOwned<Tensor> lu_ = maybe_expand_lu(b, lu);
c10::MaybeOwned<Tensor> pivots_ = maybe_expand_pivots(b, pivots);
lu_solve_batched_cublas(b, *lu_, *pivots_, trans);
lu_solve_batched_cublas(b, lu, pivots, trans);
}
#endif // ifdef CUDART_VERSION
else {
c10::MaybeOwned<Tensor> lu_ = maybe_expand_lu(b, lu);
c10::MaybeOwned<Tensor> pivots_ = maybe_expand_pivots(b, pivots);
lu_solve_batched_magma(b, *lu_, *pivots_, trans);
lu_solve_batched_magma(b, lu, pivots, trans);
}
}

View File

@ -84,8 +84,6 @@ static void apply_lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, con
#ifndef CUDART_VERSION
TORCH_CHECK(false, "lu_solve: cuBLAS backend for lu_solve is not available.")
#else
TORCH_INTERNAL_ASSERT(batchCount(b) == batchCount(lu), "batch_size of b and lu must be the same");
TORCH_INTERNAL_ASSERT(batchCount(lu) == batchCount(pivots.unsqueeze(-1)), "batch_size of lu and pivots must be the same");
const auto trans = to_cublas(transpose);
auto pivots_data = pivots.data_ptr<int>();
@ -1469,14 +1467,14 @@ void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& p
const auto trans = to_cublas(transpose);
int n = cuda_int_cast(lu.size(-2), "n");
int nrhs = cuda_int_cast(b.size(-1), "nrhs");
auto batch_size = batchCount(b);
auto batch_size = batchCount(lu);
auto info = at::zeros({1}, lu.options().dtype(kInt));
auto info_data = info.data_ptr<int>();
auto b_data = b.data_ptr<scalar_t>();
auto lu_data = lu.data_ptr<scalar_t>();
auto pivots_data = pivots.data_ptr<int>();
auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
auto lu_stride = lu.dim() > 2 ? lu.stride(-3) : 0;
auto pivots_stride = pivots.size(-1);
auto lu_stride = matrixStride(lu);
auto b_stride = matrixStride(b);
int leading_dimension = cuda_int_cast(std::max<int>(1, n), "leading_dimension");