mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
ROCm: Disable torch check for Multiplication of two Float8_e5m2 matrices (#148228)
ROCm supports Multiplication of two Float8_e5m2 matrices. Hence disabling the torch check for ROCm. Test command (on ROCm h/w supporting fp8) python test/test_matmul_cuda.py TestFP8MatmulCudaCUDA.test_float8_basics_cuda -v Pull Request resolved: https://github.com/pytorch/pytorch/pull/148228 Approved by: https://github.com/jeffdaily, https://github.com/petrex
This commit is contained in:
parent
e6800bda7f
commit
ed9c8a5d13
|
|
@ -1149,9 +1149,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||||
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
|
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
|
||||||
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
|
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
|
||||||
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
|
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
|
||||||
|
#ifndef USE_ROCM
|
||||||
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
|
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
|
||||||
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
|
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
|
||||||
"Multiplication of two Float8_e5m2 matrices is not supported");
|
"Multiplication of two Float8_e5m2 matrices is not supported");
|
||||||
|
#endif
|
||||||
if (bias) {
|
if (bias) {
|
||||||
TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32");
|
TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32");
|
||||||
TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half,
|
TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user