mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
enable fp8 cast for inductor CPU (#117737)
Enable FP8 cast for this issue https://github.com/pytorch/pytorch/issues/117119. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117737 Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
parent
d8420c0b0c
commit
d01ba4e94e
|
|
@ -33,7 +33,12 @@ from torch._inductor.utils import timed
|
|||
from torch._inductor.virtualized import V
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_utils import IS_MACOS, slowTest
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
IS_MACOS,
|
||||
parametrize,
|
||||
slowTest,
|
||||
)
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
try:
|
||||
|
|
@ -83,6 +88,7 @@ class LstmModule(torch.nn.Module):
|
|||
return x, h
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class CPUReproTests(TestCase):
|
||||
common = check_model
|
||||
|
||||
|
|
@ -2780,6 +2786,18 @@ class CPUReproTests(TestCase):
|
|||
"Vectorized<float>::loadu(tmpbuf.data())", 0, exactly=True
|
||||
).run(code)
|
||||
|
||||
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
|
||||
@parametrize("shape", ("15,3,13", "4,2048,4096"))
|
||||
def test_fp8_cast(self, dtype: torch.dtype, shape: str):
|
||||
def fp8_cast(x):
|
||||
y0 = x.to(dtype=torch.float8_e4m3fn).to(dtype)
|
||||
y1 = x.to(dtype=torch.float8_e5m2).to(dtype)
|
||||
return y0, y1
|
||||
|
||||
shape = [int(dim) for dim in shape.split(",")]
|
||||
x = torch.rand(*shape, device="cpu", dtype=dtype)
|
||||
self.common(fp8_cast, (x,))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -66,6 +66,8 @@ DTYPE_TO_CPP = {
|
|||
torch.bool: "bool",
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.complex64: "complex64",
|
||||
torch.float8_e4m3fn: "float8_e4m3fn",
|
||||
torch.float8_e5m2: "float8_e5m2",
|
||||
}
|
||||
|
||||
DTYPE_TO_ATEN = {
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@
|
|||
#include <ATen/core/PhiloxRNGEngine.h>
|
||||
#include <ATen/native/Math.h>
|
||||
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e5m2.h>
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/BFloat16-math.h>
|
||||
#include <c10/util/generic_math.h>
|
||||
|
|
@ -31,6 +33,9 @@
|
|||
typedef at::Half half;
|
||||
typedef at::BFloat16 bfloat16;
|
||||
|
||||
typedef at::Float8_e4m3fn float8_e4m3fn;
|
||||
typedef at::Float8_e5m2 float8_e5m2;
|
||||
|
||||
template <typename T>
|
||||
struct Welford {
|
||||
T mean = T(0);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user