Add AOT inductor support for _scaled_mm for CPU (#141961)

This PR is to add AOT inductor support for _scaled_mm for CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141961
Approved by: https://github.com/malfet
ghstack dependencies: #139975
This commit is contained in:
Jiang, Yanbing 2024-12-27 06:48:27 +00:00 committed by PyTorch MergeBot
parent cbc4cf3043
commit 3fabd10c40
4 changed files with 31 additions and 22 deletions

View File

@ -48,7 +48,7 @@ from torch.testing._internal.common_utils import (
TEST_WITH_ROCM,
)
from torch.testing._internal.custom_tensor import CustomTensorPlainOut
from torch.testing._internal.inductor_utils import GPU_TYPE
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
from torch.testing._internal.triton_utils import HAS_GPU, requires_gpu
from torch.utils import _pytree as pytree
@ -690,16 +690,11 @@ class AOTInductorTestsTemplate:
example_inputs = (x, y)
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
@unittest.skipIf(
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0),
"FP8 is only supported on H100+",
)
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
@skipIfXpu
def test_fp8(self):
# cuda only
if self.device != "cuda":
return
if self.device == "cuda" and not SM90OrLater:
raise unittest.SkipTest("FP8 is only supported on H100+")
class Model(torch.nn.Module):
def __init__(self, dtype):
@ -720,16 +715,18 @@ class AOTInductorTestsTemplate:
dtype = torch.float16
a_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
b_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
input_bias = torch.rand(32, device=GPU_TYPE, dtype=dtype)
a_scale = torch.Tensor([1.0]).to(device=self.device)
b_scale = torch.Tensor([1.0]).to(device=self.device)
input_bias = torch.rand(32, device=self.device, dtype=dtype)
weight_shape = (32, 16)
weight = torch.rand(*weight_shape, device=GPU_TYPE, dtype=dtype).T
weight = torch.rand(*weight_shape, device=self.device, dtype=dtype).T
a_inverse_scale = 1 / a_scale
b_inverse_scale = 1 / b_scale
x_shape = (16, 16)
x = torch.rand(*x_shape, device=GPU_TYPE, dtype=dtype).to(torch.float8_e4m3fn)
x = torch.rand(*x_shape, device=self.device, dtype=dtype).to(
torch.float8_e4m3fn
)
dim0_x = Dim("dim0_x", min=1, max=2048)
dynamic_shapes = ({0: dim0_x}, None, None, None, None)
self.check_model(
@ -738,16 +735,11 @@ class AOTInductorTestsTemplate:
dynamic_shapes=dynamic_shapes,
)
@unittest.skipIf(
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0),
"FP8 is only supported on H100+",
)
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
@skipIfXpu
def test_fp8_view_of_param(self):
# cuda only
if self.device != GPU_TYPE:
return
if self.device == "cuda" and not SM90OrLater:
raise unittest.SkipTest("FP8 is only supported on H100+")
class Model(torch.nn.Module):
def __init__(self, dtype, weight):
@ -4270,7 +4262,7 @@ copy_tests(
)
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
@unittest.skipIf(sys.platform == "darwin" or not HAS_GPU, "No CUDA on MacOS")
class AOTInductorTestABICompatibleGpu(TestCase):
device = GPU_TYPE
device_type = GPU_TYPE
@ -4292,5 +4284,5 @@ if __name__ == "__main__":
from torch._inductor.test_case import run_tests
# cpp_extension N/A in fbcode
if HAS_GPU or sys.platform == "darwin":
if HAS_GPU or sys.platform == "darwin" or HAS_CPU:
run_tests(needs="filelock")

View File

@ -158,6 +158,9 @@ def supported_dtype_of_cpp_wrapper(dtype: torch.dtype, device_type: str) -> bool
supported_dtype.add(torch.float8_e5m2)
supported_dtype.add(torch.float8_e4m3fnuz)
supported_dtype.add(torch.float8_e5m2fnuz)
if device_type == "cpu":
supported_dtype.add(torch.float8_e4m3fn)
supported_dtype.add(torch.float8_e5m2)
return dtype in supported_dtype

View File

@ -54,6 +54,8 @@
// The following files are implemented in a header-only way and are guarded by
// test/cpp/aoti_abi_check
#include <c10/util/BFloat16.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Half.h>
#include <c10/util/complex.h>
@ -164,6 +166,12 @@ aoti_torch_item_bfloat16(AtenTensorHandle tensor, c10::BFloat16* ret_value);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_complex64(
AtenTensorHandle tensor,
c10::complex<float>* ret_value);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_float8_e4m3fn(
AtenTensorHandle tensor,
c10::Float8_e4m3fn* ret_value);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_float8_e5m2(
AtenTensorHandle tensor,
c10::Float8_e5m2* ret_value);
// Functions for wrapping a scalar value to a single-element tensor
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_float32(
@ -701,6 +709,8 @@ int32_t aoti_torch_dtype() = delete;
namespace c10 {
struct BFloat16;
struct Half;
struct Float8_e4m3fn;
struct Float8_e5m2;
} // namespace c10
DEFINE_DTYPE_SPECIALIZATION(c10::BFloat16, bfloat16)
@ -714,6 +724,8 @@ DEFINE_DTYPE_SPECIALIZATION(int16_t, int16)
DEFINE_DTYPE_SPECIALIZATION(int32_t, int32)
DEFINE_DTYPE_SPECIALIZATION(int64_t, int64)
DEFINE_DTYPE_SPECIALIZATION(bool, bool)
DEFINE_DTYPE_SPECIALIZATION(c10::Float8_e4m3fn, float8_e4m3fn)
DEFINE_DTYPE_SPECIALIZATION(c10::Float8_e5m2, float8_e5m2)
#endif

View File

@ -200,6 +200,8 @@ AOTI_TORCH_ITEM_IMPL(int64, int64_t)
AOTI_TORCH_ITEM_IMPL(bool, bool)
AOTI_TORCH_ITEM_IMPL(bfloat16, c10::BFloat16)
AOTI_TORCH_ITEM_IMPL(complex64, c10::complex<float>)
AOTI_TORCH_ITEM_IMPL(float8_e4m3fn, c10::Float8_e4m3fn)
AOTI_TORCH_ITEM_IMPL(float8_e5m2, c10::Float8_e5m2)
#undef AOTI_TORCH_ITEM_IMPL
#define AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(dtype, ctype, ttype) \