mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68490
The use of ATEN as a fallback operator during ONNX conversion is important for increasing operator coverage or even provide more efficient implementations over some ONNX ops.
Currently this feature is available through `OperatorExportTypes.ONNX_ATEN_FALLBACK`,
but it also performs changes to the graph that are runnable by Caffe2, only.
This PR introduces restricts caffe2-specific graph transformations for `ONNX_ATEN_FALLBACK`
operator export type for when pytorch is built with caffe2 support (aka BUILD_CAFFE2=1 during build)
The first version of this PR introduced a new operator export type `ONNX_ATEN__STRICT_FALLBACK`,
which essentially is the same as `ONNX_ATEN_FALLBACK` but without caffe2 transformations.
It was preferred to not introduce a new operator export type, but to refine the existing aten fallback one
## BC-breaking note
### The global constant `torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE` is removed in favor of
a less visible `torch.onnx._CAFFE2_ATEN_FALLBACK`.
`PYTORCH_ONNX_CAFFE2_BUNDLE` is really a dead code flag always set to False.
One alternative would be fixing it, but #66658 disables Caffe2 build by default.
Making a Caffe2 feature a private one seems to make more sense for future deprecation.
### The method `torch.onnx.export` now defaults to ONNX when `operator_export_type` is not specified.
Previously `torch.onnx.export's operator_export_type` intended to default to `ONNX_ATEN_FALLBACK` when `PYTORCH_ONNX_CAFFE2_BUNDLE` was set, but it would never happen as `PYTORCH_ONNX_CAFFE2_BUNDLE` is always undefined
Co-authored-by: Nikita Shulga <nshulga@fb.com>
Test Plan: Imported from OSS
Reviewed By: jansel
Differential Revision: D32483781
Pulled By: malfet
fbshipit-source-id: e9b447db9466b369e77d747188685495aec3f124
(cherry picked from commit 5fb1eb1b19)
37 lines
644 B
Python
37 lines
644 B
Python
# Defined in torch/csrc/onnx/init.cpp
|
|
|
|
from enum import Enum
|
|
|
|
_CAFFE2_ATEN_FALLBACK: bool
|
|
PRODUCER_VERSION: str
|
|
|
|
class TensorProtoDataType(Enum):
|
|
UNDEFINED = ...
|
|
FLOAT = ...
|
|
UINT8 = ...
|
|
INT8 = ...
|
|
UINT16 = ...
|
|
INT16 = ...
|
|
INT32 = ...
|
|
INT64 = ...
|
|
STRING = ...
|
|
BOOL = ...
|
|
FLOAT16 = ...
|
|
DOUBLE = ...
|
|
UINT32 = ...
|
|
UINT64 = ...
|
|
COMPLEX64 = ...
|
|
COMPLEX128 = ...
|
|
BFLOAT16 = ...
|
|
|
|
class OperatorExportTypes(Enum):
|
|
ONNX = ...
|
|
ONNX_ATEN = ...
|
|
ONNX_ATEN_FALLBACK = ...
|
|
ONNX_FALLTHROUGH = ...
|
|
|
|
class TrainingMode(Enum):
|
|
EVAL = ...
|
|
PRESERVE = ...
|
|
TRAINING = ...
|