pytorch/torch/csrc/onnx/back_compat.h
Aaron Bockover bd1229477d [ONNX] Add initial support for FP8 ONNX export (#107962)
This PR resurrects @tcherckez-nvidia's #106379 with changes to resolve conflicts against newer `main` and defines our own constants for the new ONNX types to [avoid breaking Meta's internal usage of an old ONNX](https://github.com/pytorch/pytorch/pull/106379#issuecomment-1675189340).

- `::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN=17`
- `::torch::onnx::TensorProto_DataType_FLOAT8E5M2=19`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107962
Approved by: https://github.com/justinchuby, https://github.com/titaiwangms
2023-09-08 20:40:39 +00:00

22 lines
707 B
C++

#pragma once
#include <onnx/onnx_pb.h>
namespace torch {
namespace onnx {
// The following constants are defined here to avoid breaking Meta's internal
// usage of ONNX which pre-dates ONNX 1.14 and thus does not support FLOAT8:
// cf. https://github.com/pytorch/pytorch/pull/106379#issuecomment-1675189340
// -abock, 2023-08-25
//
// ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN
constexpr auto TensorProto_DataType_FLOAT8E4M3FN =
static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(17);
// ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2
constexpr auto TensorProto_DataType_FLOAT8E5M2 =
static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(19);
} // namespace onnx
} // namespace torch