mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "torch._scaled_mm with MXFP8 (#147548)"
This reverts commit 12b9674cb6.
Reverted https://github.com/pytorch/pytorch/pull/147548 on behalf of https://github.com/wdvr due to failing internal build - similar to previous, see below ([comment](https://github.com/pytorch/pytorch/pull/147548#issuecomment-2684134336))
This commit is contained in:
parent
4216478250
commit
a84db75e1b
|
|
@ -14,7 +14,6 @@
|
|||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/env.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hipblaslt/hipblaslt-ext.hpp>
|
||||
|
|
@ -1504,12 +1503,10 @@ void scaled_gemm(
|
|||
const void* mat1_scale_ptr,
|
||||
int64_t mat1_ld,
|
||||
ScalarType mat1_dtype,
|
||||
ScalarType mat1_scale_dtype,
|
||||
const void* mat2_ptr,
|
||||
const void* mat2_scale_ptr,
|
||||
int64_t mat2_ld,
|
||||
ScalarType mat2_dtype,
|
||||
ScalarType mat2_scale_dtype,
|
||||
const void* bias_ptr,
|
||||
ScalarType bias_dtype,
|
||||
void* result_ptr,
|
||||
|
|
@ -1537,8 +1534,10 @@ void scaled_gemm(
|
|||
// rowwise isn't supported using cublaslt or older hipblaslt
|
||||
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
|
||||
#endif
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
|
||||
{
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
|
||||
}
|
||||
if (result_scale_ptr != nullptr) {
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
|
||||
}
|
||||
|
|
@ -1561,15 +1560,6 @@ void scaled_gemm(
|
|||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
|
||||
}
|
||||
|
||||
if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
|
||||
#if CUDA_VERSION >= 12080
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
|
||||
#else
|
||||
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above");
|
||||
#endif // CUDA_VERSION >= 12080
|
||||
}
|
||||
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
size_t workspaceSize = 0;
|
||||
auto workspace_ptr = _getWorkspace(workspaceSize);
|
||||
|
|
|
|||
|
|
@ -130,12 +130,10 @@ void scaled_gemm(
|
|||
const void* mat1_scale_ptr,
|
||||
int64_t mat1_ld,
|
||||
ScalarType mat1_dtype,
|
||||
ScalarType mat1_scale_dtype,
|
||||
const void* mat2_ptr,
|
||||
const void* mat2_scale_ptr,
|
||||
int64_t mat2_ld,
|
||||
ScalarType mat2_dtype,
|
||||
ScalarType mat2_scale_dtype,
|
||||
const void* bias_ptr,
|
||||
ScalarType bias_dtype,
|
||||
void* result_ptr,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
||||
#include <ATen/cuda/tunable/TunableOp.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
|
|
@ -425,12 +424,10 @@ struct ScaledGemmParams : OpParams {
|
|||
const void* a_scale_ptr{};
|
||||
int64_t lda{};
|
||||
ScalarType a_dtype{};
|
||||
ScalarType a_scale_dtype{};
|
||||
const void* b{};
|
||||
const void* b_scale_ptr{};
|
||||
int64_t ldb{};
|
||||
ScalarType b_dtype{};
|
||||
ScalarType b_scale_dtype{};
|
||||
const void* bias_ptr{};
|
||||
ScalarType bias_dtype{};
|
||||
void* c{};
|
||||
|
|
|
|||
|
|
@ -95,12 +95,10 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
|
|||
params->a_scale_ptr,
|
||||
params->lda,
|
||||
params->a_dtype,
|
||||
params->a_scale_dtype,
|
||||
params->b,
|
||||
params->b_scale_ptr,
|
||||
params->ldb,
|
||||
params->b_dtype,
|
||||
params->b_scale_dtype,
|
||||
params->bias_ptr,
|
||||
params->bias_dtype,
|
||||
params->c,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
#include <cstdint>
|
||||
#include <c10/util/typeid.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
|
@ -96,33 +95,11 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
|
|||
}
|
||||
|
||||
struct cublasCommonArgs {
|
||||
cublasCommonArgs(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
Tensor& c,
|
||||
const std::optional<Tensor>& scale_a = c10::nullopt,
|
||||
const std::optional<Tensor>& scale_b = c10::nullopt,
|
||||
const std::optional<Tensor>& scale_result = c10::nullopt) {
|
||||
cublasCommonArgs(const Tensor& mat1, const Tensor& mat2, Tensor& c) {
|
||||
bool transpose_result = false, transpose_mat1 = false, transpose_mat2 = false;
|
||||
result = prepare_matrix_for_cublas(c, transpose_result);
|
||||
mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1, transpose_result);
|
||||
matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2, transpose_result);
|
||||
|
||||
// Handle scale tensors if provided
|
||||
if (scale_a && scale_b) {
|
||||
// By default since we return in row-major we run the gemm
|
||||
// as B.T @ A.T, check transpose_result to determine if we flip the scales
|
||||
scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr();
|
||||
scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type();
|
||||
scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr();
|
||||
scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type();
|
||||
}
|
||||
|
||||
if (scale_result) {
|
||||
scale_result_ptr = scale_result->data_ptr();
|
||||
scale_result_dtype = scale_result->scalar_type();
|
||||
}
|
||||
|
||||
auto mat1_sizes = mat1.sizes();
|
||||
auto mat2_sizes = mat2.sizes();
|
||||
if (transpose_result) {
|
||||
|
|
@ -138,23 +115,13 @@ struct cublasCommonArgs {
|
|||
lda = mata->stride((transpose_mat1 == transpose_result) ? 1 : 0);
|
||||
ldb = matb->stride((transpose_mat2 == transpose_result) ? 1 : 0);
|
||||
result_ld = result->stride(transpose_result ? 0 : 1);
|
||||
transa = transpose_mat1 ? mata->is_conj() ? 'c' : 't' : 'n';
|
||||
transb = transpose_mat2 ? matb->is_conj() ? 'c' : 't' : 'n';
|
||||
transa = transpose_mat1 ? mata->is_conj() ? 'c' : 't' : 'n';
|
||||
transb = transpose_mat2 ? matb->is_conj() ? 'c' : 't' : 'n';
|
||||
}
|
||||
|
||||
// Matrix members
|
||||
char transa, transb;
|
||||
int64_t m, n, k;
|
||||
int64_t lda, ldb, result_ld;
|
||||
c10::MaybeOwned<Tensor> mata, matb, result;
|
||||
|
||||
// Scale members
|
||||
void* scale_mata_ptr = nullptr;
|
||||
void* scale_matb_ptr = nullptr;
|
||||
void* scale_result_ptr = nullptr;
|
||||
std::optional<c10::ScalarType> scale_mata_dtype;
|
||||
std::optional<c10::ScalarType> scale_matb_dtype;
|
||||
std::optional<c10::ScalarType> scale_result_dtype;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
|
@ -936,10 +903,9 @@ static bool _scaled_mm_is_fnuz() {
|
|||
|
||||
namespace{
|
||||
|
||||
enum class ScalingType : std::uint8_t {
|
||||
enum class ScalingType {
|
||||
TensorWise,
|
||||
RowWise,
|
||||
BlockWise,
|
||||
Error
|
||||
};
|
||||
/*
|
||||
|
|
@ -947,13 +913,10 @@ enum class ScalingType : std::uint8_t {
|
|||
* ---------------------------
|
||||
* Conditions and corresponding Scaling Types:
|
||||
*
|
||||
* - If scale tensors are Float8_e8m0fnu:
|
||||
* - Returns BlockWise (with additional size checks).
|
||||
*
|
||||
* - If scale_a.numel() == 1 && scale_b.numel() == 1:
|
||||
* - Returns TensorWise.
|
||||
*
|
||||
* - Else if scale_a.dim() == 2 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
|
||||
* - Else if scale_a.dim() == 1 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
|
||||
* - Returns RowWise.
|
||||
*
|
||||
* - Otherwise:
|
||||
|
|
@ -966,40 +929,7 @@ ScalingType get_scaling_type(
|
|||
const at::Tensor& scale_a,
|
||||
const at::Tensor& scale_b,
|
||||
int64_t dim_m,
|
||||
int64_t dim_k,
|
||||
int64_t dim_n) {
|
||||
// Check for BlockWise scaling (FP8_E8M0 types)
|
||||
if (scale_a.scalar_type() == scale_b.scalar_type() &&
|
||||
scale_a.scalar_type() == at::kFloat8_e8m0fnu) {
|
||||
constexpr int64_t BLOCK_SIZE_K = 32;
|
||||
constexpr int64_t BLOCK_SIZE_MN = 128;
|
||||
|
||||
auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; };
|
||||
auto num_k_blocks = ceil_div(dim_k, BLOCK_SIZE_K);
|
||||
auto padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4;
|
||||
|
||||
// TODO: We might want to enforce some structure on the shapes of the scale
|
||||
// tensors
|
||||
|
||||
// Check expected sizes for block-wise scaling
|
||||
auto expected_a_size =
|
||||
BLOCK_SIZE_MN * ceil_div(dim_m, BLOCK_SIZE_MN) * padded_num_k_blocks;
|
||||
auto expected_b_size =
|
||||
BLOCK_SIZE_MN * ceil_div(dim_n, BLOCK_SIZE_MN) * padded_num_k_blocks;
|
||||
|
||||
TORCH_CHECK(scale_a.numel() == expected_a_size,
|
||||
"For BlockWise scaling: Expected scale_a size to be ",
|
||||
expected_a_size, " but got ", scale_a.numel());
|
||||
TORCH_CHECK(scale_b.numel() == expected_b_size,
|
||||
"For BlockWise scaling: Expected scale_b size to be ",
|
||||
expected_b_size, " but got ", scale_b.numel());
|
||||
|
||||
TORCH_CHECK(
|
||||
scale_a.is_contiguous() && scale_b.is_contiguous(),
|
||||
"For BlockWise scaling: Both scale_a and scale_b must be contiguous");
|
||||
|
||||
return ScalingType::BlockWise;
|
||||
}
|
||||
// Both Per-Tensor and Row-wise scaling expect fp32 tensors
|
||||
TORCH_CHECK(
|
||||
scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
|
||||
|
|
@ -1097,7 +1027,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
|||
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||
|
||||
// Check what type of scaling we are doing based on inputs
|
||||
ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat1.size(1), mat2.size(1));
|
||||
ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1));
|
||||
TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported");
|
||||
|
||||
TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
|
||||
|
|
@ -1190,7 +1120,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
|||
}
|
||||
#endif
|
||||
|
||||
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result);
|
||||
cublasCommonArgs args(mat1, mat2, out);
|
||||
const auto out_dtype_ = args.result->scalar_type();
|
||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||
|
||||
|
|
@ -1300,7 +1230,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
|||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
{
|
||||
at::cuda::blas::scaled_gemm(
|
||||
args.transa,
|
||||
args.transb,
|
||||
|
|
@ -1308,19 +1238,17 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
|||
args.n,
|
||||
args.k,
|
||||
args.mata->data_ptr(),
|
||||
args.scale_mata_ptr,
|
||||
scale_a.data_ptr(),
|
||||
args.lda,
|
||||
args.mata->scalar_type(),
|
||||
args.scale_mata_dtype.value(),
|
||||
args.matb->data_ptr(),
|
||||
args.scale_matb_ptr,
|
||||
scale_b.data_ptr(),
|
||||
args.ldb,
|
||||
args.matb->scalar_type(),
|
||||
args.scale_matb_dtype.value(),
|
||||
bias ? bias->data_ptr(): nullptr,
|
||||
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
|
||||
args.result->data_ptr(),
|
||||
args.scale_result_ptr,
|
||||
scale_result ? scale_result->data_ptr() : nullptr,
|
||||
args.result_ld,
|
||||
out_dtype_,
|
||||
use_fast_accum,
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ from torch.testing._internal.common_cuda import (
|
|||
SM53OrLater,
|
||||
SM89OrLater,
|
||||
_get_torch_cuda_version,
|
||||
PLATFORM_SUPPORTS_FP8,
|
||||
PLATFORM_SUPPORTS_MX_GEMM
|
||||
PLATFORM_SUPPORTS_FP8
|
||||
)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
|
|
@ -251,7 +250,6 @@ class TestMatmulCuda(TestCase):
|
|||
|
||||
|
||||
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
|
||||
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"
|
||||
|
||||
if torch.version.hip and 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName:
|
||||
e4m3_type = torch.float8_e4m3fnuz
|
||||
|
|
@ -368,79 +366,6 @@ def to_fp8_saturated(
|
|||
|
||||
return x.to(fp8_dtype)
|
||||
|
||||
# copied from https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/mx/to_blocked.py
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
def to_blocked(input_matrix) -> torch.Tensor:
|
||||
"""
|
||||
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
|
||||
|
||||
See:
|
||||
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
||||
|
||||
Args:
|
||||
input_matrix: Input tensor of shape (H, W)
|
||||
|
||||
Returns:
|
||||
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
|
||||
"""
|
||||
rows, cols = input_matrix.shape
|
||||
n_row_blocks = ceil_div(rows, 128)
|
||||
n_col_blocks = ceil_div(cols, 4)
|
||||
|
||||
# Calculate the padded shape
|
||||
padded_rows = n_row_blocks * 128
|
||||
padded_cols = n_col_blocks * 4
|
||||
|
||||
padded = input_matrix
|
||||
# Ideally we would use torch.nn.pad but it doesn't support float8_e8m0fnu for now
|
||||
if (rows, cols) != (padded_rows, padded_cols):
|
||||
padded = torch.zeros((padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype)
|
||||
padded[:rows, :cols] = input_matrix
|
||||
|
||||
# Rearrange the blocks
|
||||
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
|
||||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
||||
|
||||
return rearranged.flatten()
|
||||
|
||||
def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes the error between two tensors in dB.
|
||||
|
||||
For more details see:
|
||||
https://en.wikipedia.org/wiki/Signal-to-noise_ratio
|
||||
|
||||
Args:
|
||||
x: The original tensor.
|
||||
y: The tensor to compare to the original tensor.
|
||||
"""
|
||||
Ps = torch.norm(x)
|
||||
Pn = torch.norm(x - y)
|
||||
return 20 * torch.log10(Ps / Pn)
|
||||
|
||||
# largest power of 2 representable in `torch.float8_e4m3fn`
|
||||
F8E4M3_LARGEST_POW2 = 8
|
||||
# max value of `torch.float8_e4m3fn` (448)
|
||||
F8E4M3_MAX_VAL = torch.finfo(torch.float8_e4m3fn).max
|
||||
# exponent bias of `torch.float8_e8m0fnu`
|
||||
F8E8M0_EXP_BIAS = 127
|
||||
|
||||
def data_to_mx_scale(x, block_size):
|
||||
# simple implementation of https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||
# section 6.3, not all edge cases (such as NaN) are handled/tested
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(-1, block_size)
|
||||
max_abs = torch.amax(torch.abs(x), 1)
|
||||
largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs))
|
||||
scale_e8m0_unbiased = largest_p2_lt_max_abs - F8E4M3_LARGEST_POW2
|
||||
scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, -1 * F8E8M0_EXP_BIAS, F8E8M0_EXP_BIAS)
|
||||
scale_e8m0_biased = scale_e8m0_unbiased + F8E8M0_EXP_BIAS
|
||||
scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8)
|
||||
scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
|
||||
return scale_e8m0_biased.reshape(orig_shape[0], -1)
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
|
||||
class TestFP8MatmulCuda(TestCase):
|
||||
|
||||
|
|
@ -843,287 +768,6 @@ class TestFP8MatmulCuda(TestCase):
|
|||
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
|
||||
@parametrize("test_case_name", [
|
||||
"a_eye_b_eye",
|
||||
"a_ones_b_ones",
|
||||
"a_ones_modified_b_ones",
|
||||
"a_ones_b_ones_modified",
|
||||
"a_scale_modified_b_ones",
|
||||
"a_ones_b_scale_modified",
|
||||
"data_random_scales_one",
|
||||
"data_random_scales_from_data",
|
||||
])
|
||||
@parametrize("fast_accum", [False, True])
|
||||
@parametrize("mkn", [
|
||||
# Nice shapes
|
||||
(128, 128, 128),
|
||||
(256, 256, 256),
|
||||
(128, 256, 512),
|
||||
(256, 512, 128),
|
||||
(512, 128, 256),
|
||||
|
||||
# Non block multiples
|
||||
(65, 96, 112),
|
||||
(197, 224, 272),
|
||||
# K not multiple of 32
|
||||
(197, 240, 272),
|
||||
|
||||
# Very unbalanced
|
||||
(1023, 64, 48),
|
||||
(31, 1024, 64),
|
||||
(45, 96, 1024),
|
||||
|
||||
# Mixed large and small
|
||||
(2, 1024, 128),
|
||||
(127, 96, 1024),
|
||||
(1025, 128, 96)
|
||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
||||
def test_blockwise_mxfp8_numerics(self, test_case_name, fast_accum, mkn) -> None:
|
||||
# inspiration: https://github.com/pytorch/ao/pull/1625
|
||||
|
||||
device = "cuda"
|
||||
M, K, N = mkn
|
||||
BLOCK_SIZE = 32
|
||||
require_exact_match = True
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
if test_case_name == "a_eye_b_eye":
|
||||
if not ((M == K) and (M == N)):
|
||||
return unittest.skip("this test is only defined for M == K == N, skipping")
|
||||
A_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
|
||||
B_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
|
||||
|
||||
A = A_ref.to(torch.float8_e4m3fn)
|
||||
B = B_ref.to(torch.float8_e4m3fn)
|
||||
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
# convert to swizzled format
|
||||
A_scale = to_blocked(A_scale)
|
||||
B_scale = to_blocked(B_scale)
|
||||
|
||||
elif test_case_name == "a_ones_b_ones":
|
||||
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
||||
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
||||
|
||||
A = A_ref.to(torch.float8_e4m3fn)
|
||||
B = B_ref.to(torch.float8_e4m3fn)
|
||||
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
# convert to swizzled format
|
||||
A_scale = to_blocked(A_scale)
|
||||
B_scale = to_blocked(B_scale)
|
||||
|
||||
elif test_case_name == "a_ones_modified_b_ones":
|
||||
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
||||
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
||||
|
||||
A = A_ref.to(torch.float8_e4m3fn)
|
||||
B = B_ref.to(torch.float8_e4m3fn)
|
||||
|
||||
A_ref[1][0:BLOCK_SIZE] = 2
|
||||
A[1][0:BLOCK_SIZE] = 2
|
||||
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
# convert to swizzled format
|
||||
A_scale = to_blocked(A_scale)
|
||||
B_scale = to_blocked(B_scale)
|
||||
|
||||
elif test_case_name == "a_ones_b_ones_modified":
|
||||
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
||||
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
||||
|
||||
A = A_ref.to(torch.float8_e4m3fn)
|
||||
B = B_ref.to(torch.float8_e4m3fn)
|
||||
|
||||
B_ref[1][0:BLOCK_SIZE] = 2
|
||||
B[1][0:BLOCK_SIZE] = 2
|
||||
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
# convert to swizzled format
|
||||
A_scale = to_blocked(A_scale)
|
||||
B_scale = to_blocked(B_scale)
|
||||
|
||||
elif test_case_name == "a_scale_modified_b_ones":
|
||||
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
||||
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
||||
|
||||
A = A_ref.to(torch.float8_e4m3fn)
|
||||
B = B_ref.to(torch.float8_e4m3fn)
|
||||
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
|
||||
A_ref[1][0:BLOCK_SIZE] = 4
|
||||
A[1][0:BLOCK_SIZE] = 2
|
||||
A_scale[1][0] = 2
|
||||
|
||||
# convert to swizzled format
|
||||
A_scale = to_blocked(A_scale)
|
||||
B_scale = to_blocked(B_scale)
|
||||
|
||||
elif test_case_name == "a_ones_b_scale_modified":
|
||||
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
||||
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
||||
|
||||
A = A_ref.to(torch.float8_e4m3fn)
|
||||
B = B_ref.to(torch.float8_e4m3fn)
|
||||
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
|
||||
B_ref[1][0:BLOCK_SIZE] = 4
|
||||
B[1][0:BLOCK_SIZE] = 2
|
||||
B_scale[1][0] = 2
|
||||
|
||||
# convert to swizzled format
|
||||
A_scale = to_blocked(A_scale)
|
||||
B_scale = to_blocked(B_scale)
|
||||
|
||||
elif test_case_name == "data_random_scales_one":
|
||||
require_exact_match = False
|
||||
# scales all-ones, element data random while being exactly representable in float8_e4m3fn
|
||||
|
||||
# generate integers in [0, 255] and interpret as float8_e4m3fn
|
||||
A_ref = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(torch.float8_e4m3fn).to(torch.bfloat16)
|
||||
B_ref = torch.randint(0, 255, (N, K), device=device, dtype=torch.uint8).view(torch.float8_e4m3fn).to(torch.bfloat16)
|
||||
# modification: don't allow NaN values
|
||||
A_ref[torch.isnan(A_ref)] = 0
|
||||
B_ref[torch.isnan(B_ref)] = 0
|
||||
|
||||
A = A_ref.to(torch.float8_e4m3fn)
|
||||
B = B_ref.to(torch.float8_e4m3fn)
|
||||
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
|
||||
# convert to swizzled format
|
||||
A_scale = to_blocked(A_scale)
|
||||
B_scale = to_blocked(B_scale)
|
||||
|
||||
elif test_case_name == "data_random_scales_from_data":
|
||||
if not K % BLOCK_SIZE == 0:
|
||||
return unittest.skip(f"this test is only defined for K a multiple of {BLOCK_SIZE}, skipping")
|
||||
require_exact_match = False
|
||||
# random data, scales from data
|
||||
A_ref = torch.randn((M, K), device=device, dtype=torch.bfloat16) * 1000
|
||||
B_ref = torch.randn((N, K), device=device, dtype=torch.bfloat16) * 1000
|
||||
|
||||
# Calculate scales based on the inputs
|
||||
A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE)
|
||||
B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE)
|
||||
|
||||
max_val = F8E4M3_MAX_VAL
|
||||
min_val = -1 * max_val
|
||||
|
||||
A = (A_ref.reshape(-1, BLOCK_SIZE) / A_scale.reshape(M * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(M, K)
|
||||
A = A.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
|
||||
B = (B_ref.reshape(-1, BLOCK_SIZE) / B_scale.reshape(N * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(N, K)
|
||||
B = B.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
|
||||
|
||||
# convert to swizzled format
|
||||
A_scale = to_blocked(A_scale)
|
||||
B_scale = to_blocked(B_scale)
|
||||
|
||||
C_ref = A_ref @ B_ref.t()
|
||||
|
||||
C = torch._scaled_mm(
|
||||
A,
|
||||
B.t(),
|
||||
A_scale,
|
||||
B_scale,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=fast_accum,
|
||||
)
|
||||
|
||||
if require_exact_match:
|
||||
torch.testing.assert_close(C, C_ref, atol=0, rtol=0)
|
||||
else:
|
||||
sqnr = compute_error(C_ref, C)
|
||||
assert sqnr.item() > 22.0
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@skipIfRocm()
|
||||
def test_blockwise_mxfloat8_error_messages(self, device) -> None:
|
||||
M, K, N = (1024, 512, 2048)
|
||||
BLOCK_SIZE_K = 32
|
||||
BLOCK_SIZE_MN = 128
|
||||
fill_value = 0.5
|
||||
|
||||
x = torch.full((M, K), fill_value, device=device)
|
||||
y = torch.full((N, K), fill_value, device=device)
|
||||
|
||||
x_fp8 = x.to(e4m3_type)
|
||||
y_fp8 = y.to(e4m3_type).t()
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
num_k_blocks = ceil_div(K, BLOCK_SIZE_K)
|
||||
padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4
|
||||
expected_a_size = BLOCK_SIZE_MN * ceil_div(M, BLOCK_SIZE_MN) * padded_num_k_blocks
|
||||
expected_b_size = BLOCK_SIZE_MN * ceil_div(N, BLOCK_SIZE_MN) * padded_num_k_blocks
|
||||
|
||||
|
||||
# Test wrong scale tensor size for scale_a with correct dtype
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
f"For BlockWise scaling: Expected scale_a size to be {expected_a_size} "
|
||||
f"but got {expected_a_size - 1}"
|
||||
),
|
||||
):
|
||||
incorrect_size_a = torch.ones(expected_a_size - 1, device=device, dtype=torch.float8_e8m0fnu)
|
||||
correct_size_b = torch.ones(expected_b_size, device=device, dtype=torch.float8_e8m0fnu)
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=incorrect_size_a,
|
||||
scale_b=correct_size_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Test wrong scale tensor size for scale_b with correct dtype
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
f"For BlockWise scaling: Expected scale_b size to be {expected_b_size} "
|
||||
f"but got {expected_b_size + 1}"
|
||||
),
|
||||
):
|
||||
correct_size_a = torch.ones(expected_a_size, device=device, dtype=torch.float8_e8m0fnu)
|
||||
incorrect_size_b = torch.ones(expected_b_size + 1, device=device, dtype=torch.float8_e8m0fnu)
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=correct_size_a,
|
||||
scale_b=incorrect_size_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Test non-contiguous scale tensors with correct dtype
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
"For BlockWise scaling: Both scale_a and scale_b must be contiguous"
|
||||
),
|
||||
):
|
||||
non_contiguous_a = torch.ones(expected_a_size * 2, device=device, dtype=torch.float8_e8m0fnu)[::2]
|
||||
contiguous_b = torch.ones(expected_b_size, device=device, dtype=torch.float8_e8m0fnu)
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=non_contiguous_a,
|
||||
scale_b=contiguous_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic
|
|||
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
|
||||
SM89OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9))
|
||||
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
|
||||
SM100OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0))
|
||||
|
||||
IS_THOR = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 10
|
||||
and torch.cuda.get_device_capability()[1] > 0)
|
||||
|
|
@ -102,7 +101,6 @@ def evaluate_platform_supports_fp8():
|
|||
|
||||
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
|
||||
|
||||
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater)
|
||||
|
||||
if TEST_NUMBA:
|
||||
try:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user