mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add AOT inductor support for _scaled_mm for CPU (#141961)
This PR is to add AOT inductor support for _scaled_mm for CPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141961 Approved by: https://github.com/malfet ghstack dependencies: #139975
This commit is contained in:
parent
cbc4cf3043
commit
3fabd10c40
|
|
@ -48,7 +48,7 @@ from torch.testing._internal.common_utils import (
|
|||
TEST_WITH_ROCM,
|
||||
)
|
||||
from torch.testing._internal.custom_tensor import CustomTensorPlainOut
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU
|
||||
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
||||
from torch.testing._internal.triton_utils import HAS_GPU, requires_gpu
|
||||
from torch.utils import _pytree as pytree
|
||||
|
|
@ -690,16 +690,11 @@ class AOTInductorTestsTemplate:
|
|||
example_inputs = (x, y)
|
||||
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0),
|
||||
"FP8 is only supported on H100+",
|
||||
)
|
||||
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
|
||||
@skipIfXpu
|
||||
def test_fp8(self):
|
||||
# cuda only
|
||||
if self.device != "cuda":
|
||||
return
|
||||
if self.device == "cuda" and not SM90OrLater:
|
||||
raise unittest.SkipTest("FP8 is only supported on H100+")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, dtype):
|
||||
|
|
@ -720,16 +715,18 @@ class AOTInductorTestsTemplate:
|
|||
|
||||
dtype = torch.float16
|
||||
|
||||
a_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
|
||||
b_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
|
||||
input_bias = torch.rand(32, device=GPU_TYPE, dtype=dtype)
|
||||
a_scale = torch.Tensor([1.0]).to(device=self.device)
|
||||
b_scale = torch.Tensor([1.0]).to(device=self.device)
|
||||
input_bias = torch.rand(32, device=self.device, dtype=dtype)
|
||||
weight_shape = (32, 16)
|
||||
weight = torch.rand(*weight_shape, device=GPU_TYPE, dtype=dtype).T
|
||||
weight = torch.rand(*weight_shape, device=self.device, dtype=dtype).T
|
||||
a_inverse_scale = 1 / a_scale
|
||||
b_inverse_scale = 1 / b_scale
|
||||
|
||||
x_shape = (16, 16)
|
||||
x = torch.rand(*x_shape, device=GPU_TYPE, dtype=dtype).to(torch.float8_e4m3fn)
|
||||
x = torch.rand(*x_shape, device=self.device, dtype=dtype).to(
|
||||
torch.float8_e4m3fn
|
||||
)
|
||||
dim0_x = Dim("dim0_x", min=1, max=2048)
|
||||
dynamic_shapes = ({0: dim0_x}, None, None, None, None)
|
||||
self.check_model(
|
||||
|
|
@ -738,16 +735,11 @@ class AOTInductorTestsTemplate:
|
|||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0),
|
||||
"FP8 is only supported on H100+",
|
||||
)
|
||||
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
|
||||
@skipIfXpu
|
||||
def test_fp8_view_of_param(self):
|
||||
# cuda only
|
||||
if self.device != GPU_TYPE:
|
||||
return
|
||||
if self.device == "cuda" and not SM90OrLater:
|
||||
raise unittest.SkipTest("FP8 is only supported on H100+")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, dtype, weight):
|
||||
|
|
@ -4270,7 +4262,7 @@ copy_tests(
|
|||
)
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
|
||||
@unittest.skipIf(sys.platform == "darwin" or not HAS_GPU, "No CUDA on MacOS")
|
||||
class AOTInductorTestABICompatibleGpu(TestCase):
|
||||
device = GPU_TYPE
|
||||
device_type = GPU_TYPE
|
||||
|
|
@ -4292,5 +4284,5 @@ if __name__ == "__main__":
|
|||
from torch._inductor.test_case import run_tests
|
||||
|
||||
# cpp_extension N/A in fbcode
|
||||
if HAS_GPU or sys.platform == "darwin":
|
||||
if HAS_GPU or sys.platform == "darwin" or HAS_CPU:
|
||||
run_tests(needs="filelock")
|
||||
|
|
|
|||
|
|
@ -158,6 +158,9 @@ def supported_dtype_of_cpp_wrapper(dtype: torch.dtype, device_type: str) -> bool
|
|||
supported_dtype.add(torch.float8_e5m2)
|
||||
supported_dtype.add(torch.float8_e4m3fnuz)
|
||||
supported_dtype.add(torch.float8_e5m2fnuz)
|
||||
if device_type == "cpu":
|
||||
supported_dtype.add(torch.float8_e4m3fn)
|
||||
supported_dtype.add(torch.float8_e5m2)
|
||||
|
||||
return dtype in supported_dtype
|
||||
|
||||
|
|
|
|||
|
|
@ -54,6 +54,8 @@
|
|||
// The following files are implemented in a header-only way and are guarded by
|
||||
// test/cpp/aoti_abi_check
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e5m2.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
|
|
@ -164,6 +166,12 @@ aoti_torch_item_bfloat16(AtenTensorHandle tensor, c10::BFloat16* ret_value);
|
|||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_complex64(
|
||||
AtenTensorHandle tensor,
|
||||
c10::complex<float>* ret_value);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_float8_e4m3fn(
|
||||
AtenTensorHandle tensor,
|
||||
c10::Float8_e4m3fn* ret_value);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_float8_e5m2(
|
||||
AtenTensorHandle tensor,
|
||||
c10::Float8_e5m2* ret_value);
|
||||
|
||||
// Functions for wrapping a scalar value to a single-element tensor
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_float32(
|
||||
|
|
@ -701,6 +709,8 @@ int32_t aoti_torch_dtype() = delete;
|
|||
namespace c10 {
|
||||
struct BFloat16;
|
||||
struct Half;
|
||||
struct Float8_e4m3fn;
|
||||
struct Float8_e5m2;
|
||||
} // namespace c10
|
||||
|
||||
DEFINE_DTYPE_SPECIALIZATION(c10::BFloat16, bfloat16)
|
||||
|
|
@ -714,6 +724,8 @@ DEFINE_DTYPE_SPECIALIZATION(int16_t, int16)
|
|||
DEFINE_DTYPE_SPECIALIZATION(int32_t, int32)
|
||||
DEFINE_DTYPE_SPECIALIZATION(int64_t, int64)
|
||||
DEFINE_DTYPE_SPECIALIZATION(bool, bool)
|
||||
DEFINE_DTYPE_SPECIALIZATION(c10::Float8_e4m3fn, float8_e4m3fn)
|
||||
DEFINE_DTYPE_SPECIALIZATION(c10::Float8_e5m2, float8_e5m2)
|
||||
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -200,6 +200,8 @@ AOTI_TORCH_ITEM_IMPL(int64, int64_t)
|
|||
AOTI_TORCH_ITEM_IMPL(bool, bool)
|
||||
AOTI_TORCH_ITEM_IMPL(bfloat16, c10::BFloat16)
|
||||
AOTI_TORCH_ITEM_IMPL(complex64, c10::complex<float>)
|
||||
AOTI_TORCH_ITEM_IMPL(float8_e4m3fn, c10::Float8_e4m3fn)
|
||||
AOTI_TORCH_ITEM_IMPL(float8_e5m2, c10::Float8_e5m2)
|
||||
#undef AOTI_TORCH_ITEM_IMPL
|
||||
|
||||
#define AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(dtype, ctype, ttype) \
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user