Revert "Introduce CUDA-only _scaled_mm op (#106844)"

This reverts commit 9440a8cbec.

Reverted https://github.com/pytorch/pytorch/pull/106844 on behalf of https://github.com/izaitsevfb due to Breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/106844#issuecomment-1679858327))
This commit is contained in:
PyTorch MergeBot 2023-08-16 02:05:29 +00:00
parent 22f5889753
commit 1af324b560
6 changed files with 1 additions and 311 deletions

View File

@ -5,7 +5,6 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/CUDADataType.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/macros/Export.h>
#include <c10/util/irange.h>
@ -197,6 +196,7 @@ static size_t _getWorkspaceSize() {
static size_t workspace_size = _parseChosenWorkspaceSize();
return workspace_size;
}
} // anonymous namespace
namespace at::cuda::blas {
@ -876,115 +876,6 @@ template void gemm_and_bias(
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
void scaled_gemm(
char transa,
char transb,
int64_t m,
int64_t n,
int64_t k,
const void* mat1_ptr,
const void* mat1_scale_ptr,
int64_t mat1_ld,
ScalarType mat1_dtype,
const void* mat2_ptr,
const void* mat2_scale_ptr,
int64_t mat2_ld,
ScalarType mat2_dtype,
const void* bias_ptr,
ScalarType bias_dtype,
void* result_ptr,
const void *result_scale_ptr,
int64_t result_ld,
ScalarType result_dtype,
void* amax_ptr) {
const auto computeType = CUBLAS_COMPUTE_32F;
const auto scaleType = CUDA_R_32F;
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa));
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, amax_ptr);
CuBlasLtMatrixLayout Adesc(ScalarTypeToCudaDataType(mat1_dtype), m, k, mat1_ld, transa == 't');
CuBlasLtMatrixLayout Bdesc(ScalarTypeToCudaDataType(mat2_dtype), k, n, mat2_ld, transb == 't');
CuBlasLtMatrixLayout Cdesc(ScalarTypeToCudaDataType(bias_dtype), m, n, result_ld);
CuBlasLtMatrixLayout Ddesc(ScalarTypeToCudaDataType(result_dtype), m, n, result_ld);
if (bias_ptr) {
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_RELU_BIAS);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
}
size_t workspaceSize = _getWorkspaceSize();
auto workspace = at::empty(
{static_cast<int64_t>(workspaceSize)},
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte));
CuBlasLtMatmulPreference preference;
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
cublasLtHandle_t ltHandle =
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Ddesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
if (returnedResult == 0) {
TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
}
float alpha_val = 1.0;
float beta_val = 0.0;
cublasStatus_t cublasStatus = cublasLtMatmul(
ltHandle,
computeDesc.descriptor(),
&alpha_val,
mat1_ptr,
Adesc.descriptor(),
mat2_ptr,
Bdesc.descriptor(),
&beta_val,
nullptr,
Cdesc.descriptor(),
result_ptr,
Ddesc.descriptor(),
&heuristicResult.algo,
workspace.data_ptr(),
workspaceSize,
at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ",
at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
" when calling cublasLtMatmul with transpose_mat1 ",
transa,
" transpose_mat2 ",
transb,
" m ",
m,
" n ",
n,
" k ",
k,
" mat1_ld ",
mat1_ld,
" mat2_ld ",
mat2_ld,
" result_ld ",
result_ld,
" computeType ",
computeType,
" scaleType ",
scaleType);
}
void int8_gemm(
bool transpose_mat1,
bool transpose_mat2,

View File

@ -100,28 +100,6 @@ void int8_gemm(
int64_t mat2_ld,
int32_t* result_ptr,
int64_t result_ld);
void scaled_gemm(
char transa,
char transb,
int64_t m,
int64_t n,
int64_t k,
const void* mat1_ptr,
const void* mat1_scale_ptr,
int64_t mat1_ld,
ScalarType mat1_dtype,
const void* mat2_ptr,
const void* mat2_scale_ptr,
int64_t mat2_ld,
ScalarType mat2_dtype,
const void* bias,
ScalarType bias_dtype,
void* result_ptr,
const void* result_scale_ptr,
int64_t result_ld,
ScalarType result_dtype,
void* amax_ptr);
#endif
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \

View File

@ -15,7 +15,6 @@
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
@ -714,114 +713,4 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
return _int_mm_out_cuda(self, mat2, result);
}
// Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax
// Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default.
// If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed.
// Known limitations:
// - Only works if mat1 is row-major and mat2 is column-major
// - Only works if matrices sizes are divisible by 32
std::tuple<Tensor&, Tensor&>
_scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
const c10::optional<at::Tensor>& bias,
c10::optional<c10::ScalarType> out_dtype,
const c10::optional<at::Tensor>& scale_a,
const c10::optional<at::Tensor>& scale_b,
const c10::optional<at::Tensor>& scale_result,
Tensor& out, Tensor& amax) {
// Check sizes
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
TORCH_CHECK(!scale_a || (scale_a->numel() == 1 && scale_a->scalar_type() == kFloat),
"scale_a must be float scalar");
TORCH_CHECK(!scale_b || (scale_b->numel() == 1 && scale_b->scalar_type() == kFloat),
"scale_b must be a float scalar");
TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
"scale_result must be a float scalar");
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
" but got ", bias->numel());
TORCH_CHECK(mat1.sizes()[0] % 16 == 0 && mat1.sizes()[1] % 16 == 0, "mat1 shape (", mat1.sizes()[0], "x",
mat1.sizes()[1], " must be divisible by 16");
TORCH_CHECK(mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0, "mat2 shape (", mat2.sizes()[0], "x",
mat2.sizes()[1], " must be divisible by 16");
// Check types
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
TORCH_CHECK(amax.scalar_type() == kFloat, "amax must be a float scalar");
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat1.scalar_type());
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
"Multiplication of two Float8_e5m2 matrices is not supported");
if (bias) {
TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half,
"Bias must be either Half or BFloat16, but got ", bias->scalar_type());
TORCH_CHECK((out.scalar_type() != kFloat && out.scalar_type() != ScalarType::BFloat16) ||
bias->scalar_type() == ScalarType::BFloat16,
"Bias must be BFloat16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type());
TORCH_CHECK(out.scalar_type() != ScalarType::Half || bias->scalar_type() == ScalarType::Half,
"Bias must be Float16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type());
}
{
auto bias_ = bias.value_or(Tensor());
auto scale_a_ = scale_a.value_or(Tensor());
auto scale_b_ = scale_b.value_or(Tensor());
auto scale_result_ = scale_result.value_or(Tensor());
TensorArg targs[]{{out, "out", 0}, {amax, "amax", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3},
{bias_, "bias", 4}, {scale_a_, "scale_a", 5}, {scale_b_, "scale_b", 6},
{scale_result_, "scale_result", 7}};
checkAllSameGPU(__func__, targs);
}
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
at::native::resize_output(amax, {});
#if !defined(USE_ROCM) && !defined(_MSC_VER)
cublasCommonArgs args(mat1, mat2, out);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
at::cuda::blas::scaled_gemm(
args.transa,
args.transb,
args.m,
args.n,
args.k,
args.mata->data_ptr(),
scale_a ? scale_a->data_ptr() : nullptr,
args.lda,
args.mata->scalar_type(),
args.matb->data_ptr(),
scale_b ? scale_b->data_ptr() : nullptr,
args.ldb,
args.matb->scalar_type(),
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
scale_result ? scale_result->data_ptr() : nullptr,
args.result_ld,
out_dtype_,
amax.data_ptr());
#else
TORCH_CHECK(false, "_scaled_mm_out_cuda is not compiled for this platform.");
#endif
return {out, amax};
}
std::tuple<Tensor, Tensor>
_scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
const c10::optional<at::Tensor>& bias,
c10::optional<c10::ScalarType> out_dtype,
const c10::optional<at::Tensor>& scale_a,
const c10::optional<at::Tensor>& scale_b,
const c10::optional<at::Tensor>& scale_result) {
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
Tensor amax = at::empty({0}, mat_a.options().dtype(ScalarType::Float));
return _scaled_mm_out_cuda(mat_a, mat_b, bias, out_dtype, scale_a, scale_b, scale_result, out ,amax);
}
} // namespace at::native

View File

@ -6651,16 +6651,6 @@
structured_delegate: _addmm_activation.out
variants: function, method
- func: _scaled_mm(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None) -> (Tensor, Tensor)
variants: function
dispatch:
CUDA: _scaled_mm_cuda
- func: _scaled_mm.out(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None, Tensor(a!) out, Tensor(b!) out_amax) -> (Tensor(a!), Tensor(b!))
variants: function
dispatch:
CUDA: _scaled_mm_out_cuda
# NOTE [ Sparse: autograd and API ]
#
#

View File

@ -448,8 +448,6 @@ aten::_scaled_dot_product_efficient_attention
aten::_scaled_dot_product_efficient_attention_backward
aten::_scaled_dot_product_flash_attention
aten::_scaled_dot_product_flash_attention_backward
aten::_scaled_mm
aten::_scaled_mm.out
aten::_segment_reduce_backward
aten::_segment_reduce_backward.out
aten::_slow_conv2d_backward.grad_input

View File

@ -2,7 +2,6 @@
import unittest
from functools import partial
from typing import Optional
import torch
from torch.testing import make_tensor
@ -176,62 +175,7 @@ class TestMatmulCuda(TestCase):
self.assertEqual(out1_gpu, out2_gpu[0])
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0), "FP8 is only supported on H100+")
class TestFP8MatmulCuda(TestCase):
def _test_tautological_mm(self, device: str = "cuda",
x_dtype: torch.dtype = torch.float8_e4m3fn,
y_dtype: torch.dtype = torch.float8_e4m3fn,
out_dtype: Optional[torch.dtype] = None,
size: int = 16) -> None:
x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
(out_fp8, amax_fp8) = torch._scaled_mm(x_fp8, y_fp8, out_dtype=out_dtype)
if out_dtype is not None:
self.assertEqual(out_dtype, out_fp8.dtype)
if out_dtype not in [torch.float16, torch.bfloat16, torch.float]:
self.assertEqual(out_fp32.amax(), amax_fp8)
self.assertEqual(out_fp32, out_fp8.to(torch.float))
def test_float8_basics(self, device) -> None:
self._test_tautological_mm(device, torch.float8_e4m3fn, torch.float8_e4m3fn, size=16)
self._test_tautological_mm(device, torch.float8_e4m3fn, torch.float8_e5m2, size=32)
self._test_tautological_mm(device, torch.float8_e5m2, torch.float8_e4m3fn, size=48)
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
with self.assertRaises(RuntimeError):
self._test_tautological_mm(device, torch.float8_e5m2, torch.float8_e5m2)
def test_float8_out_dtype(self, device) -> None:
self._test_tautological_mm(device, size=64, out_dtype=torch.float16)
self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
with self.assertRaises(RuntimeError):
self._test_tautological_mm(device, out_dtype=torch.float8_e5m2)
def test_float8_scale(self, device) -> None:
size = (16, 16)
x = torch.full(size, .5, device=device, dtype=torch.float8_e4m3fn)
y = torch.full(size, .5, device=device, dtype=torch.float8_e5m2).t()
scale_a = torch.tensor(1.5, device=device)
scale_b = torch.tensor(0.66, device=device)
out_fp8, amax_fp8 = torch._scaled_mm(x, y)
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
out_fp8_s, amax_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
self.assertEqual(out_fp8, out_fp8_s)
def test_float8_bias(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.rand((k, l), device=device).to(torch.float8_e4m3fn)
y = torch.full((m, l), .25, device=device, dtype=torch.float8_e4m3fn).t()
bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
out_fp8, amax_fp8 = torch._scaled_mm(x, y)
outb_fp8, amaxb_fp8 = torch._scaled_mm(x, y, bias=bias)
self.assertEqual((amaxb_fp8 - amax_fp8).item(), 4.0)
instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu")
if __name__ == '__main__':
run_tests()