mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Back out "[pytorch][PR] Performance and memory improvements to batched torch.linalg.solve" (#71421)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71421 Original commit changeset: 7a0dd443cd0e Original Phabricator Diff: D33028236 (410e91adee) Test Plan: PyTorch OSS CI Reviewed By: ngimel Differential Revision: D33637628 fbshipit-source-id: 1e81485be202b2f9d6a1ff315279cc099754c2dc (cherry picked from commitc2d730bfeb)
This commit is contained in:
parent
8a9243996c
commit
a0ada2d22b
|
|
@ -908,8 +908,8 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor
|
||||||
|
|
||||||
// _linalg_broadcast_batch_dims also includes linearSolveCheckInputs
|
// _linalg_broadcast_batch_dims also includes linearSolveCheckInputs
|
||||||
// it checks for squareness of 'input' and 'shape' compatibility of 'other' and 'input'
|
// it checks for squareness of 'input' and 'shape' compatibility of 'other' and 'input'
|
||||||
Tensor other_broadcasted;
|
Tensor other_broadcasted, input_broadcasted;
|
||||||
std::tie(other_broadcasted, std::ignore) = _linalg_broadcast_batch_dims(other_, input, "linalg.solve");
|
std::tie(other_broadcasted, input_broadcasted) = _linalg_broadcast_batch_dims(other_, input, "linalg.solve");
|
||||||
|
|
||||||
auto squeezed_other_broadcasted = at::squeeze(other_broadcasted, -1);
|
auto squeezed_other_broadcasted = at::squeeze(other_broadcasted, -1);
|
||||||
auto squeezed_result_shape = squeezed_other_broadcasted.sizes();
|
auto squeezed_result_shape = squeezed_other_broadcasted.sizes();
|
||||||
|
|
@ -945,17 +945,18 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor
|
||||||
// lu_factor_stub+lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted'
|
// lu_factor_stub+lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted'
|
||||||
result.copy_(other_broadcasted);
|
result.copy_(other_broadcasted);
|
||||||
|
|
||||||
|
auto input_working_copy = cloneBatchedColumnMajor(input_broadcasted);
|
||||||
|
|
||||||
TORCH_INTERNAL_ASSERT(infos.scalar_type() == kInt);
|
TORCH_INTERNAL_ASSERT(infos.scalar_type() == kInt);
|
||||||
TORCH_INTERNAL_ASSERT(infos.device() == input.device());
|
TORCH_INTERNAL_ASSERT(infos.device() == input.device());
|
||||||
infos.resize_({std::max<int64_t>(1, batchCount(input))});
|
infos.resize_({std::max<int64_t>(1, batchCount(input_broadcasted))});
|
||||||
// if input is empty infos might not get filled; make sure infos doesn't contain garbage then
|
// if input is empty infos might not get filled; make sure infos doesn't contain garbage then
|
||||||
if (input.numel() == 0) {
|
if (input.numel() == 0) {
|
||||||
infos.fill_(0);
|
infos.fill_(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute the LU factorization of 'input_working_copy'
|
// compute the LU factorization of 'input_working_copy'
|
||||||
auto input_working_copy = cloneBatchedColumnMajor(input);
|
auto pivots_shape = IntArrayRef(input_broadcasted.sizes().data(), input_broadcasted.dim() - 2).vec(); // input_broadcasted.shape[:-2]
|
||||||
auto pivots_shape = IntArrayRef(input.sizes().data(), input.dim() - 2).vec(); // input.shape[:-2]
|
|
||||||
pivots_shape.push_back(std::min(input.size(-2), input.size(-1)));
|
pivots_shape.push_back(std::min(input.size(-2), input.size(-1)));
|
||||||
Tensor pivots = at::empty(pivots_shape, input.options().dtype(kInt));
|
Tensor pivots = at::empty(pivots_shape, input.options().dtype(kInt));
|
||||||
lu_factor_stub(input.device().type(), input_working_copy, pivots, infos, /*compute_pivots=*/true);
|
lu_factor_stub(input.device().type(), input_working_copy, pivots, infos, /*compute_pivots=*/true);
|
||||||
|
|
@ -978,7 +979,8 @@ Tensor& linalg_solve_out(const Tensor& input, const Tensor& other, Tensor& resul
|
||||||
|
|
||||||
// Now check LAPACK/MAGMA error codes
|
// Now check LAPACK/MAGMA error codes
|
||||||
// batchCheckErrors(Tensor, char*) calls 'infos = infos.to(kCPU)'
|
// batchCheckErrors(Tensor, char*) calls 'infos = infos.to(kCPU)'
|
||||||
if (input.dim() > 2) {
|
bool vector_case = linalg_solve_is_vector_rhs(input, other);
|
||||||
|
if (vector_case ? result.dim() > 1 : result.dim() > 2) {
|
||||||
batchCheckErrors(infos, "linalg.solve");
|
batchCheckErrors(infos, "linalg.solve");
|
||||||
} else {
|
} else {
|
||||||
singleCheckErrors(infos.item().toInt(), "linalg.solve");
|
singleCheckErrors(infos.item().toInt(), "linalg.solve");
|
||||||
|
|
|
||||||
|
|
@ -908,8 +908,8 @@ void apply_lu_solve(const Tensor& b, const Tensor& lu, const Tensor& pivots, Tra
|
||||||
const auto trans = to_blas(transpose);
|
const auto trans = to_blas(transpose);
|
||||||
auto pivots_data = pivots.data_ptr<int>();
|
auto pivots_data = pivots.data_ptr<int>();
|
||||||
auto b_stride = matrixStride(b);
|
auto b_stride = matrixStride(b);
|
||||||
auto lu_stride = lu.dim() > 2 ? lu.stride(-3) : 0;
|
auto lu_stride = matrixStride(lu);
|
||||||
auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
|
auto pivots_stride = pivots.size(-1);
|
||||||
auto batch_size = batchCount(b);
|
auto batch_size = batchCount(b);
|
||||||
|
|
||||||
auto n = lu.size(-2);
|
auto n = lu.size(-2);
|
||||||
|
|
|
||||||
|
|
@ -2838,8 +2838,8 @@ static void apply_lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const
|
||||||
auto pivots_data = pivots_cpu.data_ptr<magma_int_t>();
|
auto pivots_data = pivots_cpu.data_ptr<magma_int_t>();
|
||||||
|
|
||||||
auto b_stride = matrixStride(b);
|
auto b_stride = matrixStride(b);
|
||||||
auto lu_stride = lu.dim() > 2 ? lu.stride(-3) : 0;
|
auto lu_stride = matrixStride(lu);
|
||||||
auto pivots_stride = pivots_cpu.dim() > 1 ? pivots_cpu.stride(-2) : 0;
|
auto pivots_stride = pivots_cpu.size(-1);
|
||||||
auto batch_size = batchCount(b);
|
auto batch_size = batchCount(b);
|
||||||
|
|
||||||
magma_int_t n = magma_int_cast(lu.size(-2), "n");
|
magma_int_t n = magma_int_cast(lu.size(-2), "n");
|
||||||
|
|
@ -2883,8 +2883,6 @@ static void apply_lu_solve_batched_magma(const Tensor& b, const Tensor& lu, cons
|
||||||
"Calling torch.lu_solve on a CUDA tensor requires compiling ",
|
"Calling torch.lu_solve on a CUDA tensor requires compiling ",
|
||||||
"PyTorch with MAGMA. Please rebuild with MAGMA.");
|
"PyTorch with MAGMA. Please rebuild with MAGMA.");
|
||||||
#else
|
#else
|
||||||
TORCH_INTERNAL_ASSERT(batchCount(b) == batchCount(lu), "batch_size of b and lu must be the same");
|
|
||||||
TORCH_INTERNAL_ASSERT(batchCount(lu) == batchCount(pivots.unsqueeze(-1)), "batch_size of lu and pivots must be the same");
|
|
||||||
auto trans = to_magma(transpose);
|
auto trans = to_magma(transpose);
|
||||||
auto b_data = b.data_ptr<scalar_t>();
|
auto b_data = b.data_ptr<scalar_t>();
|
||||||
auto lu_data = lu.data_ptr<scalar_t>();
|
auto lu_data = lu.data_ptr<scalar_t>();
|
||||||
|
|
@ -2951,36 +2949,9 @@ static void lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const Tenso
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
c10::MaybeOwned<Tensor> maybe_expand_lu(const Tensor& b, const Tensor& lu) {
|
|
||||||
if (batchCount(b) != batchCount(lu)) {
|
|
||||||
IntArrayRef b_batch_size(b.sizes().data(), b.dim() - 2);
|
|
||||||
std::vector<int64_t> expand_size = b_batch_size.vec();
|
|
||||||
expand_size.insert(expand_size.end(), {lu.size(-2), lu.size(-1)});
|
|
||||||
return c10::MaybeOwned<Tensor>::owned(
|
|
||||||
cloneBatchedColumnMajor(lu.expand(expand_size)));
|
|
||||||
} else {
|
|
||||||
return c10::MaybeOwned<Tensor>::borrowed(lu);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c10::MaybeOwned<Tensor> maybe_expand_pivots(const Tensor& b,const Tensor& pivots) {
|
|
||||||
if (batchCount(b) != batchCount(pivots.unsqueeze(-1))) {
|
|
||||||
IntArrayRef b_batch_size(b.sizes().data(), b.dim() - 2);
|
|
||||||
std::vector<int64_t> expand_size = b_batch_size.vec();
|
|
||||||
expand_size.insert(expand_size.end(), {pivots.size(-1)});
|
|
||||||
return c10::MaybeOwned<Tensor>::owned(
|
|
||||||
pivots.expand(expand_size).clone(at::MemoryFormat::Contiguous));
|
|
||||||
} else {
|
|
||||||
return c10::MaybeOwned<Tensor>::borrowed(pivots);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) {
|
static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) {
|
||||||
auto batch_size = batchCount(b);
|
auto batch_size = batchCount(lu);
|
||||||
auto m = lu.size(-2);
|
auto m = lu.size(-2);
|
||||||
auto b2 = b.size(-1);
|
auto b2 = b.size(-1);
|
||||||
bool over_magma_dim_limit = b2 > 1024; // magma implementation of LU solve cannot handle a b tensor with last dim > 1024 (https://bitbucket.org/icl/magma/issues/19/dgesv_batched-dgetrs_batched-fails-for)
|
bool over_magma_dim_limit = b2 > 1024; // magma implementation of LU solve cannot handle a b tensor with last dim > 1024 (https://bitbucket.org/icl/magma/issues/19/dgesv_batched-dgetrs_batched-fails-for)
|
||||||
|
|
@ -2996,15 +2967,11 @@ static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Ten
|
||||||
#endif // ifdef USE_CUSOLVER
|
#endif // ifdef USE_CUSOLVER
|
||||||
#ifdef CUDART_VERSION
|
#ifdef CUDART_VERSION
|
||||||
else if ((batch_size > 2 && m <= 128) || (batch_size > 8 && over_magma_dim_limit)) {
|
else if ((batch_size > 2 && m <= 128) || (batch_size > 8 && over_magma_dim_limit)) {
|
||||||
c10::MaybeOwned<Tensor> lu_ = maybe_expand_lu(b, lu);
|
lu_solve_batched_cublas(b, lu, pivots, trans);
|
||||||
c10::MaybeOwned<Tensor> pivots_ = maybe_expand_pivots(b, pivots);
|
|
||||||
lu_solve_batched_cublas(b, *lu_, *pivots_, trans);
|
|
||||||
}
|
}
|
||||||
#endif // ifdef CUDART_VERSION
|
#endif // ifdef CUDART_VERSION
|
||||||
else {
|
else {
|
||||||
c10::MaybeOwned<Tensor> lu_ = maybe_expand_lu(b, lu);
|
lu_solve_batched_magma(b, lu, pivots, trans);
|
||||||
c10::MaybeOwned<Tensor> pivots_ = maybe_expand_pivots(b, pivots);
|
|
||||||
lu_solve_batched_magma(b, *lu_, *pivots_, trans);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -84,8 +84,6 @@ static void apply_lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, con
|
||||||
#ifndef CUDART_VERSION
|
#ifndef CUDART_VERSION
|
||||||
TORCH_CHECK(false, "lu_solve: cuBLAS backend for lu_solve is not available.")
|
TORCH_CHECK(false, "lu_solve: cuBLAS backend for lu_solve is not available.")
|
||||||
#else
|
#else
|
||||||
TORCH_INTERNAL_ASSERT(batchCount(b) == batchCount(lu), "batch_size of b and lu must be the same");
|
|
||||||
TORCH_INTERNAL_ASSERT(batchCount(lu) == batchCount(pivots.unsqueeze(-1)), "batch_size of lu and pivots must be the same");
|
|
||||||
const auto trans = to_cublas(transpose);
|
const auto trans = to_cublas(transpose);
|
||||||
|
|
||||||
auto pivots_data = pivots.data_ptr<int>();
|
auto pivots_data = pivots.data_ptr<int>();
|
||||||
|
|
@ -1469,14 +1467,14 @@ void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& p
|
||||||
const auto trans = to_cublas(transpose);
|
const auto trans = to_cublas(transpose);
|
||||||
int n = cuda_int_cast(lu.size(-2), "n");
|
int n = cuda_int_cast(lu.size(-2), "n");
|
||||||
int nrhs = cuda_int_cast(b.size(-1), "nrhs");
|
int nrhs = cuda_int_cast(b.size(-1), "nrhs");
|
||||||
auto batch_size = batchCount(b);
|
auto batch_size = batchCount(lu);
|
||||||
auto info = at::zeros({1}, lu.options().dtype(kInt));
|
auto info = at::zeros({1}, lu.options().dtype(kInt));
|
||||||
auto info_data = info.data_ptr<int>();
|
auto info_data = info.data_ptr<int>();
|
||||||
auto b_data = b.data_ptr<scalar_t>();
|
auto b_data = b.data_ptr<scalar_t>();
|
||||||
auto lu_data = lu.data_ptr<scalar_t>();
|
auto lu_data = lu.data_ptr<scalar_t>();
|
||||||
auto pivots_data = pivots.data_ptr<int>();
|
auto pivots_data = pivots.data_ptr<int>();
|
||||||
auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
|
auto pivots_stride = pivots.size(-1);
|
||||||
auto lu_stride = lu.dim() > 2 ? lu.stride(-3) : 0;
|
auto lu_stride = matrixStride(lu);
|
||||||
auto b_stride = matrixStride(b);
|
auto b_stride = matrixStride(b);
|
||||||
int leading_dimension = cuda_int_cast(std::max<int>(1, n), "leading_dimension");
|
int leading_dimension = cuda_int_cast(std::max<int>(1, n), "leading_dimension");
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user