mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[cuBLASLt][FP8] cuBLASLt appears to support float8 rowwise-scaling on H100 (#161305)
Following #157905 I think the macro around ``` TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); ``` was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well... CC @lw Pull Request resolved: https://github.com/pytorch/pytorch/pull/161305 Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
parent
b2c7b9ad2d
commit
c2a3024617
|
|
@ -1937,11 +1937,11 @@ void scaled_gemm(
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
|
||||||
cublasLtMatmulDescAttributes_t matmulDescA = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER;
|
cublasLtMatmulDescAttributes_t matmulDescA = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER;
|
||||||
cublasLtMatmulDescAttributes_t matmulDescB = CUBLASLT_MATMUL_DESC_B_SCALE_POINTER;
|
cublasLtMatmulDescAttributes_t matmulDescB = CUBLASLT_MATMUL_DESC_B_SCALE_POINTER;
|
||||||
|
#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
|
||||||
// hipblaslt supported row-wise before cublas, and did so their own way (via
|
// hipblaslt supported row-wise before cublas, and did so their own way (via
|
||||||
// the SCALE_POINTERSs), but then migrated to match how cublas does it (via
|
// the SCALE_POINTERSs), but then migrated to match how cublas does it (via
|
||||||
// the SCALE_MODEs). Here we check for this early custom mode.
|
// the SCALE_MODEs). Here we check for this early custom mode.
|
||||||
bool use_rowwise = (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise);
|
bool use_rowwise = (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise);
|
||||||
#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
|
|
||||||
if (use_rowwise) {
|
if (use_rowwise) {
|
||||||
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
|
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
|
||||||
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
|
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
|
||||||
|
|
@ -1956,8 +1956,12 @@ void scaled_gemm(
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
#else
|
#elif (CUDA_VERSION < 12090) && !defined(USE_ROCM)
|
||||||
// rowwise isn't supported using cublaslt or older hipblaslt
|
// hipblaslt supported row-wise before cublas, and did so their own way (via
|
||||||
|
// the SCALE_POINTERSs), but then migrated to match how cublas does it (via
|
||||||
|
// the SCALE_MODEs). Here we check for this early custom mode.
|
||||||
|
bool use_rowwise = (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise);
|
||||||
|
// rowwise isn't supported using older cublaslt or older hipblaslt
|
||||||
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
|
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
|
||||||
#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
|
#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
|
||||||
computeDesc.setAttribute(matmulDescA, mat1_scale_ptr);
|
computeDesc.setAttribute(matmulDescA, mat1_scale_ptr);
|
||||||
|
|
|
||||||
|
|
@ -465,7 +465,10 @@ class TestFP8Lowering(TestCase):
|
||||||
# autotuning for the compiled case, the results can be different because of
|
# autotuning for the compiled case, the results can be different because of
|
||||||
# the way blocks of results are accumulated (float addition not associative), so
|
# the way blocks of results are accumulated (float addition not associative), so
|
||||||
# setting a small absolute tolerance in these tests
|
# setting a small absolute tolerance in these tests
|
||||||
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
|
if dtype == torch.bfloat16:
|
||||||
|
self.assertEqual(y_eager, y_compiled, rtol=5e-2, atol=0.07)
|
||||||
|
else:
|
||||||
|
self.assertEqual(y_eager, y_compiled, rtol=1e-2, atol=0.05)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
|
|
@ -611,7 +614,7 @@ class TestFP8Lowering(TestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(y_eager.dtype, dtype)
|
self.assertEqual(y_eager.dtype, dtype)
|
||||||
self.assertEqual(y_compiled.dtype, dtype)
|
self.assertEqual(y_compiled.dtype, dtype)
|
||||||
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
|
torch.testing.assert_close(y_eager, y_compiled, rtol=5e-2, atol=0.07)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
|
|
@ -744,7 +747,7 @@ class TestFP8Lowering(TestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(y_eager.dtype, dtype)
|
self.assertEqual(y_eager.dtype, dtype)
|
||||||
self.assertEqual(y_compiled.dtype, dtype)
|
self.assertEqual(y_compiled.dtype, dtype)
|
||||||
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07)
|
torch.testing.assert_close(y_eager, y_compiled, rtol=5e-2, atol=0.07)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||||
@parametrize("M", (1, 3, 33, 257, 1024))
|
@parametrize("M", (1, 3, 33, 257, 1024))
|
||||||
|
|
|
||||||
|
|
@ -1315,18 +1315,26 @@ class TestFP8Matmul(TestCase):
|
||||||
out_dtype=torch.bfloat16,
|
out_dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note re.compile is used, not re.escape. This is to accommodate fn vs fnuz type message.
|
def e5m2():
|
||||||
with self.assertRaisesRegex(
|
out = torch._scaled_mm(
|
||||||
RuntimeError,
|
|
||||||
r"Expected b\.dtype\(\) == at::kFloat8_e4m3fnu?z? to be true, but got false\.",
|
|
||||||
):
|
|
||||||
torch._scaled_mm(
|
|
||||||
x_fp8,
|
x_fp8,
|
||||||
y_fp8.to(e5m2_type),
|
y_fp8.to(e5m2_type),
|
||||||
scale_a=torch.ones((M, 1), device="cuda"),
|
scale_a=torch.ones((M, 1), device="cuda"),
|
||||||
scale_b=torch.ones((1, N), device="cuda"),
|
scale_b=torch.ones((1, N), device="cuda"),
|
||||||
out_dtype=torch.bfloat16,
|
out_dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
if torch.cuda.get_device_capability() == (9, 0) and torch.version.cuda and torch.version.cuda >= "12.9":
|
||||||
|
out = e5m2()
|
||||||
|
self.assertEqual(out, torch.ones_like(out) * 128.)
|
||||||
|
else:
|
||||||
|
# Note re.compile is used, not re.escape. This is to accommodate fn vs fnuz type message.
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
r"Expected b\.dtype\(\) == at::kFloat8_e4m3fnu?z? to be true, but got false\.",
|
||||||
|
):
|
||||||
|
e5m2()
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||||
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
|
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user