mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add Float8 support to onnx exporter (#121281)
Fixes #106877 Pull Request resolved: https://github.com/pytorch/pytorch/pull/121281 Approved by: https://github.com/BowenBao, https://github.com/titaiwangms
This commit is contained in:
parent
5a2527db22
commit
418568d2e3
|
|
@ -650,6 +650,35 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
||||||
|
|
||||||
self.assertTrue(onnx_program.model_signature, torch.export.ExportGraphSignature)
|
self.assertTrue(onnx_program.model_signature, torch.export.ExportGraphSignature)
|
||||||
|
|
||||||
|
@common_utils.parametrize(
|
||||||
|
"float8_type",
|
||||||
|
[
|
||||||
|
common_utils.subtest(
|
||||||
|
torch.float8_e5m2,
|
||||||
|
name="torch_float8_e5m2",
|
||||||
|
),
|
||||||
|
common_utils.subtest(
|
||||||
|
torch.float8_e5m2fnuz,
|
||||||
|
name="torch_float8_e5m2fnuz",
|
||||||
|
),
|
||||||
|
common_utils.subtest(
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
name="torch_float8_e4m3fn",
|
||||||
|
),
|
||||||
|
common_utils.subtest(
|
||||||
|
torch.float8_e4m3fnuz,
|
||||||
|
name="torch_float8_e4m3fnuz",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_float8_support(self, float8_type):
|
||||||
|
class Float8Module(torch.nn.Module):
|
||||||
|
def forward(self, input: torch.Tensor):
|
||||||
|
input = input.to(float8_type)
|
||||||
|
return input + torch.tensor(1.0, dtype=float8_type)
|
||||||
|
|
||||||
|
_ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
common_utils.run_tests()
|
common_utils.run_tests()
|
||||||
|
|
|
||||||
|
|
@ -124,6 +124,10 @@ _TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS: Dict[
|
||||||
torch.float64: {"tensor(double)"},
|
torch.float64: {"tensor(double)"},
|
||||||
torch.float32: {"tensor(float)"},
|
torch.float32: {"tensor(float)"},
|
||||||
torch.float16: {"tensor(float16)"},
|
torch.float16: {"tensor(float16)"},
|
||||||
|
torch.float8_e4m3fn: {"tensor(float8_e4m3fn)"},
|
||||||
|
torch.float8_e4m3fnuz: {"tensor(float8_e4m3fnuz)"},
|
||||||
|
torch.float8_e5m2: {"tensor(float8_e5m2)"},
|
||||||
|
torch.float8_e5m2fnuz: {"tensor(float8_e5m2fnuz)"},
|
||||||
torch.int16: {"tensor(int16)"},
|
torch.int16: {"tensor(int16)"},
|
||||||
torch.int32: {"tensor(int32)"},
|
torch.int32: {"tensor(int32)"},
|
||||||
torch.int64: {"tensor(int64)"},
|
torch.int64: {"tensor(int64)"},
|
||||||
|
|
@ -174,6 +178,10 @@ _TORCH_DTYPE_TO_ABBREVIATION = {
|
||||||
torch.float64: "f64",
|
torch.float64: "f64",
|
||||||
torch.float32: "f32",
|
torch.float32: "f32",
|
||||||
torch.float16: "f16",
|
torch.float16: "f16",
|
||||||
|
torch.float8_e4m3fn: "e4m3fn",
|
||||||
|
torch.float8_e4m3fnuz: "e4m3fnuz",
|
||||||
|
torch.float8_e5m2: "f8e5m2",
|
||||||
|
torch.float8_e5m2fnuz: "e5m2fnuz",
|
||||||
torch.complex32: "c32",
|
torch.complex32: "c32",
|
||||||
torch.complex64: "c64",
|
torch.complex64: "c64",
|
||||||
torch.complex128: "c128",
|
torch.complex128: "c128",
|
||||||
|
|
@ -200,6 +208,10 @@ _TORCH_DTYPE_TO_NUMPY_DTYPE = {
|
||||||
_ONNX_TENSOR_ELEMENT_TYPE_TO_TORCH_DTYPE = {
|
_ONNX_TENSOR_ELEMENT_TYPE_TO_TORCH_DTYPE = {
|
||||||
onnx.TensorProto.FLOAT: torch.float32, # type: ignore[attr-defined]
|
onnx.TensorProto.FLOAT: torch.float32, # type: ignore[attr-defined]
|
||||||
onnx.TensorProto.FLOAT16: torch.float16, # type: ignore[attr-defined]
|
onnx.TensorProto.FLOAT16: torch.float16, # type: ignore[attr-defined]
|
||||||
|
onnx.TensorProto.FLOAT8E5M2: torch.float8_e5m2, # type: ignore[attr-defined]
|
||||||
|
onnx.TensorProto.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz, # type: ignore[attr-defined]
|
||||||
|
onnx.TensorProto.FLOAT8E4M3FN: torch.float8_e4m3fn, # type: ignore[attr-defined]
|
||||||
|
onnx.TensorProto.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz, # type: ignore[attr-defined]
|
||||||
onnx.TensorProto.DOUBLE: torch.float64, # type: ignore[attr-defined]
|
onnx.TensorProto.DOUBLE: torch.float64, # type: ignore[attr-defined]
|
||||||
onnx.TensorProto.BOOL: torch.bool, # type: ignore[attr-defined]
|
onnx.TensorProto.BOOL: torch.bool, # type: ignore[attr-defined]
|
||||||
onnx.TensorProto.UINT8: torch.uint8, # type: ignore[attr-defined]
|
onnx.TensorProto.UINT8: torch.uint8, # type: ignore[attr-defined]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user