mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Better 1x128, 128x128 error handling on non-Hopper (#166639)
Summary: Blockwise 1x128 and 128x128 scaling is only available on CUDA >= 12.9 and only on Hopper GPUs. Attempting to run on B200 would give a hard-to-debug `CUBLAS_STATUS_NOT_SUPPORTED`. Add a more helpful `NotImplementedError` to catch this case. Also more explicitly disable ROCm builds for relevant methods, based on lack of support per [hipBLASlt docs](https://rocm.docs.amd.com/projects/hipBLASLt/en/latest/reference/datatypes.html#_CPPv4N28hipblasLtMatmulMatrixScale_t40HIPBLASLT_MATMUL_MATRIX_SCALE_VEC128_32FE). Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/166639 Approved by: https://github.com/drisspg
This commit is contained in:
parent
f911d64750
commit
99b05d1b78
|
|
@ -794,6 +794,24 @@ void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const Sc
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
_check_deepseek_support() {
|
||||
#ifndef USE_ROCM
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (dprops->major != 9) {
|
||||
// Only on Hopper GPUs
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
dprops->major == 9,
|
||||
"DeepSeek style (1x128, 128x128) scaling only supported in CUDA for SM90")
|
||||
}
|
||||
// Only in cublasLt >= 12.9
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900,
|
||||
"DeepSeek style (1x128, 128x128) scaling requires cublasLt >= 12.9"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
_scaled_block1x128_block1x128(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
|
|
@ -802,8 +820,12 @@ _scaled_block1x128_block1x128(
|
|||
const c10::ScalarType out_dtype,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
|
|
@ -821,6 +843,12 @@ _scaled_block1x128_block1x128(
|
|||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"1x128 and 128x128 scaling not available with ROCm"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
|
|
@ -831,10 +859,12 @@ _scaled_block128x128_block1x128(
|
|||
const c10::ScalarType out_dtype,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
std::cout << "mat_b: " << mat_b.dim() << ", " << mat_b.sizes() << ", " << mat_b.strides() << std::endl;
|
||||
std::cout << "scale_b: " << scale_b.dim() << ", " << scale_b.sizes() << ", " << scale_b.strides() << std::endl;
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
|
||||
|
|
@ -852,6 +882,12 @@ _scaled_block128x128_block1x128(
|
|||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"1x128 and 128x128 scaling not available with ROCm"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
|
|
@ -862,8 +898,12 @@ _scaled_block1x128_block128x128(
|
|||
const c10::ScalarType out_dtype,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, A: shape K//128, B: K//128, N//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
|
|
@ -881,6 +921,12 @@ _scaled_block1x128_block128x128(
|
|||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"1x128 and 128x128 scaling not available with ROCm"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
|
|
|
|||
|
|
@ -1349,6 +1349,56 @@ class TestFP8Matmul(TestCase):
|
|||
# Verify that emulated F8 mm doesn't error
|
||||
mm_float8_emulated_block(x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype)
|
||||
|
||||
@skipIfRocm
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(IS_SM90, "cuBLAS blockwise scaling works on sm90")
|
||||
@unittest.skipIf(
|
||||
_get_torch_cuda_version() < (12, 9),
|
||||
"cuBLAS blockwise scaling added in CUDA 12.9",
|
||||
)
|
||||
@parametrize("output_dtype", [torch.bfloat16, ])
|
||||
@parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
|
||||
@parametrize("M,N,K", [(256, 256, 256), (256, 256, 512)])
|
||||
def test_scaled_mm_deepseek_error_messages(
|
||||
self, output_dtype, lhs_block, rhs_block, M, N, K
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
|
||||
x = torch.randn(M, K, device="cuda", dtype=output_dtype).pow(3)
|
||||
y = torch.randn(N, K, device="cuda", dtype=output_dtype).pow(3)
|
||||
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
lhs_recipe = ScalingType.BlockWise1x128
|
||||
else:
|
||||
lhs_recipe = ScalingType.BlockWise128x128
|
||||
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
rhs_recipe = ScalingType.BlockWise1x128
|
||||
else:
|
||||
rhs_recipe = ScalingType.BlockWise128x128
|
||||
|
||||
# Verify that actual F8 mm doesn't error
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
".*DeepSeek.*scaling.*only supported in CUDA for SM90.*"
|
||||
):
|
||||
scaled_mm_wrap(
|
||||
x_fp8,
|
||||
y_fp8.t(),
|
||||
scale_a=x_scales,
|
||||
scale_recipe_a=lhs_recipe,
|
||||
scale_b=y_scales.t(),
|
||||
scale_recipe_b=rhs_recipe,
|
||||
out_dtype=output_dtype,
|
||||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("which_dim_zero", [0, 1, 2])
|
||||
@parametrize("use_torch_compile", [False, True])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user