diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index 58646320935..9fb69589f41 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -650,6 +650,35 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase): self.assertTrue(onnx_program.model_signature, torch.export.ExportGraphSignature) + @common_utils.parametrize( + "float8_type", + [ + common_utils.subtest( + torch.float8_e5m2, + name="torch_float8_e5m2", + ), + common_utils.subtest( + torch.float8_e5m2fnuz, + name="torch_float8_e5m2fnuz", + ), + common_utils.subtest( + torch.float8_e4m3fn, + name="torch_float8_e4m3fn", + ), + common_utils.subtest( + torch.float8_e4m3fnuz, + name="torch_float8_e4m3fnuz", + ), + ], + ) + def test_float8_support(self, float8_type): + class Float8Module(torch.nn.Module): + def forward(self, input: torch.Tensor): + input = input.to(float8_type) + return input + torch.tensor(1.0, dtype=float8_type) + + _ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4)) + if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/_internal/fx/type_utils.py b/torch/onnx/_internal/fx/type_utils.py index d3c11257724..b7f3d6cea64 100644 --- a/torch/onnx/_internal/fx/type_utils.py +++ b/torch/onnx/_internal/fx/type_utils.py @@ -124,6 +124,10 @@ _TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS: Dict[ torch.float64: {"tensor(double)"}, torch.float32: {"tensor(float)"}, torch.float16: {"tensor(float16)"}, + torch.float8_e4m3fn: {"tensor(float8_e4m3fn)"}, + torch.float8_e4m3fnuz: {"tensor(float8_e4m3fnuz)"}, + torch.float8_e5m2: {"tensor(float8_e5m2)"}, + torch.float8_e5m2fnuz: {"tensor(float8_e5m2fnuz)"}, torch.int16: {"tensor(int16)"}, torch.int32: {"tensor(int32)"}, torch.int64: {"tensor(int64)"}, @@ -174,6 +178,10 @@ _TORCH_DTYPE_TO_ABBREVIATION = { torch.float64: "f64", torch.float32: "f32", torch.float16: "f16", + torch.float8_e4m3fn: "e4m3fn", + torch.float8_e4m3fnuz: "e4m3fnuz", + torch.float8_e5m2: "f8e5m2", + torch.float8_e5m2fnuz: "e5m2fnuz", torch.complex32: "c32", torch.complex64: "c64", torch.complex128: "c128", @@ -200,6 +208,10 @@ _TORCH_DTYPE_TO_NUMPY_DTYPE = { _ONNX_TENSOR_ELEMENT_TYPE_TO_TORCH_DTYPE = { onnx.TensorProto.FLOAT: torch.float32, # type: ignore[attr-defined] onnx.TensorProto.FLOAT16: torch.float16, # type: ignore[attr-defined] + onnx.TensorProto.FLOAT8E5M2: torch.float8_e5m2, # type: ignore[attr-defined] + onnx.TensorProto.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz, # type: ignore[attr-defined] + onnx.TensorProto.FLOAT8E4M3FN: torch.float8_e4m3fn, # type: ignore[attr-defined] + onnx.TensorProto.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz, # type: ignore[attr-defined] onnx.TensorProto.DOUBLE: torch.float64, # type: ignore[attr-defined] onnx.TensorProto.BOOL: torch.bool, # type: ignore[attr-defined] onnx.TensorProto.UINT8: torch.uint8, # type: ignore[attr-defined]