Revert "torch._scaled_mm with MXFP8 (#147548)"

This reverts commit 12b9674cb6.

Reverted https://github.com/pytorch/pytorch/pull/147548 on behalf of https://github.com/wdvr due to failing internal build - similar to previous, see below ([comment](https://github.com/pytorch/pytorch/pull/147548#issuecomment-2684134336))
This commit is contained in:
PyTorch MergeBot 2025-02-26 07:17:24 +00:00
parent 4216478250
commit a84db75e1b
7 changed files with 16 additions and 463 deletions

View File

@ -14,7 +14,6 @@
#include <c10/macros/Export.h>
#include <c10/util/env.h>
#include <c10/util/irange.h>
#include <c10/core/ScalarType.h>
#ifdef USE_ROCM
#include <hipblaslt/hipblaslt-ext.hpp>
@ -1504,12 +1503,10 @@ void scaled_gemm(
const void* mat1_scale_ptr,
int64_t mat1_ld,
ScalarType mat1_dtype,
ScalarType mat1_scale_dtype,
const void* mat2_ptr,
const void* mat2_scale_ptr,
int64_t mat2_ld,
ScalarType mat2_dtype,
ScalarType mat2_scale_dtype,
const void* bias_ptr,
ScalarType bias_dtype,
void* result_ptr,
@ -1537,8 +1534,10 @@ void scaled_gemm(
// rowwise isn't supported using cublaslt or older hipblaslt
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
#endif
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
{
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
}
if (result_scale_ptr != nullptr) {
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
}
@ -1561,15 +1560,6 @@ void scaled_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
}
if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
#if CUDA_VERSION >= 12080
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
#else
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above");
#endif // CUDA_VERSION >= 12080
}
auto stream = c10::cuda::getCurrentCUDAStream();
size_t workspaceSize = 0;
auto workspace_ptr = _getWorkspace(workspaceSize);

View File

@ -130,12 +130,10 @@ void scaled_gemm(
const void* mat1_scale_ptr,
int64_t mat1_ld,
ScalarType mat1_dtype,
ScalarType mat1_scale_dtype,
const void* mat2_ptr,
const void* mat2_scale_ptr,
int64_t mat2_ld,
ScalarType mat2_dtype,
ScalarType mat2_scale_dtype,
const void* bias_ptr,
ScalarType bias_dtype,
void* result_ptr,

View File

@ -10,7 +10,6 @@
#pragma once
#include <string>
#include <c10/core/ScalarType.h>
#include <ATen/cuda/tunable/TunableOp.h>
#include <ATen/cuda/CUDABlas.h>
@ -425,12 +424,10 @@ struct ScaledGemmParams : OpParams {
const void* a_scale_ptr{};
int64_t lda{};
ScalarType a_dtype{};
ScalarType a_scale_dtype{};
const void* b{};
const void* b_scale_ptr{};
int64_t ldb{};
ScalarType b_dtype{};
ScalarType b_scale_dtype{};
const void* bias_ptr{};
ScalarType bias_dtype{};
void* c{};

View File

@ -95,12 +95,10 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
params->a_scale_ptr,
params->lda,
params->a_dtype,
params->a_scale_dtype,
params->b,
params->b_scale_ptr,
params->ldb,
params->b_dtype,
params->b_scale_dtype,
params->bias_ptr,
params->bias_dtype,
params->c,

View File

@ -1,5 +1,4 @@
#include <cstdint>
#include <c10/util/typeid.h>
#include <c10/util/Exception.h>
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
@ -96,33 +95,11 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
}
struct cublasCommonArgs {
cublasCommonArgs(
const Tensor& mat1,
const Tensor& mat2,
Tensor& c,
const std::optional<Tensor>& scale_a = c10::nullopt,
const std::optional<Tensor>& scale_b = c10::nullopt,
const std::optional<Tensor>& scale_result = c10::nullopt) {
cublasCommonArgs(const Tensor& mat1, const Tensor& mat2, Tensor& c) {
bool transpose_result = false, transpose_mat1 = false, transpose_mat2 = false;
result = prepare_matrix_for_cublas(c, transpose_result);
mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1, transpose_result);
matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2, transpose_result);
// Handle scale tensors if provided
if (scale_a && scale_b) {
// By default since we return in row-major we run the gemm
// as B.T @ A.T, check transpose_result to determine if we flip the scales
scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr();
scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type();
scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr();
scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type();
}
if (scale_result) {
scale_result_ptr = scale_result->data_ptr();
scale_result_dtype = scale_result->scalar_type();
}
auto mat1_sizes = mat1.sizes();
auto mat2_sizes = mat2.sizes();
if (transpose_result) {
@ -138,23 +115,13 @@ struct cublasCommonArgs {
lda = mata->stride((transpose_mat1 == transpose_result) ? 1 : 0);
ldb = matb->stride((transpose_mat2 == transpose_result) ? 1 : 0);
result_ld = result->stride(transpose_result ? 0 : 1);
transa = transpose_mat1 ? mata->is_conj() ? 'c' : 't' : 'n';
transb = transpose_mat2 ? matb->is_conj() ? 'c' : 't' : 'n';
transa = transpose_mat1 ? mata->is_conj() ? 'c' : 't' : 'n';
transb = transpose_mat2 ? matb->is_conj() ? 'c' : 't' : 'n';
}
// Matrix members
char transa, transb;
int64_t m, n, k;
int64_t lda, ldb, result_ld;
c10::MaybeOwned<Tensor> mata, matb, result;
// Scale members
void* scale_mata_ptr = nullptr;
void* scale_matb_ptr = nullptr;
void* scale_result_ptr = nullptr;
std::optional<c10::ScalarType> scale_mata_dtype;
std::optional<c10::ScalarType> scale_matb_dtype;
std::optional<c10::ScalarType> scale_result_dtype;
};
} // namespace
@ -936,10 +903,9 @@ static bool _scaled_mm_is_fnuz() {
namespace{
enum class ScalingType : std::uint8_t {
enum class ScalingType {
TensorWise,
RowWise,
BlockWise,
Error
};
/*
@ -947,13 +913,10 @@ enum class ScalingType : std::uint8_t {
* ---------------------------
* Conditions and corresponding Scaling Types:
*
* - If scale tensors are Float8_e8m0fnu:
* - Returns BlockWise (with additional size checks).
*
* - If scale_a.numel() == 1 && scale_b.numel() == 1:
* - Returns TensorWise.
*
* - Else if scale_a.dim() == 2 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
* - Else if scale_a.dim() == 1 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
* - Returns RowWise.
*
* - Otherwise:
@ -966,40 +929,7 @@ ScalingType get_scaling_type(
const at::Tensor& scale_a,
const at::Tensor& scale_b,
int64_t dim_m,
int64_t dim_k,
int64_t dim_n) {
// Check for BlockWise scaling (FP8_E8M0 types)
if (scale_a.scalar_type() == scale_b.scalar_type() &&
scale_a.scalar_type() == at::kFloat8_e8m0fnu) {
constexpr int64_t BLOCK_SIZE_K = 32;
constexpr int64_t BLOCK_SIZE_MN = 128;
auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; };
auto num_k_blocks = ceil_div(dim_k, BLOCK_SIZE_K);
auto padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4;
// TODO: We might want to enforce some structure on the shapes of the scale
// tensors
// Check expected sizes for block-wise scaling
auto expected_a_size =
BLOCK_SIZE_MN * ceil_div(dim_m, BLOCK_SIZE_MN) * padded_num_k_blocks;
auto expected_b_size =
BLOCK_SIZE_MN * ceil_div(dim_n, BLOCK_SIZE_MN) * padded_num_k_blocks;
TORCH_CHECK(scale_a.numel() == expected_a_size,
"For BlockWise scaling: Expected scale_a size to be ",
expected_a_size, " but got ", scale_a.numel());
TORCH_CHECK(scale_b.numel() == expected_b_size,
"For BlockWise scaling: Expected scale_b size to be ",
expected_b_size, " but got ", scale_b.numel());
TORCH_CHECK(
scale_a.is_contiguous() && scale_b.is_contiguous(),
"For BlockWise scaling: Both scale_a and scale_b must be contiguous");
return ScalingType::BlockWise;
}
// Both Per-Tensor and Row-wise scaling expect fp32 tensors
TORCH_CHECK(
scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
@ -1097,7 +1027,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
// Check what type of scaling we are doing based on inputs
ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat1.size(1), mat2.size(1));
ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1));
TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported");
TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
@ -1190,7 +1120,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
}
#endif
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result);
cublasCommonArgs args(mat1, mat2, out);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
@ -1300,7 +1230,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
}
else
#endif
{
{
at::cuda::blas::scaled_gemm(
args.transa,
args.transb,
@ -1308,19 +1238,17 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
args.n,
args.k,
args.mata->data_ptr(),
args.scale_mata_ptr,
scale_a.data_ptr(),
args.lda,
args.mata->scalar_type(),
args.scale_mata_dtype.value(),
args.matb->data_ptr(),
args.scale_matb_ptr,
scale_b.data_ptr(),
args.ldb,
args.matb->scalar_type(),
args.scale_matb_dtype.value(),
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
args.scale_result_ptr,
scale_result ? scale_result->data_ptr() : nullptr,
args.result_ld,
out_dtype_,
use_fast_accum,

View File

@ -19,8 +19,7 @@ from torch.testing._internal.common_cuda import (
SM53OrLater,
SM89OrLater,
_get_torch_cuda_version,
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_MX_GEMM
PLATFORM_SUPPORTS_FP8
)
from torch.testing._internal.common_device_type import (
dtypes,
@ -251,7 +250,6 @@ class TestMatmulCuda(TestCase):
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"
if torch.version.hip and 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName:
e4m3_type = torch.float8_e4m3fnuz
@ -368,79 +366,6 @@ def to_fp8_saturated(
return x.to(fp8_dtype)
# copied from https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/mx/to_blocked.py
def ceil_div(a, b):
return (a + b - 1) // b
def to_blocked(input_matrix) -> torch.Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
See:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
input_matrix: Input tensor of shape (H, W)
Returns:
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
"""
rows, cols = input_matrix.shape
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)
# Calculate the padded shape
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4
padded = input_matrix
# Ideally we would use torch.nn.pad but it doesn't support float8_e8m0fnu for now
if (rows, cols) != (padded_rows, padded_cols):
padded = torch.zeros((padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype)
padded[:rows, :cols] = input_matrix
# Rearrange the blocks
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
return rearranged.flatten()
def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Computes the error between two tensors in dB.
For more details see:
https://en.wikipedia.org/wiki/Signal-to-noise_ratio
Args:
x: The original tensor.
y: The tensor to compare to the original tensor.
"""
Ps = torch.norm(x)
Pn = torch.norm(x - y)
return 20 * torch.log10(Ps / Pn)
# largest power of 2 representable in `torch.float8_e4m3fn`
F8E4M3_LARGEST_POW2 = 8
# max value of `torch.float8_e4m3fn` (448)
F8E4M3_MAX_VAL = torch.finfo(torch.float8_e4m3fn).max
# exponent bias of `torch.float8_e8m0fnu`
F8E8M0_EXP_BIAS = 127
def data_to_mx_scale(x, block_size):
# simple implementation of https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
# section 6.3, not all edge cases (such as NaN) are handled/tested
orig_shape = x.shape
x = x.reshape(-1, block_size)
max_abs = torch.amax(torch.abs(x), 1)
largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs))
scale_e8m0_unbiased = largest_p2_lt_max_abs - F8E4M3_LARGEST_POW2
scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, -1 * F8E8M0_EXP_BIAS, F8E8M0_EXP_BIAS)
scale_e8m0_biased = scale_e8m0_unbiased + F8E8M0_EXP_BIAS
scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8)
scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
return scale_e8m0_biased.reshape(orig_shape[0], -1)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
class TestFP8MatmulCuda(TestCase):
@ -843,287 +768,6 @@ class TestFP8MatmulCuda(TestCase):
self.assertEqual(out_dtype, out_fp8.dtype)
self.assertEqual(out_fp32, out_fp8.to(torch.float))
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
@parametrize("test_case_name", [
"a_eye_b_eye",
"a_ones_b_ones",
"a_ones_modified_b_ones",
"a_ones_b_ones_modified",
"a_scale_modified_b_ones",
"a_ones_b_scale_modified",
"data_random_scales_one",
"data_random_scales_from_data",
])
@parametrize("fast_accum", [False, True])
@parametrize("mkn", [
# Nice shapes
(128, 128, 128),
(256, 256, 256),
(128, 256, 512),
(256, 512, 128),
(512, 128, 256),
# Non block multiples
(65, 96, 112),
(197, 224, 272),
# K not multiple of 32
(197, 240, 272),
# Very unbalanced
(1023, 64, 48),
(31, 1024, 64),
(45, 96, 1024),
# Mixed large and small
(2, 1024, 128),
(127, 96, 1024),
(1025, 128, 96)
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
def test_blockwise_mxfp8_numerics(self, test_case_name, fast_accum, mkn) -> None:
# inspiration: https://github.com/pytorch/ao/pull/1625
device = "cuda"
M, K, N = mkn
BLOCK_SIZE = 32
require_exact_match = True
def ceil_div(a, b):
return (a + b - 1) // b
if test_case_name == "a_eye_b_eye":
if not ((M == K) and (M == N)):
return unittest.skip("this test is only defined for M == K == N, skipping")
A_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
B_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
A = A_ref.to(torch.float8_e4m3fn)
B = B_ref.to(torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
# convert to swizzled format
A_scale = to_blocked(A_scale)
B_scale = to_blocked(B_scale)
elif test_case_name == "a_ones_b_ones":
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
A = A_ref.to(torch.float8_e4m3fn)
B = B_ref.to(torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
# convert to swizzled format
A_scale = to_blocked(A_scale)
B_scale = to_blocked(B_scale)
elif test_case_name == "a_ones_modified_b_ones":
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
A = A_ref.to(torch.float8_e4m3fn)
B = B_ref.to(torch.float8_e4m3fn)
A_ref[1][0:BLOCK_SIZE] = 2
A[1][0:BLOCK_SIZE] = 2
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
# convert to swizzled format
A_scale = to_blocked(A_scale)
B_scale = to_blocked(B_scale)
elif test_case_name == "a_ones_b_ones_modified":
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
A = A_ref.to(torch.float8_e4m3fn)
B = B_ref.to(torch.float8_e4m3fn)
B_ref[1][0:BLOCK_SIZE] = 2
B[1][0:BLOCK_SIZE] = 2
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
# convert to swizzled format
A_scale = to_blocked(A_scale)
B_scale = to_blocked(B_scale)
elif test_case_name == "a_scale_modified_b_ones":
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
A = A_ref.to(torch.float8_e4m3fn)
B = B_ref.to(torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
A_ref[1][0:BLOCK_SIZE] = 4
A[1][0:BLOCK_SIZE] = 2
A_scale[1][0] = 2
# convert to swizzled format
A_scale = to_blocked(A_scale)
B_scale = to_blocked(B_scale)
elif test_case_name == "a_ones_b_scale_modified":
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
A = A_ref.to(torch.float8_e4m3fn)
B = B_ref.to(torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_ref[1][0:BLOCK_SIZE] = 4
B[1][0:BLOCK_SIZE] = 2
B_scale[1][0] = 2
# convert to swizzled format
A_scale = to_blocked(A_scale)
B_scale = to_blocked(B_scale)
elif test_case_name == "data_random_scales_one":
require_exact_match = False
# scales all-ones, element data random while being exactly representable in float8_e4m3fn
# generate integers in [0, 255] and interpret as float8_e4m3fn
A_ref = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(torch.float8_e4m3fn).to(torch.bfloat16)
B_ref = torch.randint(0, 255, (N, K), device=device, dtype=torch.uint8).view(torch.float8_e4m3fn).to(torch.bfloat16)
# modification: don't allow NaN values
A_ref[torch.isnan(A_ref)] = 0
B_ref[torch.isnan(B_ref)] = 0
A = A_ref.to(torch.float8_e4m3fn)
B = B_ref.to(torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
# convert to swizzled format
A_scale = to_blocked(A_scale)
B_scale = to_blocked(B_scale)
elif test_case_name == "data_random_scales_from_data":
if not K % BLOCK_SIZE == 0:
return unittest.skip(f"this test is only defined for K a multiple of {BLOCK_SIZE}, skipping")
require_exact_match = False
# random data, scales from data
A_ref = torch.randn((M, K), device=device, dtype=torch.bfloat16) * 1000
B_ref = torch.randn((N, K), device=device, dtype=torch.bfloat16) * 1000
# Calculate scales based on the inputs
A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE)
B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE)
max_val = F8E4M3_MAX_VAL
min_val = -1 * max_val
A = (A_ref.reshape(-1, BLOCK_SIZE) / A_scale.reshape(M * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(M, K)
A = A.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
B = (B_ref.reshape(-1, BLOCK_SIZE) / B_scale.reshape(N * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(N, K)
B = B.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
# convert to swizzled format
A_scale = to_blocked(A_scale)
B_scale = to_blocked(B_scale)
C_ref = A_ref @ B_ref.t()
C = torch._scaled_mm(
A,
B.t(),
A_scale,
B_scale,
out_dtype=torch.bfloat16,
use_fast_accum=fast_accum,
)
if require_exact_match:
torch.testing.assert_close(C, C_ref, atol=0, rtol=0)
else:
sqnr = compute_error(C_ref, C)
assert sqnr.item() > 22.0
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@skipIfRocm()
def test_blockwise_mxfloat8_error_messages(self, device) -> None:
M, K, N = (1024, 512, 2048)
BLOCK_SIZE_K = 32
BLOCK_SIZE_MN = 128
fill_value = 0.5
x = torch.full((M, K), fill_value, device=device)
y = torch.full((N, K), fill_value, device=device)
x_fp8 = x.to(e4m3_type)
y_fp8 = y.to(e4m3_type).t()
def ceil_div(a, b):
return (a + b - 1) // b
num_k_blocks = ceil_div(K, BLOCK_SIZE_K)
padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4
expected_a_size = BLOCK_SIZE_MN * ceil_div(M, BLOCK_SIZE_MN) * padded_num_k_blocks
expected_b_size = BLOCK_SIZE_MN * ceil_div(N, BLOCK_SIZE_MN) * padded_num_k_blocks
# Test wrong scale tensor size for scale_a with correct dtype
with self.assertRaisesRegex(
RuntimeError,
re.escape(
f"For BlockWise scaling: Expected scale_a size to be {expected_a_size} "
f"but got {expected_a_size - 1}"
),
):
incorrect_size_a = torch.ones(expected_a_size - 1, device=device, dtype=torch.float8_e8m0fnu)
correct_size_b = torch.ones(expected_b_size, device=device, dtype=torch.float8_e8m0fnu)
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=incorrect_size_a,
scale_b=correct_size_b,
out_dtype=torch.bfloat16,
)
# Test wrong scale tensor size for scale_b with correct dtype
with self.assertRaisesRegex(
RuntimeError,
re.escape(
f"For BlockWise scaling: Expected scale_b size to be {expected_b_size} "
f"but got {expected_b_size + 1}"
),
):
correct_size_a = torch.ones(expected_a_size, device=device, dtype=torch.float8_e8m0fnu)
incorrect_size_b = torch.ones(expected_b_size + 1, device=device, dtype=torch.float8_e8m0fnu)
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=correct_size_a,
scale_b=incorrect_size_b,
out_dtype=torch.bfloat16,
)
# Test non-contiguous scale tensors with correct dtype
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"For BlockWise scaling: Both scale_a and scale_b must be contiguous"
),
):
non_contiguous_a = torch.ones(expected_a_size * 2, device=device, dtype=torch.float8_e8m0fnu)[::2]
contiguous_b = torch.ones(expected_b_size, device=device, dtype=torch.float8_e8m0fnu)
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=non_contiguous_a,
scale_b=contiguous_b,
out_dtype=torch.bfloat16,
)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")

View File

@ -32,7 +32,6 @@ SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
SM89OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9))
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
SM100OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0))
IS_THOR = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 10
and torch.cuda.get_device_capability()[1] > 0)
@ -102,7 +101,6 @@ def evaluate_platform_supports_fp8():
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater)
if TEST_NUMBA:
try: