pytorch/torch/onnx/symbolic_opset12.py
Thiago Crepaldi a8f40b39ce Update all ONNX symbolics with new JitScalarType API (#87245)
Fixes https://github.com/pytorch/pytorch/issues/84365 and more

This PR addresses not only the issue above, but the entire family of issues related to `torch._C.Value.type()` parsing when `scalarType()` or `dtype()` is not available.

This issue exists before `JitScalarType` was introduced, but the new implementation refactored the bug in because the new api `from_name` and `from_dtype` requires parsing `torch._C.Value.type()` to get proper inputs, which is exactly the root cause for this family of bugs.

Therefore `from_name` and `from_dtype` must be called when the implementor knows the `name` and `dtype` without parsing a `torch._C.Value`. To handle the corner cases hidden within `torch._C.Value`, a new `from_value` API was introduced and it should be used in favor of the former ones for most cases. The new API is safer and doesn't require type parsing from user, triggering JIT asserts in the core of pytorch.

Although CI is passing for all tests, please review carefully all symbolics/helpers refactoring to make sure the meaning/intetion of the old call are not changed in the new call

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87245
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
2022-11-03 03:01:33 +00:00

484 lines
16 KiB
Python

import functools
import sys
from typing import Optional, Tuple
import torch
from torch._C import _onnx as _C_onnx
from torch.onnx import (
_type_utils,
errors,
symbolic_helper,
symbolic_opset9 as opset9,
utils,
)
from torch.onnx._internal import _beartype, jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
# This file exports ONNX ops for opset 12
__all__ = [
"argmax",
"argmin",
"binary_cross_entropy_with_logits",
"celu",
"cross_entropy_loss",
"dropout",
"einsum",
"ge",
"le",
"native_dropout",
"nll_loss",
"nll_loss2d",
"nll_loss_nd",
"outer",
"pow",
"tensordot",
"unfold",
]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12)
@_beartype.beartype
def _einsum_helper(g: jit_utils.GraphContext, equation, tensors):
if not tensors:
raise RuntimeError("Einsum inputs are empty.")
# ONNX does not support bool for Einsum inputs.
if symbolic_helper._is_bool(tensors[0]):
tensors = [
g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64)
for tensor in tensors
]
return g.op(
"Cast",
g.op("Einsum", *tensors, equation_s=equation),
to_i=_C_onnx.TensorProtoDataType.BOOL,
)
else:
return g.op("Einsum", *tensors, equation_s=equation)
@_onnx_symbolic("aten::einsum")
@symbolic_helper.parse_args("s", "v", "is")
@_beartype.beartype
def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None):
tensors = symbolic_helper._unpack_list(tensor_list)
return _einsum_helper(g, equation, tensors)
@_onnx_symbolic("aten::outer")
@symbolic_helper.parse_args("v", "v")
@_beartype.beartype
def outer(g: jit_utils.GraphContext, input, other):
# make sure to cast other to self's type
if _type_utils.JitScalarType.from_value(
other, _type_utils.JitScalarType.UNDEFINED
) != _type_utils.JitScalarType.from_value(input):
other = g.op(
"Cast",
other,
to_i=_type_utils.JitScalarType.from_value(input).onnx_type(),
)
return _einsum_helper(g, "i,j->ij", [input, other])
@_beartype.beartype
def _dropout_returns_masked_input_and_mask(
g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool
) -> Tuple[torch._C.Value, Optional[torch._C.Value]]:
symbolic_helper.check_training_mode(train, "dropout")
# In eval mode, dropout is non-op. That is, if the node's
# train param is set to False, dropout just returns its inputs.
if not train:
return input, None
p = g.op("Constant", value_t=torch.tensor(p))
t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool))
r, mask = g.op("Dropout", input, p, t, outputs=2)
return r, mask
@_onnx_symbolic("aten::dropout")
@symbolic_helper.parse_args("v", "f", "b")
@_beartype.beartype
def dropout(g: jit_utils.GraphContext, input, p, train):
masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train)
return masked
@_onnx_symbolic("aten::native_dropout")
@symbolic_helper.parse_args("v", "f", "b")
@_beartype.beartype
def native_dropout(g: jit_utils.GraphContext, input, p, train):
return _dropout_returns_masked_input_and_mask(g, input, p, train)
@_onnx_symbolic("aten::nll_loss")
@_beartype.beartype
def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index):
# none reduction : onnx::Constant[value={0}]
# mean reduction : onnx::Constant[value={1}]
# sum reduction : onnx::Constant[value={2}]
reduction = symbolic_helper._maybe_get_const(reduction, "i")
reduction_vals = ["none", "mean", "sum"]
reduction = reduction_vals[reduction]
# in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value.
# therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
if weight.node().mustBeNone():
nllloss = g.op(
"NegativeLogLikelihoodLoss",
self,
target,
reduction_s=reduction,
ignore_index_i=ignore_index,
)
else:
nllloss = g.op(
"NegativeLogLikelihoodLoss",
self,
target,
weight,
reduction_s=reduction,
ignore_index_i=ignore_index,
)
return nllloss
@_onnx_symbolic("aten::nll_loss2d")
@_beartype.beartype
def nll_loss2d(
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
):
return nll_loss(g, self, target, weight, reduction, ignore_index)
@_onnx_symbolic("aten::nll_loss_nd")
@_beartype.beartype
def nll_loss_nd(
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
):
return nll_loss(g, self, target, weight, reduction, ignore_index)
@_onnx_symbolic("aten::cross_entropy_loss")
@_beartype.beartype
def cross_entropy_loss(
g: jit_utils.GraphContext,
self,
target,
weight,
reduction,
ignore_index,
label_smoothing,
):
# none reduction : onnx::Constant[value={0}]
# mean reduction : onnx::Constant[value={1}]
# sum reduction : onnx::Constant[value={2}]
reduction = symbolic_helper._maybe_get_const(reduction, "i")
reduction_vals = ["none", "mean", "sum"]
reduction = reduction_vals[reduction]
label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f")
if label_smoothing is not None and label_smoothing > 0.0:
raise errors.SymbolicValueError(
"Unsupported: ONNX does not support label_smoothing", self
)
# in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value.
# therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
if weight.node().mustBeNone():
celoss = g.op(
"SoftmaxCrossEntropyLoss",
self,
target,
reduction_s=reduction,
ignore_index_i=ignore_index,
)
else:
celoss = g.op(
"SoftmaxCrossEntropyLoss",
self,
target,
weight,
reduction_s=reduction,
ignore_index_i=ignore_index,
)
return celoss
@_onnx_symbolic("aten::binary_cross_entropy_with_logits")
@symbolic_helper.parse_args("v", "v", "v", "v", "i")
@_beartype.beartype
def binary_cross_entropy_with_logits(
g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction
):
p = g.op("Constant", value_t=torch.tensor([1]))
sig_x = opset9.sigmoid(g, input)
log_sig_x = opset9.log(g, sig_x)
sub_1_x = opset9.sub(g, p, sig_x)
sub_1_y = opset9.sub(g, p, target)
log_1_x = opset9.log(g, sub_1_x)
if pos_weight is None or symbolic_helper._is_none(pos_weight):
output = opset9.neg(
g,
opset9.add(
g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x)
),
)
else:
output = opset9.neg(
g,
opset9.add(
g,
opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight),
opset9.mul(g, sub_1_y, log_1_x),
),
)
if weight is not None and not symbolic_helper._is_none(weight):
output = opset9.mul(g, weight, output)
reduction = symbolic_helper._maybe_get_const(reduction, "i")
if reduction == 0:
return output
elif reduction == 1:
return g.op("ReduceMean", output, keepdims_i=0)
elif reduction == 2:
return g.op("ReduceSum", output, keepdims_i=0)
else:
return symbolic_helper._onnx_unsupported(
"binary_cross_entropy_with_logits with reduction other than none, mean, or sum",
input,
)
@_onnx_symbolic("aten::celu")
@_beartype.beartype
def celu(g: jit_utils.GraphContext, self, alpha):
alpha = symbolic_helper._maybe_get_const(alpha, "f")
# if the input is of type double cast it to float
if (
_type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED)
== _type_utils.JitScalarType.DOUBLE
):
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT)
out = g.op("Celu", self, alpha_f=alpha)
return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
return g.op("Celu", self, alpha_f=alpha)
@_onnx_symbolic("aten::argmax")
@symbolic_helper.parse_args("v", "v", "b")
@_beartype.beartype
def argmax(
g: jit_utils.GraphContext,
input: torch._C.Value,
dim: torch._C.Value,
keepdim: bool,
):
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax")
@_onnx_symbolic("aten::argmin")
@symbolic_helper.parse_args("v", "v", "b")
@_beartype.beartype
def argmin(
g: jit_utils.GraphContext,
input: torch._C.Value,
dim: torch._C.Value,
keepdim: bool,
):
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin")
@_onnx_symbolic("aten::pow")
@_beartype.beartype
def pow(g: jit_utils.GraphContext, self, exponent):
return g.op("Pow", self, exponent)
@_onnx_symbolic("aten::ge")
@_beartype.beartype
def ge(g: jit_utils.GraphContext, input, other):
return g.op("GreaterOrEqual", input, other)
@_onnx_symbolic("aten::le")
@_beartype.beartype
def le(g: jit_utils.GraphContext, input, other):
return g.op("LessOrEqual", input, other)
@_onnx_symbolic("aten::unfold")
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
const_size = symbolic_helper._maybe_get_const(size, "i")
const_step = symbolic_helper._maybe_get_const(step, "i")
if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value(
const_step
):
return opset9.unfold(g, input, dimension, const_size, const_step)
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step)
sizedim = symbolic_helper._get_tensor_dim_size(input, dimension)
if sizedim is not None:
low_start = g.op("Constant", value_t=torch.tensor(0))
low_end = g.op("Constant", value_t=torch.tensor(sizedim))
hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1))
low_indices = g.op("Range", low_start, low_end, step)
hi_indices = g.op("Range", size, hi_end, step)
low_size = symbolic_helper._size_helper(
g, low_indices, g.op("Constant", value_t=torch.tensor(0))
)
hi_size = symbolic_helper._size_helper(
g, hi_indices, g.op("Constant", value_t=torch.tensor(0))
)
ndim = symbolic_helper._get_tensor_rank(input)
assert ndim is not None
perm = list(range(0, ndim))
perm.append(perm.pop(dimension))
unsqueeze_list = []
loop_condition = g.op("Constant", value_t=torch.tensor(1))
loop_condition = g.op(
"Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
)
loop_len = g.op("Min", low_size, hi_size)
loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
g, "Loop", loop_len, loop_condition, n_blocks=1
)
loop_block = loop_context.block
block_input_iter = utils._add_input_to_block(loop_block)
# FIXME(justinchuby): cond is unused?
cond = utils._add_input_to_block(loop_block)
starts = loop_context.op("Gather", low_indices, block_input_iter)
ends = loop_context.op("Gather", hi_indices, block_input_iter)
axes = loop_context.op("Constant", value_t=torch.tensor([2]))
starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0])
ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0])
stack = loop_context.op("Slice", input, starts, ends, axes)
unsqueeze = symbolic_helper._unsqueeze_helper(
loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension]
)
unsqueeze_list.append(unsqueeze)
concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0)
cond_out = loop_context.op(
"Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL
)
utils._add_output_to_block(loop_block, cond_out)
utils._add_output_to_block(loop_block, concat)
loop_output = loop.node().output()
perm = [0, 1, 2, 3, 4]
perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0]
transpose = g.op("Transpose", loop_output, perm_i=perm)
squeeze = symbolic_helper._squeeze_helper(g, transpose, [0])
return squeeze
return symbolic_helper._unimplemented("Unfold", "input size not accessible")
@_onnx_symbolic("aten::tensordot")
@symbolic_helper.parse_args("v", "v", "is", "is", "v")
@_beartype.beartype
def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None):
if out is not None:
symbolic_helper._unimplemented(
"Tensordot", "Out parameter is not supported for tensordot."
)
dim_count_a = symbolic_helper._get_tensor_rank(input_a)
if dim_count_a is None:
raise errors.SymbolicValueError(
"Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.",
input_a,
)
dim_count_b = symbolic_helper._get_tensor_rank(input_b)
if dim_count_b is None:
raise errors.SymbolicValueError(
"Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.",
input_b,
)
dims_a = [
(dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i]
for i in range(len(dims_a))
]
dims_b = [
(dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i]
for i in range(len(dims_b))
]
left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)]
left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)]
new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a)
new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b)
input_shape = g.op("Shape", new_input_a)
left_sizes_a = symbolic_helper._slice_helper(
g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)]
)
shape_sizes = [
left_sizes_a,
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
]
output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
input_shape = g.op("Shape", output_a)
slices = symbolic_helper._slice_helper(
g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
)
shape_sizes = [
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
slices,
]
output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
input_shape = g.op("Shape", new_input_b)
left_sizes_b = symbolic_helper._slice_helper(
g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize]
)
slices = symbolic_helper._slice_helper(
g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)]
)
shape_sizes = [
slices,
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
]
output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
input_shape = g.op("Shape", output_b)
slices = symbolic_helper._slice_helper(
g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
)
shape_sizes = [
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
slices,
]
output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b]))
shape_sizes = [left_sizes_a, left_sizes_b]
return opset9._reshape_from_tensor(g, output, shape_sizes)