mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
enable fp8 cast for inductor CPU (#117737)
Enable FP8 cast for this issue https://github.com/pytorch/pytorch/issues/117119. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117737 Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
parent
d8420c0b0c
commit
d01ba4e94e
|
|
@ -33,7 +33,12 @@ from torch._inductor.utils import timed
|
||||||
from torch._inductor.virtualized import V
|
from torch._inductor.virtualized import V
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.testing._internal.common_utils import IS_MACOS, slowTest
|
from torch.testing._internal.common_utils import (
|
||||||
|
instantiate_parametrized_tests,
|
||||||
|
IS_MACOS,
|
||||||
|
parametrize,
|
||||||
|
slowTest,
|
||||||
|
)
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -83,6 +88,7 @@ class LstmModule(torch.nn.Module):
|
||||||
return x, h
|
return x, h
|
||||||
|
|
||||||
|
|
||||||
|
@instantiate_parametrized_tests
|
||||||
class CPUReproTests(TestCase):
|
class CPUReproTests(TestCase):
|
||||||
common = check_model
|
common = check_model
|
||||||
|
|
||||||
|
|
@ -2780,6 +2786,18 @@ class CPUReproTests(TestCase):
|
||||||
"Vectorized<float>::loadu(tmpbuf.data())", 0, exactly=True
|
"Vectorized<float>::loadu(tmpbuf.data())", 0, exactly=True
|
||||||
).run(code)
|
).run(code)
|
||||||
|
|
||||||
|
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
|
||||||
|
@parametrize("shape", ("15,3,13", "4,2048,4096"))
|
||||||
|
def test_fp8_cast(self, dtype: torch.dtype, shape: str):
|
||||||
|
def fp8_cast(x):
|
||||||
|
y0 = x.to(dtype=torch.float8_e4m3fn).to(dtype)
|
||||||
|
y1 = x.to(dtype=torch.float8_e5m2).to(dtype)
|
||||||
|
return y0, y1
|
||||||
|
|
||||||
|
shape = [int(dim) for dim in shape.split(",")]
|
||||||
|
x = torch.rand(*shape, device="cpu", dtype=dtype)
|
||||||
|
self.common(fp8_cast, (x,))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
|
||||||
|
|
@ -66,6 +66,8 @@ DTYPE_TO_CPP = {
|
||||||
torch.bool: "bool",
|
torch.bool: "bool",
|
||||||
torch.bfloat16: "bfloat16",
|
torch.bfloat16: "bfloat16",
|
||||||
torch.complex64: "complex64",
|
torch.complex64: "complex64",
|
||||||
|
torch.float8_e4m3fn: "float8_e4m3fn",
|
||||||
|
torch.float8_e5m2: "float8_e5m2",
|
||||||
}
|
}
|
||||||
|
|
||||||
DTYPE_TO_ATEN = {
|
DTYPE_TO_ATEN = {
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,8 @@
|
||||||
#include <ATen/core/PhiloxRNGEngine.h>
|
#include <ATen/core/PhiloxRNGEngine.h>
|
||||||
#include <ATen/native/Math.h>
|
#include <ATen/native/Math.h>
|
||||||
|
|
||||||
|
#include <c10/util/Float8_e4m3fn.h>
|
||||||
|
#include <c10/util/Float8_e5m2.h>
|
||||||
#include <c10/util/BFloat16.h>
|
#include <c10/util/BFloat16.h>
|
||||||
#include <c10/util/BFloat16-math.h>
|
#include <c10/util/BFloat16-math.h>
|
||||||
#include <c10/util/generic_math.h>
|
#include <c10/util/generic_math.h>
|
||||||
|
|
@ -31,6 +33,9 @@
|
||||||
typedef at::Half half;
|
typedef at::Half half;
|
||||||
typedef at::BFloat16 bfloat16;
|
typedef at::BFloat16 bfloat16;
|
||||||
|
|
||||||
|
typedef at::Float8_e4m3fn float8_e4m3fn;
|
||||||
|
typedef at::Float8_e5m2 float8_e5m2;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct Welford {
|
struct Welford {
|
||||||
T mean = T(0);
|
T mean = T(0);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user