mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[CUDA][cuBLAS] addmm -- some refactoring for easier navigation between the Lt and non-Lt paths (#163955)
As per title. Additionally, some Lt selection conditions are revisited, and some redundancy removed (especially in the ROCm vs non-ROCm paths). Pull Request resolved: https://github.com/pytorch/pytorch/pull/163955 Approved by: https://github.com/ngimel, https://github.com/eqy
This commit is contained in:
parent
830e789a55
commit
2f38eece7c
|
|
@ -272,28 +272,110 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
|
|||
}
|
||||
}
|
||||
|
||||
static bool getDisableAddmmCudaLt() {
|
||||
static const auto env_value = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
|
||||
if (env_value == "1") {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
/*
|
||||
* Checks whether DISABLE_ADDMM_CUDA_LT is set.
|
||||
* Additionally, for ROCM we test whether the architecture supports the Lt.
|
||||
*/
|
||||
static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
|
||||
// When hipBLASLt is not supported on the architecture, return true
|
||||
#ifdef USE_ROCM
|
||||
static const std::vector<std::string> archs = {
|
||||
"gfx90a", "gfx942",
|
||||
#if ROCM_VERSION >= 60300
|
||||
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
|
||||
#endif
|
||||
#if ROCM_VERSION >= 70000
|
||||
"gfx950", "gfx1150", "gfx1151"
|
||||
#endif
|
||||
};
|
||||
const auto is_hipblas_lt_arch_supported = at::detail::getCUDAHooks().isGPUArch(archs, device.index());
|
||||
if (!is_hipblas_lt_arch_supported) {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Check whether it is disabled in the env
|
||||
static const auto is_addmm_cuda_lt_disabled = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
|
||||
if (is_addmm_cuda_lt_disabled == "1") {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static bool isSupportedHipLtROCmArch(int index) {
|
||||
static const std::vector<std::string> archs = {
|
||||
"gfx90a", "gfx942",
|
||||
#if ROCM_VERSION >= 60300
|
||||
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
|
||||
#endif
|
||||
#if ROCM_VERSION >= 70000
|
||||
"gfx950", "gfx1150", "gfx1151"
|
||||
#endif
|
||||
};
|
||||
return at::detail::getCUDAHooks().isGPUArch(archs, index);
|
||||
/*
|
||||
* Check whether for the given input we want to enable the Lt interface
|
||||
*/
|
||||
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
|
||||
// Implies 2D bias which we currently not send through Lt.
|
||||
// TODO: this check is done pre col-major input preparation,
|
||||
// so, this condition can be ralexed in cases when a col-major
|
||||
// copy of result is needed.
|
||||
if (result.is_same(self)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION == 60400
|
||||
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
|
||||
const auto args = cublasCommonArgs(mat1, mat2, result);
|
||||
if (args.transa == 't' && args.transb == 't') {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
const auto mat1_sizes = mat1.sizes();
|
||||
const auto mat2_sizes = mat2.sizes();
|
||||
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
||||
const auto scalar_type = mat1.scalar_type();
|
||||
return (beta.toComplexDouble() == 1.0
|
||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
||||
// is to use lt interface only when self is bias.
|
||||
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
|
||||
&& result.dim() == 2 && result.is_contiguous()
|
||||
&& ( // some dtype restrictions
|
||||
#ifndef USE_ROCM
|
||||
scalar_type == at::ScalarType::Double ||
|
||||
#endif
|
||||
scalar_type == at::ScalarType::Float ||
|
||||
scalar_type == at::ScalarType::Half ||
|
||||
scalar_type == at::ScalarType::BFloat16
|
||||
)
|
||||
&& ( // some shape/stride restrictions
|
||||
// Strangely, if mat2 has only 1 row or column, we get
|
||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||
// NOTE: extension to mat1 because mat1/mat2 can be swapped based off
|
||||
// their row-/col-majorness.
|
||||
mat1_sizes[0] > 1 && mat1_sizes[1] > 1 &&
|
||||
mat2_sizes[0] > 1 && mat2_sizes[1] > 1
|
||||
// The last conditions is to skip 16b transA and non-trans-B having
|
||||
// leading dim >> rows when they are sliced from a large tensor
|
||||
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
|
||||
#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
|
||||
// Related to avoiding the leading stride >> leading dim problematic case
|
||||
// with 16b dtypes described above. For such dtypes we only allow inputs
|
||||
// which are either row- or col-major (i.e. non-overlapping, compact memory layout).
|
||||
// In that case the leading stride will be equal to the outer dim len.
|
||||
// Why do we catch this case here? The following `prepare_matrix_for_cublas` method
|
||||
// does not modify inputs as long as there is a stride of length 1
|
||||
// and the leading stride is at least max(1, other dim length), so we might
|
||||
// end up with contiguous cols but not rows (i.e. holes between different rows)
|
||||
// and vice versa.
|
||||
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
|
||||
&& (
|
||||
// filter by dtype
|
||||
(scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) ||
|
||||
// check mat1/mat2 is row-/col-major
|
||||
(mat1.is_non_overlapping_and_dense() && mat2.is_non_overlapping_and_dense())
|
||||
)
|
||||
#endif
|
||||
)
|
||||
);
|
||||
#endif
|
||||
|
||||
// no compliance by default
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename scalar_t>
|
||||
void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
|
||||
|
|
@ -335,7 +417,70 @@ void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const
|
|||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename res_scalar_t = scalar_t>
|
||||
bool launchGemmAndBiasCublasLt(
|
||||
// args contains result which is modified
|
||||
cublasCommonArgs& args,
|
||||
const Tensor& self,
|
||||
const Scalar& alpha,
|
||||
Activation activation = Activation::None
|
||||
) {
|
||||
const auto* self_ptr = self.const_data_ptr<scalar_t>();
|
||||
|
||||
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
// TODO: maybe also return some success state?
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args, alpha, self_ptr, activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
return at::cuda::blas::gemm_and_bias<scalar_t, res_scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
self_ptr,
|
||||
args.result->data_ptr<res_scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename res_scalar_t = scalar_t>
|
||||
bool launchGemmCublas(
|
||||
// args contains result which is modified
|
||||
cublasCommonArgs& args,
|
||||
const Scalar& alpha,
|
||||
const Scalar& beta
|
||||
) {
|
||||
at::cuda::blas::gemm<scalar_t, res_scalar_t>(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
beta.to<at::opmath_type<scalar_t>>(),
|
||||
args.result->data_ptr<res_scalar_t>(),
|
||||
args.result_ld
|
||||
);
|
||||
return true; // success!
|
||||
}
|
||||
|
||||
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) {
|
||||
// Shape checks {
|
||||
// Make sure to keep addmm_cuda below in sync with this code; it
|
||||
// preflights a check to try to avoid actually needing to call
|
||||
// expand().
|
||||
|
|
@ -345,105 +490,62 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|||
"expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
|
||||
)
|
||||
|
||||
if (result.is_same(self)) {
|
||||
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
|
||||
TORCH_CHECK(self.sizes()[0] == mat1.sizes()[0], "self dim 0 must match mat1 dim 0");
|
||||
TORCH_CHECK(self.sizes()[1] == mat2.sizes()[1], "self dim 1 must match mat2 dim 1");
|
||||
}
|
||||
// } Shape checks
|
||||
|
||||
// NOLINTNEXTLINE(*c-array*)
|
||||
TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
|
||||
checkAllSameGPU(__func__, targs);
|
||||
|
||||
IntArrayRef mat1_sizes = mat1.sizes();
|
||||
IntArrayRef mat2_sizes = mat2.sizes();
|
||||
IntArrayRef self__sizes;
|
||||
bool useLtInterface = false;
|
||||
#if defined(USE_ROCM)
|
||||
// When hipBLASLt is not supported on the architecture,
|
||||
// disable_addmm_cuda_lt will always be to set to true
|
||||
static bool disable_addmm_cuda_lt =
|
||||
!isSupportedHipLtROCmArch(self.device().index()) || getDisableAddmmCudaLt();
|
||||
#else
|
||||
static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt();
|
||||
#endif
|
||||
// Handle whether to use the Lt interface {
|
||||
static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device());
|
||||
// if lt path fails, we recurse back into this function here and force the lt path to off
|
||||
// we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent
|
||||
bool disable_addmm_cuda_lt_final = disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
|
||||
#if defined(USE_ROCM) && ROCM_VERSION == 60400
|
||||
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
|
||||
cublasCommonArgs _args(mat1, mat2, result);
|
||||
if (_args.transa == 't' && _args.transb == 't') {
|
||||
disable_addmm_cuda_lt_final = true;
|
||||
}
|
||||
#endif
|
||||
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
|
||||
#ifdef USE_ROCM
|
||||
// Conditioned on the device index, which is not persistent
|
||||
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
|
||||
#endif
|
||||
// Condition on the input
|
||||
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
|
||||
// }
|
||||
|
||||
at::ScalarType scalar_type = mat1.scalar_type();
|
||||
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
|
||||
c10::MaybeOwned<Tensor> self_;
|
||||
if (&result != &self) {
|
||||
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
||||
// Strangely, if mat2 has only 1 row or column, we get
|
||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
||||
// is to use lt interface only when self is bias.
|
||||
// for cuda 11.4, cublasLtMatmul is activated
|
||||
// the last two conditions is to skip 16b transA and non-trans-B having
|
||||
// leading dim >> rows when they are sliced from a large tensor
|
||||
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
|
||||
if (!disable_addmm_cuda_lt_final) {
|
||||
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
|
||||
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
|
||||
self.is_contiguous() && result.is_contiguous() &&
|
||||
#ifdef USE_ROCM
|
||||
(scalar_type == at::ScalarType::Float ||
|
||||
scalar_type == at::ScalarType::Half ||
|
||||
scalar_type == at::ScalarType::BFloat16) &&
|
||||
#else
|
||||
(scalar_type == at::ScalarType::Double ||
|
||||
scalar_type == at::ScalarType::Float ||
|
||||
scalar_type == at::ScalarType::Half ||
|
||||
scalar_type == at::ScalarType::BFloat16) &&
|
||||
#endif
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
|
||||
mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
|
||||
#else
|
||||
mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
|
||||
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
|
||||
// avoid leading dim >> rows bugs
|
||||
((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) ||
|
||||
(mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) ||
|
||||
(scalar_type != at::ScalarType::Half &&
|
||||
scalar_type != at::ScalarType::BFloat16)) &&
|
||||
((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) ||
|
||||
(mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) ||
|
||||
(scalar_type != at::ScalarType::Half &&
|
||||
scalar_type != at::ScalarType::BFloat16));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
if (!useLtInterface) {
|
||||
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
|
||||
}
|
||||
self__sizes = self_->sizes();
|
||||
} else {
|
||||
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
self__sizes = self_->sizes();
|
||||
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
|
||||
TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0");
|
||||
TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1");
|
||||
}
|
||||
|
||||
if (&result != &self) {
|
||||
at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
|
||||
if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
|
||||
at::native::copy_(result, *self_);
|
||||
// Handle result/self shapes
|
||||
if (!result.is_same(self)) {
|
||||
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
|
||||
|
||||
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
|
||||
if (disable_addmm_cuda_lt) {
|
||||
// When in non-Lt path we do expand self even before
|
||||
// check for beta != 0.0 to make sure that
|
||||
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
|
||||
// runs green.
|
||||
return expand_size(self, result.sizes(), "addmm");
|
||||
}
|
||||
// copy next, should broadcast
|
||||
return c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
}();
|
||||
// We copy bias when in the non-Lt path
|
||||
if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) {
|
||||
// NOTE: self should broadcast over result
|
||||
at::native::copy_(result, *self_maybe_expanded);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
IntArrayRef result_sizes = result.sizes();
|
||||
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
|
||||
// Short circuit on empty result
|
||||
if (result.numel() == 0) {
|
||||
return result;
|
||||
}
|
||||
|
||||
cublasCommonArgs args(mat1, mat2, result);
|
||||
|
||||
if (mat1.numel() == 0) {
|
||||
// Short circuit if the reduction dim is empty
|
||||
if (mat1.sizes()[1] == 0) {
|
||||
// By definition, when beta==0, values in self should be ignored. nans and infs
|
||||
// should not propagate
|
||||
if (beta.toComplexDouble() == 0.) {
|
||||
|
|
@ -455,158 +557,64 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|||
result,
|
||||
self.expand(result.sizes()),
|
||||
at::native::scalar_tensor(
|
||||
beta,
|
||||
self.scalar_type(),
|
||||
std::nullopt /* layout */,
|
||||
at::kCPU,
|
||||
std::nullopt /* pin_memory */));
|
||||
beta,
|
||||
self.scalar_type(),
|
||||
std::nullopt /* layout */,
|
||||
at::kCPU,
|
||||
std::nullopt /* pin_memory */
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
cublasCommonArgs args(mat1, mat2, result);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
|
||||
|
||||
if (useLtInterface) {
|
||||
#if defined(USE_ROCM)
|
||||
bool okay = true;
|
||||
// The Lt path
|
||||
if (!disable_addmm_cuda_lt) {
|
||||
bool lt_success = false;
|
||||
if (is_float_output_with_half_input) {
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK(false, "float output with half input is not enabled for ROCm");
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args,
|
||||
alpha,
|
||||
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
|
||||
activation_to_gemm_and_blas_arg(activation));
|
||||
} else {
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
// This condition is needed for mm case on ROCm for hipblasLt path.
|
||||
// Passing the bias ptr as null to avoid accuracy issues for mm case.
|
||||
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
|
||||
args.result->data_ptr<scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
if (!okay) {
|
||||
// lt path failed; recurse but disable lt path
|
||||
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
|
||||
}
|
||||
#else
|
||||
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
|
||||
bool okay = true;
|
||||
if (is_float_output_with_half_input) {
|
||||
#else
|
||||
if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) {
|
||||
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
|
||||
}
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
|
||||
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
|
||||
}
|
||||
else {
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t, float>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
args.result->data_ptr<float>(),
|
||||
args.result_ld,
|
||||
activation_epilogue
|
||||
);
|
||||
}});
|
||||
);
|
||||
#endif
|
||||
} else {
|
||||
// !is_float_output_with_half_input
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
launchTunableGemmAndBias<scalar_t>(
|
||||
args,
|
||||
alpha,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
activation_epilogue);
|
||||
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
|
||||
}
|
||||
else {
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha.to<at::opmath_type<scalar_t>>(),
|
||||
args.mata->const_data_ptr<scalar_t>(),
|
||||
args.lda,
|
||||
args.matb->const_data_ptr<scalar_t>(),
|
||||
args.ldb,
|
||||
self.const_data_ptr<scalar_t>(),
|
||||
args.result->data_ptr<scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_epilogue
|
||||
);
|
||||
}});
|
||||
}
|
||||
if (!okay) {
|
||||
// lt path failed; recurse but disable lt path
|
||||
);
|
||||
} // end is_float_output_with_half_input
|
||||
|
||||
if (!lt_success) {
|
||||
// lt path failed; recurse but disable lt path
|
||||
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
|
||||
}
|
||||
#endif
|
||||
} else
|
||||
{
|
||||
// end Lt path
|
||||
} else {
|
||||
// No Lt, we use a GEMM instead
|
||||
if (is_float_output_with_half_input) {
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(
|
||||
scalar_type,
|
||||
"addmm_cuda",
|
||||
[&] {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
opmath_t alpha_val = alpha.to<opmath_t>();
|
||||
opmath_t beta_val = beta.to<opmath_t>();
|
||||
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
|
||||
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
|
||||
|
||||
float* result_ptr = args.result->mutable_data_ptr<float>();
|
||||
at::cuda::blas::gemm<scalar_t, float>(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha_val,
|
||||
mat1_ptr,
|
||||
args.lda,
|
||||
mat2_ptr,
|
||||
args.ldb,
|
||||
beta_val,
|
||||
result_ptr,
|
||||
args.result_ld);
|
||||
});
|
||||
launchGemmCublas<scalar_t, float>(args, alpha, beta);
|
||||
}
|
||||
);
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
|
|
@ -614,28 +622,12 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|||
scalar_type,
|
||||
"addmm_cuda",
|
||||
[&] {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
opmath_t alpha_val = alpha.to<opmath_t>();
|
||||
opmath_t beta_val = beta.to<opmath_t>();
|
||||
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
|
||||
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
|
||||
scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>();
|
||||
at::cuda::blas::gemm<scalar_t>(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
alpha_val,
|
||||
mat1_ptr,
|
||||
args.lda,
|
||||
mat2_ptr,
|
||||
args.ldb,
|
||||
beta_val,
|
||||
result_ptr,
|
||||
args.result_ld);
|
||||
});
|
||||
launchGemmCublas<scalar_t>(args, alpha, beta);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// Apply epilogue
|
||||
switch (activation) {
|
||||
case Activation::RELU:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
|
|
@ -647,14 +639,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
} // end GEMM path
|
||||
|
||||
// Preprocessor gate here needs to match the inverse of the check
|
||||
// gating activation_to_gemm_and_blas_arg above; here we are manually
|
||||
// performing a post-GELU because we weren't able to use the GELU
|
||||
// epilogue above.
|
||||
#if !defined(CUDA_VERSION) && !defined(USE_ROCM)
|
||||
if (useLtInterface && activation == Activation::GELU) {
|
||||
if (!disable_addmm_cuda_lt && activation == Activation::GELU) {
|
||||
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -67,7 +67,21 @@ class TestFullyShardMemory(FSDPTest):
|
|||
# allocate the cuBLAS workspaces before measuring the memory usage
|
||||
# since the workspace size can differ between hardwares
|
||||
lin = torch.nn.Linear(768, 768, device=device_type)
|
||||
inp = torch.randn(1, 768, device=device_type)
|
||||
# NOTE: before https://github.com/pytorch/pytorch/pull/163955,
|
||||
# the input shape was (1, 768), so that the forward gemm used
|
||||
# cublaslt, and the backward used cublas.
|
||||
# With the aforementioned PR, and with shape (1, 768),
|
||||
# the cublas path is used both in forward and in backward,
|
||||
# altering peak memory usage not accounting for cublaslt.
|
||||
# Here we change the input shape to (2, 768), and that swaps
|
||||
# the cublas/cublaslt selection in the forward/backward,
|
||||
# but that does not affect the peak memory usage stored in `base_mem_mb`.
|
||||
# Reasons for the flip:
|
||||
# before PR: no Lt in addmm when mat2 has nrows/ncols <= 1,
|
||||
# after PR: no Lt in addmm when either mat1 or mat2 have nrows/ncols <= 1,
|
||||
# since the input preparation can swap matrices based on output
|
||||
# row-/col-majorness.
|
||||
inp = torch.randn(2, 768, device=device_type)
|
||||
lin(inp).sum().backward()
|
||||
torch.get_device_module(device_type).empty_cache()
|
||||
base_mem_mb = self._get_peak_active_memory_mb()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user