pytorch/torch/_C/_onnx.pyi
Aaron Bockover bd1229477d [ONNX] Add initial support for FP8 ONNX export (#107962)
This PR resurrects @tcherckez-nvidia's #106379 with changes to resolve conflicts against newer `main` and defines our own constants for the new ONNX types to [avoid breaking Meta's internal usage of an old ONNX](https://github.com/pytorch/pytorch/pull/106379#issuecomment-1675189340).

- `::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN=17`
- `::torch::onnx::TensorProto_DataType_FLOAT8E5M2=19`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107962
Approved by: https://github.com/justinchuby, https://github.com/titaiwangms
2023-09-08 20:40:39 +00:00

39 lines
688 B
Python

# Defined in torch/csrc/onnx/init.cpp
from enum import Enum
_CAFFE2_ATEN_FALLBACK: bool
PRODUCER_VERSION: str
class TensorProtoDataType(Enum):
UNDEFINED = ...
FLOAT = ...
UINT8 = ...
INT8 = ...
UINT16 = ...
INT16 = ...
INT32 = ...
INT64 = ...
STRING = ...
BOOL = ...
FLOAT16 = ...
DOUBLE = ...
UINT32 = ...
UINT64 = ...
COMPLEX64 = ...
COMPLEX128 = ...
BFLOAT16 = ...
FLOAT8E5M2 = ...
FLOAT8E4M3FN = ...
class OperatorExportTypes(Enum):
ONNX = ...
ONNX_ATEN = ...
ONNX_ATEN_FALLBACK = ...
ONNX_FALLTHROUGH = ...
class TrainingMode(Enum):
EVAL = ...
PRESERVE = ...
TRAINING = ...