enable fp8 cast for inductor CPU (#117737)

Enable FP8 cast for this issue https://github.com/pytorch/pytorch/issues/117119.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117737
Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
haozhe.zhu 2024-01-22 07:49:58 +00:00 committed by PyTorch MergeBot
parent d8420c0b0c
commit d01ba4e94e
3 changed files with 26 additions and 1 deletions

View File

@ -33,7 +33,12 @@ from torch._inductor.utils import timed
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn import functional as F
from torch.testing._internal.common_utils import IS_MACOS, slowTest
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_MACOS,
parametrize,
slowTest,
)
from torch.utils._python_dispatch import TorchDispatchMode
try:
@ -83,6 +88,7 @@ class LstmModule(torch.nn.Module):
return x, h
@instantiate_parametrized_tests
class CPUReproTests(TestCase):
common = check_model
@ -2780,6 +2786,18 @@ class CPUReproTests(TestCase):
"Vectorized<float>::loadu(tmpbuf.data())", 0, exactly=True
).run(code)
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
@parametrize("shape", ("15,3,13", "4,2048,4096"))
def test_fp8_cast(self, dtype: torch.dtype, shape: str):
def fp8_cast(x):
y0 = x.to(dtype=torch.float8_e4m3fn).to(dtype)
y1 = x.to(dtype=torch.float8_e5m2).to(dtype)
return y0, y1
shape = [int(dim) for dim in shape.split(",")]
x = torch.rand(*shape, device="cpu", dtype=dtype)
self.common(fp8_cast, (x,))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -66,6 +66,8 @@ DTYPE_TO_CPP = {
torch.bool: "bool",
torch.bfloat16: "bfloat16",
torch.complex64: "complex64",
torch.float8_e4m3fn: "float8_e4m3fn",
torch.float8_e5m2: "float8_e5m2",
}
DTYPE_TO_ATEN = {

View File

@ -11,6 +11,8 @@
#include <ATen/core/PhiloxRNGEngine.h>
#include <ATen/native/Math.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/BFloat16.h>
#include <c10/util/BFloat16-math.h>
#include <c10/util/generic_math.h>
@ -31,6 +33,9 @@
typedef at::Half half;
typedef at::BFloat16 bfloat16;
typedef at::Float8_e4m3fn float8_e4m3fn;
typedef at::Float8_e5m2 float8_e5m2;
template <typename T>
struct Welford {
T mean = T(0);