pytorch/torch/onnx/symbolic_opset15.py
Justin Chu 0d76299ff7 [ONNX] Clean up module imports (#77423)
Cleaning up onnx module imports to prepare for updating `__init__`.

- Simplify importing the `_C` and `_C._onnx` name spaces
- Remove alias of the symbolic_helper module in imports
- Remove any module level function imports. Import modules instead
    - Alias `symbilic_opsetx` as `opsetx`
- Fix some docstrings

Requires:
- https://github.com/pytorch/pytorch/pull/77448
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77423
Approved by: https://github.com/BowenBao
2022-05-20 01:56:24 +00:00

61 lines
1.9 KiB
Python

"""This file exports ONNX ops for opset 15.
Note [ONNX operators that are added/updated in opset 15]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set
New operators:
Bernoulli
CastLike
Optional
OptionalGetElement
OptionalHasElement
Updated operators:
BatchNormalization https://github.com/onnx/onnx/pull/3545
Backwards compatible
TODO: test coverage for mixed types inputs.
Pow https://github.com/onnx/onnx/pull/3412
Backwards compatible
TODO: bfloat16 support.
Shape https://github.com/onnx/onnx/pull/3580
Backwards compatible
TODO: optional start/end attribute.
"""
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
import torch
from torch import _C
from torch.onnx import symbolic_helper
from torch.onnx import symbolic_opset9 as opset9
def __is_(g, self, other):
if symbolic_helper._is_none(other):
if isinstance(self.type(), _C.OptionalType):
none = g.op("OptionalHasElement", self)
return g.op("Not", none)
else:
return g.op("Constant", value_t=torch.BoolTensor([0]))
return opset9.eq(g, self, other)
@opset9.wrap_logical_op_with_negation
def __isnot_(g, self, other):
return __is_(g, self, other)
class Prim:
domain = "prim"
@staticmethod
def unchecked_cast(g, self):
# exists to refine the type of the Value
# if x is Optional[Tensor], unchecked_cast will cast
# x to Tensor, so the rest of the graph knows that x is a Tensor.
if isinstance(self.type(), _C.OptionalType):
return g.op("OptionalGetElement", self)
return self