mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "Introduce CUDA-only _scaled_mm op (#106844)"
This reverts commit 9440a8cbec.
Reverted https://github.com/pytorch/pytorch/pull/106844 on behalf of https://github.com/izaitsevfb due to Breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/106844#issuecomment-1679858327))
This commit is contained in:
parent
22f5889753
commit
1af324b560
|
|
@ -5,7 +5,6 @@
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/cuda/CUDABlas.h>
|
#include <ATen/cuda/CUDABlas.h>
|
||||||
#include <ATen/cuda/Exceptions.h>
|
#include <ATen/cuda/Exceptions.h>
|
||||||
#include <ATen/cuda/CUDADataType.h>
|
|
||||||
#include <c10/cuda/CUDAFunctions.h>
|
#include <c10/cuda/CUDAFunctions.h>
|
||||||
#include <c10/macros/Export.h>
|
#include <c10/macros/Export.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
|
@ -197,6 +196,7 @@ static size_t _getWorkspaceSize() {
|
||||||
static size_t workspace_size = _parseChosenWorkspaceSize();
|
static size_t workspace_size = _parseChosenWorkspaceSize();
|
||||||
return workspace_size;
|
return workspace_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
namespace at::cuda::blas {
|
namespace at::cuda::blas {
|
||||||
|
|
@ -876,115 +876,6 @@ template void gemm_and_bias(
|
||||||
int64_t result_ld,
|
int64_t result_ld,
|
||||||
GEMMAndBiasActivationEpilogue activation);
|
GEMMAndBiasActivationEpilogue activation);
|
||||||
|
|
||||||
void scaled_gemm(
|
|
||||||
char transa,
|
|
||||||
char transb,
|
|
||||||
int64_t m,
|
|
||||||
int64_t n,
|
|
||||||
int64_t k,
|
|
||||||
const void* mat1_ptr,
|
|
||||||
const void* mat1_scale_ptr,
|
|
||||||
int64_t mat1_ld,
|
|
||||||
ScalarType mat1_dtype,
|
|
||||||
const void* mat2_ptr,
|
|
||||||
const void* mat2_scale_ptr,
|
|
||||||
int64_t mat2_ld,
|
|
||||||
ScalarType mat2_dtype,
|
|
||||||
const void* bias_ptr,
|
|
||||||
ScalarType bias_dtype,
|
|
||||||
void* result_ptr,
|
|
||||||
const void *result_scale_ptr,
|
|
||||||
int64_t result_ld,
|
|
||||||
ScalarType result_dtype,
|
|
||||||
void* amax_ptr) {
|
|
||||||
const auto computeType = CUBLAS_COMPUTE_32F;
|
|
||||||
const auto scaleType = CUDA_R_32F;
|
|
||||||
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
|
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa));
|
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
|
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
|
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
|
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
|
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, amax_ptr);
|
|
||||||
CuBlasLtMatrixLayout Adesc(ScalarTypeToCudaDataType(mat1_dtype), m, k, mat1_ld, transa == 't');
|
|
||||||
CuBlasLtMatrixLayout Bdesc(ScalarTypeToCudaDataType(mat2_dtype), k, n, mat2_ld, transb == 't');
|
|
||||||
CuBlasLtMatrixLayout Cdesc(ScalarTypeToCudaDataType(bias_dtype), m, n, result_ld);
|
|
||||||
CuBlasLtMatrixLayout Ddesc(ScalarTypeToCudaDataType(result_dtype), m, n, result_ld);
|
|
||||||
if (bias_ptr) {
|
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
|
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_RELU_BIAS);
|
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
|
|
||||||
}
|
|
||||||
size_t workspaceSize = _getWorkspaceSize();
|
|
||||||
auto workspace = at::empty(
|
|
||||||
{static_cast<int64_t>(workspaceSize)},
|
|
||||||
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte));
|
|
||||||
|
|
||||||
CuBlasLtMatmulPreference preference;
|
|
||||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
|
|
||||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
|
||||||
int returnedResult = 0;
|
|
||||||
cublasLtHandle_t ltHandle =
|
|
||||||
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
|
|
||||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
|
||||||
ltHandle,
|
|
||||||
computeDesc.descriptor(),
|
|
||||||
Adesc.descriptor(),
|
|
||||||
Bdesc.descriptor(),
|
|
||||||
Cdesc.descriptor(),
|
|
||||||
Ddesc.descriptor(),
|
|
||||||
preference.descriptor(),
|
|
||||||
1,
|
|
||||||
&heuristicResult,
|
|
||||||
&returnedResult));
|
|
||||||
if (returnedResult == 0) {
|
|
||||||
TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
|
|
||||||
}
|
|
||||||
float alpha_val = 1.0;
|
|
||||||
float beta_val = 0.0;
|
|
||||||
cublasStatus_t cublasStatus = cublasLtMatmul(
|
|
||||||
ltHandle,
|
|
||||||
computeDesc.descriptor(),
|
|
||||||
&alpha_val,
|
|
||||||
mat1_ptr,
|
|
||||||
Adesc.descriptor(),
|
|
||||||
mat2_ptr,
|
|
||||||
Bdesc.descriptor(),
|
|
||||||
&beta_val,
|
|
||||||
nullptr,
|
|
||||||
Cdesc.descriptor(),
|
|
||||||
result_ptr,
|
|
||||||
Ddesc.descriptor(),
|
|
||||||
&heuristicResult.algo,
|
|
||||||
workspace.data_ptr(),
|
|
||||||
workspaceSize,
|
|
||||||
at::cuda::getCurrentCUDAStream());
|
|
||||||
TORCH_CHECK(
|
|
||||||
cublasStatus == CUBLAS_STATUS_SUCCESS,
|
|
||||||
"CUDA error: ",
|
|
||||||
at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
|
|
||||||
" when calling cublasLtMatmul with transpose_mat1 ",
|
|
||||||
transa,
|
|
||||||
" transpose_mat2 ",
|
|
||||||
transb,
|
|
||||||
" m ",
|
|
||||||
m,
|
|
||||||
" n ",
|
|
||||||
n,
|
|
||||||
" k ",
|
|
||||||
k,
|
|
||||||
" mat1_ld ",
|
|
||||||
mat1_ld,
|
|
||||||
" mat2_ld ",
|
|
||||||
mat2_ld,
|
|
||||||
" result_ld ",
|
|
||||||
result_ld,
|
|
||||||
" computeType ",
|
|
||||||
computeType,
|
|
||||||
" scaleType ",
|
|
||||||
scaleType);
|
|
||||||
}
|
|
||||||
|
|
||||||
void int8_gemm(
|
void int8_gemm(
|
||||||
bool transpose_mat1,
|
bool transpose_mat1,
|
||||||
bool transpose_mat2,
|
bool transpose_mat2,
|
||||||
|
|
|
||||||
|
|
@ -100,28 +100,6 @@ void int8_gemm(
|
||||||
int64_t mat2_ld,
|
int64_t mat2_ld,
|
||||||
int32_t* result_ptr,
|
int32_t* result_ptr,
|
||||||
int64_t result_ld);
|
int64_t result_ld);
|
||||||
|
|
||||||
void scaled_gemm(
|
|
||||||
char transa,
|
|
||||||
char transb,
|
|
||||||
int64_t m,
|
|
||||||
int64_t n,
|
|
||||||
int64_t k,
|
|
||||||
const void* mat1_ptr,
|
|
||||||
const void* mat1_scale_ptr,
|
|
||||||
int64_t mat1_ld,
|
|
||||||
ScalarType mat1_dtype,
|
|
||||||
const void* mat2_ptr,
|
|
||||||
const void* mat2_scale_ptr,
|
|
||||||
int64_t mat2_ld,
|
|
||||||
ScalarType mat2_dtype,
|
|
||||||
const void* bias,
|
|
||||||
ScalarType bias_dtype,
|
|
||||||
void* result_ptr,
|
|
||||||
const void* result_scale_ptr,
|
|
||||||
int64_t result_ld,
|
|
||||||
ScalarType result_dtype,
|
|
||||||
void* amax_ptr);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
|
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@
|
||||||
#else
|
#else
|
||||||
#include <ATen/ops/_addmm_activation_native.h>
|
#include <ATen/ops/_addmm_activation_native.h>
|
||||||
#include <ATen/ops/_efficientzerotensor.h>
|
#include <ATen/ops/_efficientzerotensor.h>
|
||||||
#include <ATen/ops/_scaled_mm_native.h>
|
|
||||||
#include <ATen/ops/addmm_native.h>
|
#include <ATen/ops/addmm_native.h>
|
||||||
#include <ATen/ops/addmv_native.h>
|
#include <ATen/ops/addmv_native.h>
|
||||||
#include <ATen/ops/baddbmm_native.h>
|
#include <ATen/ops/baddbmm_native.h>
|
||||||
|
|
@ -714,114 +713,4 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
|
||||||
return _int_mm_out_cuda(self, mat2, result);
|
return _int_mm_out_cuda(self, mat2, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax
|
|
||||||
// Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default.
|
|
||||||
// If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed.
|
|
||||||
// Known limitations:
|
|
||||||
// - Only works if mat1 is row-major and mat2 is column-major
|
|
||||||
// - Only works if matrices sizes are divisible by 32
|
|
||||||
std::tuple<Tensor&, Tensor&>
|
|
||||||
_scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
|
||||||
const c10::optional<at::Tensor>& bias,
|
|
||||||
c10::optional<c10::ScalarType> out_dtype,
|
|
||||||
const c10::optional<at::Tensor>& scale_a,
|
|
||||||
const c10::optional<at::Tensor>& scale_b,
|
|
||||||
const c10::optional<at::Tensor>& scale_result,
|
|
||||||
Tensor& out, Tensor& amax) {
|
|
||||||
// Check sizes
|
|
||||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
|
|
||||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
|
|
||||||
TORCH_CHECK(
|
|
||||||
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
|
|
||||||
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
|
||||||
TORCH_CHECK(!scale_a || (scale_a->numel() == 1 && scale_a->scalar_type() == kFloat),
|
|
||||||
"scale_a must be float scalar");
|
|
||||||
TORCH_CHECK(!scale_b || (scale_b->numel() == 1 && scale_b->scalar_type() == kFloat),
|
|
||||||
"scale_b must be a float scalar");
|
|
||||||
TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
|
|
||||||
"scale_result must be a float scalar");
|
|
||||||
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
|
|
||||||
" but got ", bias->numel());
|
|
||||||
TORCH_CHECK(mat1.sizes()[0] % 16 == 0 && mat1.sizes()[1] % 16 == 0, "mat1 shape (", mat1.sizes()[0], "x",
|
|
||||||
mat1.sizes()[1], " must be divisible by 16");
|
|
||||||
TORCH_CHECK(mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0, "mat2 shape (", mat2.sizes()[0], "x",
|
|
||||||
mat2.sizes()[1], " must be divisible by 16");
|
|
||||||
// Check types
|
|
||||||
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
|
|
||||||
TORCH_CHECK(amax.scalar_type() == kFloat, "amax must be a float scalar");
|
|
||||||
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
|
|
||||||
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat1.scalar_type());
|
|
||||||
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
|
|
||||||
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
|
|
||||||
"Multiplication of two Float8_e5m2 matrices is not supported");
|
|
||||||
if (bias) {
|
|
||||||
TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half,
|
|
||||||
"Bias must be either Half or BFloat16, but got ", bias->scalar_type());
|
|
||||||
TORCH_CHECK((out.scalar_type() != kFloat && out.scalar_type() != ScalarType::BFloat16) ||
|
|
||||||
bias->scalar_type() == ScalarType::BFloat16,
|
|
||||||
"Bias must be BFloat16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type());
|
|
||||||
TORCH_CHECK(out.scalar_type() != ScalarType::Half || bias->scalar_type() == ScalarType::Half,
|
|
||||||
"Bias must be Float16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type());
|
|
||||||
}
|
|
||||||
{
|
|
||||||
auto bias_ = bias.value_or(Tensor());
|
|
||||||
auto scale_a_ = scale_a.value_or(Tensor());
|
|
||||||
auto scale_b_ = scale_b.value_or(Tensor());
|
|
||||||
auto scale_result_ = scale_result.value_or(Tensor());
|
|
||||||
TensorArg targs[]{{out, "out", 0}, {amax, "amax", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3},
|
|
||||||
{bias_, "bias", 4}, {scale_a_, "scale_a", 5}, {scale_b_, "scale_b", 6},
|
|
||||||
{scale_result_, "scale_result", 7}};
|
|
||||||
checkAllSameGPU(__func__, targs);
|
|
||||||
}
|
|
||||||
|
|
||||||
IntArrayRef mat1_sizes = mat1.sizes();
|
|
||||||
IntArrayRef mat2_sizes = mat2.sizes();
|
|
||||||
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
|
||||||
at::native::resize_output(amax, {});
|
|
||||||
|
|
||||||
#if !defined(USE_ROCM) && !defined(_MSC_VER)
|
|
||||||
cublasCommonArgs args(mat1, mat2, out);
|
|
||||||
const auto out_dtype_ = args.result->scalar_type();
|
|
||||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
|
||||||
at::cuda::blas::scaled_gemm(
|
|
||||||
args.transa,
|
|
||||||
args.transb,
|
|
||||||
args.m,
|
|
||||||
args.n,
|
|
||||||
args.k,
|
|
||||||
args.mata->data_ptr(),
|
|
||||||
scale_a ? scale_a->data_ptr() : nullptr,
|
|
||||||
args.lda,
|
|
||||||
args.mata->scalar_type(),
|
|
||||||
args.matb->data_ptr(),
|
|
||||||
scale_b ? scale_b->data_ptr() : nullptr,
|
|
||||||
args.ldb,
|
|
||||||
args.matb->scalar_type(),
|
|
||||||
bias ? bias->data_ptr(): nullptr,
|
|
||||||
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
|
|
||||||
args.result->data_ptr(),
|
|
||||||
scale_result ? scale_result->data_ptr() : nullptr,
|
|
||||||
args.result_ld,
|
|
||||||
out_dtype_,
|
|
||||||
amax.data_ptr());
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "_scaled_mm_out_cuda is not compiled for this platform.");
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return {out, amax};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor>
|
|
||||||
_scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
|
||||||
const c10::optional<at::Tensor>& bias,
|
|
||||||
c10::optional<c10::ScalarType> out_dtype,
|
|
||||||
const c10::optional<at::Tensor>& scale_a,
|
|
||||||
const c10::optional<at::Tensor>& scale_b,
|
|
||||||
const c10::optional<at::Tensor>& scale_result) {
|
|
||||||
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
|
|
||||||
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
|
|
||||||
Tensor amax = at::empty({0}, mat_a.options().dtype(ScalarType::Float));
|
|
||||||
return _scaled_mm_out_cuda(mat_a, mat_b, bias, out_dtype, scale_a, scale_b, scale_result, out ,amax);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -6651,16 +6651,6 @@
|
||||||
structured_delegate: _addmm_activation.out
|
structured_delegate: _addmm_activation.out
|
||||||
variants: function, method
|
variants: function, method
|
||||||
|
|
||||||
- func: _scaled_mm(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None) -> (Tensor, Tensor)
|
|
||||||
variants: function
|
|
||||||
dispatch:
|
|
||||||
CUDA: _scaled_mm_cuda
|
|
||||||
|
|
||||||
- func: _scaled_mm.out(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None, Tensor(a!) out, Tensor(b!) out_amax) -> (Tensor(a!), Tensor(b!))
|
|
||||||
variants: function
|
|
||||||
dispatch:
|
|
||||||
CUDA: _scaled_mm_out_cuda
|
|
||||||
|
|
||||||
# NOTE [ Sparse: autograd and API ]
|
# NOTE [ Sparse: autograd and API ]
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
|
|
||||||
|
|
@ -448,8 +448,6 @@ aten::_scaled_dot_product_efficient_attention
|
||||||
aten::_scaled_dot_product_efficient_attention_backward
|
aten::_scaled_dot_product_efficient_attention_backward
|
||||||
aten::_scaled_dot_product_flash_attention
|
aten::_scaled_dot_product_flash_attention
|
||||||
aten::_scaled_dot_product_flash_attention_backward
|
aten::_scaled_dot_product_flash_attention_backward
|
||||||
aten::_scaled_mm
|
|
||||||
aten::_scaled_mm.out
|
|
||||||
aten::_segment_reduce_backward
|
aten::_segment_reduce_backward
|
||||||
aten::_segment_reduce_backward.out
|
aten::_segment_reduce_backward.out
|
||||||
aten::_slow_conv2d_backward.grad_input
|
aten::_slow_conv2d_backward.grad_input
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.testing import make_tensor
|
from torch.testing import make_tensor
|
||||||
|
|
@ -176,62 +175,7 @@ class TestMatmulCuda(TestCase):
|
||||||
self.assertEqual(out1_gpu, out2_gpu[0])
|
self.assertEqual(out1_gpu, out2_gpu[0])
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
|
|
||||||
@unittest.skipIf(not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0), "FP8 is only supported on H100+")
|
|
||||||
class TestFP8MatmulCuda(TestCase):
|
|
||||||
def _test_tautological_mm(self, device: str = "cuda",
|
|
||||||
x_dtype: torch.dtype = torch.float8_e4m3fn,
|
|
||||||
y_dtype: torch.dtype = torch.float8_e4m3fn,
|
|
||||||
out_dtype: Optional[torch.dtype] = None,
|
|
||||||
size: int = 16) -> None:
|
|
||||||
x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
|
|
||||||
y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
|
|
||||||
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
|
||||||
(out_fp8, amax_fp8) = torch._scaled_mm(x_fp8, y_fp8, out_dtype=out_dtype)
|
|
||||||
if out_dtype is not None:
|
|
||||||
self.assertEqual(out_dtype, out_fp8.dtype)
|
|
||||||
if out_dtype not in [torch.float16, torch.bfloat16, torch.float]:
|
|
||||||
self.assertEqual(out_fp32.amax(), amax_fp8)
|
|
||||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
|
||||||
|
|
||||||
def test_float8_basics(self, device) -> None:
|
|
||||||
self._test_tautological_mm(device, torch.float8_e4m3fn, torch.float8_e4m3fn, size=16)
|
|
||||||
self._test_tautological_mm(device, torch.float8_e4m3fn, torch.float8_e5m2, size=32)
|
|
||||||
self._test_tautological_mm(device, torch.float8_e5m2, torch.float8_e4m3fn, size=48)
|
|
||||||
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
self._test_tautological_mm(device, torch.float8_e5m2, torch.float8_e5m2)
|
|
||||||
|
|
||||||
def test_float8_out_dtype(self, device) -> None:
|
|
||||||
self._test_tautological_mm(device, size=64, out_dtype=torch.float16)
|
|
||||||
self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
|
|
||||||
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
self._test_tautological_mm(device, out_dtype=torch.float8_e5m2)
|
|
||||||
|
|
||||||
def test_float8_scale(self, device) -> None:
|
|
||||||
size = (16, 16)
|
|
||||||
x = torch.full(size, .5, device=device, dtype=torch.float8_e4m3fn)
|
|
||||||
y = torch.full(size, .5, device=device, dtype=torch.float8_e5m2).t()
|
|
||||||
scale_a = torch.tensor(1.5, device=device)
|
|
||||||
scale_b = torch.tensor(0.66, device=device)
|
|
||||||
out_fp8, amax_fp8 = torch._scaled_mm(x, y)
|
|
||||||
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
|
|
||||||
out_fp8_s, amax_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
|
||||||
self.assertEqual(out_fp8, out_fp8_s)
|
|
||||||
|
|
||||||
def test_float8_bias(self, device) -> None:
|
|
||||||
(k, l, m) = (16, 48, 32)
|
|
||||||
x = torch.rand((k, l), device=device).to(torch.float8_e4m3fn)
|
|
||||||
y = torch.full((m, l), .25, device=device, dtype=torch.float8_e4m3fn).t()
|
|
||||||
bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
|
|
||||||
out_fp8, amax_fp8 = torch._scaled_mm(x, y)
|
|
||||||
outb_fp8, amaxb_fp8 = torch._scaled_mm(x, y, bias=bias)
|
|
||||||
self.assertEqual((amaxb_fp8 - amax_fp8).item(), 4.0)
|
|
||||||
|
|
||||||
|
|
||||||
instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
|
instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
|
||||||
instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu")
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user