[CUDA][cuBLAS] addmm -- some refactoring for easier navigation between the Lt and non-Lt paths (#163955)

As per title. Additionally, some Lt selection conditions are revisited, and some redundancy removed (especially in the ROCm vs non-ROCm paths).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163955
Approved by: https://github.com/ngimel, https://github.com/eqy
This commit is contained in:
Nikita Vedeneev 2025-10-21 20:48:09 +00:00 committed by PyTorch MergeBot
parent 830e789a55
commit 2f38eece7c
2 changed files with 259 additions and 253 deletions

View File

@ -272,28 +272,110 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
}
}
static bool getDisableAddmmCudaLt() {
static const auto env_value = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
if (env_value == "1") {
return true;
}
return false;
/*
* Checks whether DISABLE_ADDMM_CUDA_LT is set.
* Additionally, for ROCM we test whether the architecture supports the Lt.
*/
static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
// When hipBLASLt is not supported on the architecture, return true
#ifdef USE_ROCM
static const std::vector<std::string> archs = {
"gfx90a", "gfx942",
#if ROCM_VERSION >= 60300
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
#endif
#if ROCM_VERSION >= 70000
"gfx950", "gfx1150", "gfx1151"
#endif
};
const auto is_hipblas_lt_arch_supported = at::detail::getCUDAHooks().isGPUArch(archs, device.index());
if (!is_hipblas_lt_arch_supported) {
return true;
}
#endif
// Check whether it is disabled in the env
static const auto is_addmm_cuda_lt_disabled = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
if (is_addmm_cuda_lt_disabled == "1") {
return true;
}
return false;
}
#ifdef USE_ROCM
static bool isSupportedHipLtROCmArch(int index) {
static const std::vector<std::string> archs = {
"gfx90a", "gfx942",
#if ROCM_VERSION >= 60300
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
#endif
#if ROCM_VERSION >= 70000
"gfx950", "gfx1150", "gfx1151"
#endif
};
return at::detail::getCUDAHooks().isGPUArch(archs, index);
/*
* Check whether for the given input we want to enable the Lt interface
*/
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
// Implies 2D bias which we currently not send through Lt.
// TODO: this check is done pre col-major input preparation,
// so, this condition can be ralexed in cases when a col-major
// copy of result is needed.
if (result.is_same(self)) {
return false;
}
#if defined(USE_ROCM) && ROCM_VERSION == 60400
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
const auto args = cublasCommonArgs(mat1, mat2, result);
if (args.transa == 't' && args.transb == 't') {
return false;
}
#endif
const auto mat1_sizes = mat1.sizes();
const auto mat2_sizes = mat2.sizes();
#if defined(CUDA_VERSION) || defined(USE_ROCM)
const auto scalar_type = mat1.scalar_type();
return (beta.toComplexDouble() == 1.0
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
&& result.dim() == 2 && result.is_contiguous()
&& ( // some dtype restrictions
#ifndef USE_ROCM
scalar_type == at::ScalarType::Double ||
#endif
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16
)
&& ( // some shape/stride restrictions
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// NOTE: extension to mat1 because mat1/mat2 can be swapped based off
// their row-/col-majorness.
mat1_sizes[0] > 1 && mat1_sizes[1] > 1 &&
mat2_sizes[0] > 1 && mat2_sizes[1] > 1
// The last conditions is to skip 16b transA and non-trans-B having
// leading dim >> rows when they are sliced from a large tensor
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
// Related to avoiding the leading stride >> leading dim problematic case
// with 16b dtypes described above. For such dtypes we only allow inputs
// which are either row- or col-major (i.e. non-overlapping, compact memory layout).
// In that case the leading stride will be equal to the outer dim len.
// Why do we catch this case here? The following `prepare_matrix_for_cublas` method
// does not modify inputs as long as there is a stride of length 1
// and the leading stride is at least max(1, other dim length), so we might
// end up with contiguous cols but not rows (i.e. holes between different rows)
// and vice versa.
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
&& (
// filter by dtype
(scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) ||
// check mat1/mat2 is row-/col-major
(mat1.is_non_overlapping_and_dense() && mat2.is_non_overlapping_and_dense())
)
#endif
)
);
#endif
// no compliance by default
return false;
}
#endif
template <typename scalar_t>
void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
@ -335,7 +417,70 @@ void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const
}
}
template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmAndBiasCublasLt(
// args contains result which is modified
cublasCommonArgs& args,
const Tensor& self,
const Scalar& alpha,
Activation activation = Activation::None
) {
const auto* self_ptr = self.const_data_ptr<scalar_t>();
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
// TODO: maybe also return some success state?
launchTunableGemmAndBias<scalar_t>(
args, alpha, self_ptr, activation_to_gemm_and_blas_arg(activation)
);
return true;
}
return at::cuda::blas::gemm_and_bias<scalar_t, res_scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self_ptr,
args.result->data_ptr<res_scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
}
template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmCublas(
// args contains result which is modified
cublasCommonArgs& args,
const Scalar& alpha,
const Scalar& beta
) {
at::cuda::blas::gemm<scalar_t, res_scalar_t>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
beta.to<at::opmath_type<scalar_t>>(),
args.result->data_ptr<res_scalar_t>(),
args.result_ld
);
return true; // success!
}
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) {
// Shape checks {
// Make sure to keep addmm_cuda below in sync with this code; it
// preflights a check to try to avoid actually needing to call
// expand().
@ -345,105 +490,62 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
"expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
)
if (result.is_same(self)) {
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self.sizes()[0] == mat1.sizes()[0], "self dim 0 must match mat1 dim 0");
TORCH_CHECK(self.sizes()[1] == mat2.sizes()[1], "self dim 1 must match mat2 dim 1");
}
// } Shape checks
// NOLINTNEXTLINE(*c-array*)
TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
checkAllSameGPU(__func__, targs);
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
IntArrayRef self__sizes;
bool useLtInterface = false;
#if defined(USE_ROCM)
// When hipBLASLt is not supported on the architecture,
// disable_addmm_cuda_lt will always be to set to true
static bool disable_addmm_cuda_lt =
!isSupportedHipLtROCmArch(self.device().index()) || getDisableAddmmCudaLt();
#else
static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt();
#endif
// Handle whether to use the Lt interface {
static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device());
// if lt path fails, we recurse back into this function here and force the lt path to off
// we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent
bool disable_addmm_cuda_lt_final = disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#if defined(USE_ROCM) && ROCM_VERSION == 60400
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
cublasCommonArgs _args(mat1, mat2, result);
if (_args.transa == 't' && _args.transb == 't') {
disable_addmm_cuda_lt_final = true;
}
#endif
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#ifdef USE_ROCM
// Conditioned on the device index, which is not persistent
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
#endif
// Condition on the input
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
// }
at::ScalarType scalar_type = mat1.scalar_type();
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
#if defined(CUDA_VERSION) || defined(USE_ROCM)
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
// for cuda 11.4, cublasLtMatmul is activated
// the last two conditions is to skip 16b transA and non-trans-B having
// leading dim >> rows when they are sliced from a large tensor
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
if (!disable_addmm_cuda_lt_final) {
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
self.is_contiguous() && result.is_contiguous() &&
#ifdef USE_ROCM
(scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16) &&
#else
(scalar_type == at::ScalarType::Double ||
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16) &&
#endif
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
#else
mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
// avoid leading dim >> rows bugs
((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) ||
(mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) ||
(scalar_type != at::ScalarType::Half &&
scalar_type != at::ScalarType::BFloat16)) &&
((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) ||
(mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) ||
(scalar_type != at::ScalarType::Half &&
scalar_type != at::ScalarType::BFloat16));
#endif
}
#endif
if (!useLtInterface) {
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
}
self__sizes = self_->sizes();
} else {
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
self__sizes = self_->sizes();
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0");
TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1");
}
if (&result != &self) {
at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
at::native::copy_(result, *self_);
// Handle result/self shapes
if (!result.is_same(self)) {
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
if (disable_addmm_cuda_lt) {
// When in non-Lt path we do expand self even before
// check for beta != 0.0 to make sure that
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
// runs green.
return expand_size(self, result.sizes(), "addmm");
}
// copy next, should broadcast
return c10::MaybeOwned<Tensor>::borrowed(self);
}();
// We copy bias when in the non-Lt path
if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) {
// NOTE: self should broadcast over result
at::native::copy_(result, *self_maybe_expanded);
}
}
IntArrayRef result_sizes = result.sizes();
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
// Short circuit on empty result
if (result.numel() == 0) {
return result;
}
cublasCommonArgs args(mat1, mat2, result);
if (mat1.numel() == 0) {
// Short circuit if the reduction dim is empty
if (mat1.sizes()[1] == 0) {
// By definition, when beta==0, values in self should be ignored. nans and infs
// should not propagate
if (beta.toComplexDouble() == 0.) {
@ -455,158 +557,64 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
result,
self.expand(result.sizes()),
at::native::scalar_tensor(
beta,
self.scalar_type(),
std::nullopt /* layout */,
at::kCPU,
std::nullopt /* pin_memory */));
beta,
self.scalar_type(),
std::nullopt /* layout */,
at::kCPU,
std::nullopt /* pin_memory */
)
);
}
cublasCommonArgs args(mat1, mat2, result);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
if (useLtInterface) {
#if defined(USE_ROCM)
bool okay = true;
// The Lt path
if (!disable_addmm_cuda_lt) {
bool lt_success = false;
if (is_float_output_with_half_input) {
#ifdef USE_ROCM
TORCH_CHECK(false, "float output with half input is not enabled for ROCm");
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
activation_to_gemm_and_blas_arg(activation));
} else {
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
// This condition is needed for mm case on ROCm for hipblasLt path.
// Passing the bias ptr as null to avoid accuracy issues for mm case.
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
}
});
}
if (!okay) {
// lt path failed; recurse but disable lt path
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
}
#else
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
bool okay = true;
if (is_float_output_with_half_input) {
#else
if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) {
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
}
else {
okay = at::cuda::blas::gemm_and_bias<scalar_t, float>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<float>(),
args.result_ld,
activation_epilogue
);
}});
);
#endif
} else {
// !is_float_output_with_half_input
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
self.const_data_ptr<scalar_t>(),
activation_epilogue);
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
}
else {
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_epilogue
);
}});
}
if (!okay) {
// lt path failed; recurse but disable lt path
);
} // end is_float_output_with_half_input
if (!lt_success) {
// lt path failed; recurse but disable lt path
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
}
#endif
} else
{
// end Lt path
} else {
// No Lt, we use a GEMM instead
if (is_float_output_with_half_input) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda",
[&] {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
float* result_ptr = args.result->mutable_data_ptr<float>();
at::cuda::blas::gemm<scalar_t, float>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha_val,
mat1_ptr,
args.lda,
mat2_ptr,
args.ldb,
beta_val,
result_ptr,
args.result_ld);
});
launchGemmCublas<scalar_t, float>(args, alpha, beta);
}
);
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half,
@ -614,28 +622,12 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type,
"addmm_cuda",
[&] {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>();
at::cuda::blas::gemm<scalar_t>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha_val,
mat1_ptr,
args.lda,
mat2_ptr,
args.ldb,
beta_val,
result_ptr,
args.result_ld);
});
launchGemmCublas<scalar_t>(args, alpha, beta);
}
);
}
// Apply epilogue
switch (activation) {
case Activation::RELU:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
@ -647,14 +639,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
break;
default: break;
}
}
} // end GEMM path
// Preprocessor gate here needs to match the inverse of the check
// gating activation_to_gemm_and_blas_arg above; here we are manually
// performing a post-GELU because we weren't able to use the GELU
// epilogue above.
#if !defined(CUDA_VERSION) && !defined(USE_ROCM)
if (useLtInterface && activation == Activation::GELU) {
if (!disable_addmm_cuda_lt && activation == Activation::GELU) {
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
}
#endif

View File

@ -67,7 +67,21 @@ class TestFullyShardMemory(FSDPTest):
# allocate the cuBLAS workspaces before measuring the memory usage
# since the workspace size can differ between hardwares
lin = torch.nn.Linear(768, 768, device=device_type)
inp = torch.randn(1, 768, device=device_type)
# NOTE: before https://github.com/pytorch/pytorch/pull/163955,
# the input shape was (1, 768), so that the forward gemm used
# cublaslt, and the backward used cublas.
# With the aforementioned PR, and with shape (1, 768),
# the cublas path is used both in forward and in backward,
# altering peak memory usage not accounting for cublaslt.
# Here we change the input shape to (2, 768), and that swaps
# the cublas/cublaslt selection in the forward/backward,
# but that does not affect the peak memory usage stored in `base_mem_mb`.
# Reasons for the flip:
# before PR: no Lt in addmm when mat2 has nrows/ncols <= 1,
# after PR: no Lt in addmm when either mat1 or mat2 have nrows/ncols <= 1,
# since the input preparation can swap matrices based on output
# row-/col-majorness.
inp = torch.randn(2, 768, device=device_type)
lin(inp).sum().backward()
torch.get_device_module(device_type).empty_cache()
base_mem_mb = self._get_peak_active_memory_mb()