mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fixes https://github.com/pytorch/pytorch/issues/84365 and more This PR addresses not only the issue above, but the entire family of issues related to `torch._C.Value.type()` parsing when `scalarType()` or `dtype()` is not available. This issue exists before `JitScalarType` was introduced, but the new implementation refactored the bug in because the new api `from_name` and `from_dtype` requires parsing `torch._C.Value.type()` to get proper inputs, which is exactly the root cause for this family of bugs. Therefore `from_name` and `from_dtype` must be called when the implementor knows the `name` and `dtype` without parsing a `torch._C.Value`. To handle the corner cases hidden within `torch._C.Value`, a new `from_value` API was introduced and it should be used in favor of the former ones for most cases. The new API is safer and doesn't require type parsing from user, triggering JIT asserts in the core of pytorch. Although CI is passing for all tests, please review carefully all symbolics/helpers refactoring to make sure the meaning/intetion of the old call are not changed in the new call Pull Request resolved: https://github.com/pytorch/pytorch/pull/87245 Approved by: https://github.com/justinchuby, https://github.com/BowenBao
484 lines
16 KiB
Python
484 lines
16 KiB
Python
import functools
|
|
import sys
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
from torch._C import _onnx as _C_onnx
|
|
from torch.onnx import (
|
|
_type_utils,
|
|
errors,
|
|
symbolic_helper,
|
|
symbolic_opset9 as opset9,
|
|
utils,
|
|
)
|
|
from torch.onnx._internal import _beartype, jit_utils, registration
|
|
|
|
|
|
# EDITING THIS FILE? READ THIS FIRST!
|
|
# see Note [Edit Symbolic Files] in README.md
|
|
|
|
# This file exports ONNX ops for opset 12
|
|
|
|
__all__ = [
|
|
"argmax",
|
|
"argmin",
|
|
"binary_cross_entropy_with_logits",
|
|
"celu",
|
|
"cross_entropy_loss",
|
|
"dropout",
|
|
"einsum",
|
|
"ge",
|
|
"le",
|
|
"native_dropout",
|
|
"nll_loss",
|
|
"nll_loss2d",
|
|
"nll_loss_nd",
|
|
"outer",
|
|
"pow",
|
|
"tensordot",
|
|
"unfold",
|
|
]
|
|
|
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12)
|
|
|
|
|
|
@_beartype.beartype
|
|
def _einsum_helper(g: jit_utils.GraphContext, equation, tensors):
|
|
if not tensors:
|
|
raise RuntimeError("Einsum inputs are empty.")
|
|
# ONNX does not support bool for Einsum inputs.
|
|
if symbolic_helper._is_bool(tensors[0]):
|
|
tensors = [
|
|
g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64)
|
|
for tensor in tensors
|
|
]
|
|
return g.op(
|
|
"Cast",
|
|
g.op("Einsum", *tensors, equation_s=equation),
|
|
to_i=_C_onnx.TensorProtoDataType.BOOL,
|
|
)
|
|
else:
|
|
return g.op("Einsum", *tensors, equation_s=equation)
|
|
|
|
|
|
@_onnx_symbolic("aten::einsum")
|
|
@symbolic_helper.parse_args("s", "v", "is")
|
|
@_beartype.beartype
|
|
def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None):
|
|
tensors = symbolic_helper._unpack_list(tensor_list)
|
|
return _einsum_helper(g, equation, tensors)
|
|
|
|
|
|
@_onnx_symbolic("aten::outer")
|
|
@symbolic_helper.parse_args("v", "v")
|
|
@_beartype.beartype
|
|
def outer(g: jit_utils.GraphContext, input, other):
|
|
# make sure to cast other to self's type
|
|
if _type_utils.JitScalarType.from_value(
|
|
other, _type_utils.JitScalarType.UNDEFINED
|
|
) != _type_utils.JitScalarType.from_value(input):
|
|
other = g.op(
|
|
"Cast",
|
|
other,
|
|
to_i=_type_utils.JitScalarType.from_value(input).onnx_type(),
|
|
)
|
|
return _einsum_helper(g, "i,j->ij", [input, other])
|
|
|
|
|
|
@_beartype.beartype
|
|
def _dropout_returns_masked_input_and_mask(
|
|
g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool
|
|
) -> Tuple[torch._C.Value, Optional[torch._C.Value]]:
|
|
symbolic_helper.check_training_mode(train, "dropout")
|
|
# In eval mode, dropout is non-op. That is, if the node's
|
|
# train param is set to False, dropout just returns its inputs.
|
|
if not train:
|
|
return input, None
|
|
p = g.op("Constant", value_t=torch.tensor(p))
|
|
t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool))
|
|
r, mask = g.op("Dropout", input, p, t, outputs=2)
|
|
return r, mask
|
|
|
|
|
|
@_onnx_symbolic("aten::dropout")
|
|
@symbolic_helper.parse_args("v", "f", "b")
|
|
@_beartype.beartype
|
|
def dropout(g: jit_utils.GraphContext, input, p, train):
|
|
masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train)
|
|
return masked
|
|
|
|
|
|
@_onnx_symbolic("aten::native_dropout")
|
|
@symbolic_helper.parse_args("v", "f", "b")
|
|
@_beartype.beartype
|
|
def native_dropout(g: jit_utils.GraphContext, input, p, train):
|
|
return _dropout_returns_masked_input_and_mask(g, input, p, train)
|
|
|
|
|
|
@_onnx_symbolic("aten::nll_loss")
|
|
@_beartype.beartype
|
|
def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index):
|
|
# none reduction : onnx::Constant[value={0}]
|
|
# mean reduction : onnx::Constant[value={1}]
|
|
# sum reduction : onnx::Constant[value={2}]
|
|
reduction = symbolic_helper._maybe_get_const(reduction, "i")
|
|
reduction_vals = ["none", "mean", "sum"]
|
|
reduction = reduction_vals[reduction]
|
|
|
|
# in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value.
|
|
# therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
|
|
ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
|
|
if weight.node().mustBeNone():
|
|
nllloss = g.op(
|
|
"NegativeLogLikelihoodLoss",
|
|
self,
|
|
target,
|
|
reduction_s=reduction,
|
|
ignore_index_i=ignore_index,
|
|
)
|
|
else:
|
|
nllloss = g.op(
|
|
"NegativeLogLikelihoodLoss",
|
|
self,
|
|
target,
|
|
weight,
|
|
reduction_s=reduction,
|
|
ignore_index_i=ignore_index,
|
|
)
|
|
|
|
return nllloss
|
|
|
|
|
|
@_onnx_symbolic("aten::nll_loss2d")
|
|
@_beartype.beartype
|
|
def nll_loss2d(
|
|
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
|
|
):
|
|
return nll_loss(g, self, target, weight, reduction, ignore_index)
|
|
|
|
|
|
@_onnx_symbolic("aten::nll_loss_nd")
|
|
@_beartype.beartype
|
|
def nll_loss_nd(
|
|
g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
|
|
):
|
|
return nll_loss(g, self, target, weight, reduction, ignore_index)
|
|
|
|
|
|
@_onnx_symbolic("aten::cross_entropy_loss")
|
|
@_beartype.beartype
|
|
def cross_entropy_loss(
|
|
g: jit_utils.GraphContext,
|
|
self,
|
|
target,
|
|
weight,
|
|
reduction,
|
|
ignore_index,
|
|
label_smoothing,
|
|
):
|
|
# none reduction : onnx::Constant[value={0}]
|
|
# mean reduction : onnx::Constant[value={1}]
|
|
# sum reduction : onnx::Constant[value={2}]
|
|
reduction = symbolic_helper._maybe_get_const(reduction, "i")
|
|
reduction_vals = ["none", "mean", "sum"]
|
|
reduction = reduction_vals[reduction]
|
|
|
|
label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f")
|
|
if label_smoothing is not None and label_smoothing > 0.0:
|
|
raise errors.SymbolicValueError(
|
|
"Unsupported: ONNX does not support label_smoothing", self
|
|
)
|
|
|
|
# in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value.
|
|
# therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
|
|
ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
|
|
if weight.node().mustBeNone():
|
|
celoss = g.op(
|
|
"SoftmaxCrossEntropyLoss",
|
|
self,
|
|
target,
|
|
reduction_s=reduction,
|
|
ignore_index_i=ignore_index,
|
|
)
|
|
else:
|
|
celoss = g.op(
|
|
"SoftmaxCrossEntropyLoss",
|
|
self,
|
|
target,
|
|
weight,
|
|
reduction_s=reduction,
|
|
ignore_index_i=ignore_index,
|
|
)
|
|
|
|
return celoss
|
|
|
|
|
|
@_onnx_symbolic("aten::binary_cross_entropy_with_logits")
|
|
@symbolic_helper.parse_args("v", "v", "v", "v", "i")
|
|
@_beartype.beartype
|
|
def binary_cross_entropy_with_logits(
|
|
g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction
|
|
):
|
|
p = g.op("Constant", value_t=torch.tensor([1]))
|
|
sig_x = opset9.sigmoid(g, input)
|
|
log_sig_x = opset9.log(g, sig_x)
|
|
sub_1_x = opset9.sub(g, p, sig_x)
|
|
sub_1_y = opset9.sub(g, p, target)
|
|
log_1_x = opset9.log(g, sub_1_x)
|
|
if pos_weight is None or symbolic_helper._is_none(pos_weight):
|
|
output = opset9.neg(
|
|
g,
|
|
opset9.add(
|
|
g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x)
|
|
),
|
|
)
|
|
else:
|
|
output = opset9.neg(
|
|
g,
|
|
opset9.add(
|
|
g,
|
|
opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight),
|
|
opset9.mul(g, sub_1_y, log_1_x),
|
|
),
|
|
)
|
|
|
|
if weight is not None and not symbolic_helper._is_none(weight):
|
|
output = opset9.mul(g, weight, output)
|
|
|
|
reduction = symbolic_helper._maybe_get_const(reduction, "i")
|
|
if reduction == 0:
|
|
return output
|
|
elif reduction == 1:
|
|
return g.op("ReduceMean", output, keepdims_i=0)
|
|
elif reduction == 2:
|
|
return g.op("ReduceSum", output, keepdims_i=0)
|
|
else:
|
|
return symbolic_helper._onnx_unsupported(
|
|
"binary_cross_entropy_with_logits with reduction other than none, mean, or sum",
|
|
input,
|
|
)
|
|
|
|
|
|
@_onnx_symbolic("aten::celu")
|
|
@_beartype.beartype
|
|
def celu(g: jit_utils.GraphContext, self, alpha):
|
|
alpha = symbolic_helper._maybe_get_const(alpha, "f")
|
|
# if the input is of type double cast it to float
|
|
if (
|
|
_type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED)
|
|
== _type_utils.JitScalarType.DOUBLE
|
|
):
|
|
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
|
out = g.op("Celu", self, alpha_f=alpha)
|
|
return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
|
|
|
|
return g.op("Celu", self, alpha_f=alpha)
|
|
|
|
|
|
@_onnx_symbolic("aten::argmax")
|
|
@symbolic_helper.parse_args("v", "v", "b")
|
|
@_beartype.beartype
|
|
def argmax(
|
|
g: jit_utils.GraphContext,
|
|
input: torch._C.Value,
|
|
dim: torch._C.Value,
|
|
keepdim: bool,
|
|
):
|
|
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax")
|
|
|
|
|
|
@_onnx_symbolic("aten::argmin")
|
|
@symbolic_helper.parse_args("v", "v", "b")
|
|
@_beartype.beartype
|
|
def argmin(
|
|
g: jit_utils.GraphContext,
|
|
input: torch._C.Value,
|
|
dim: torch._C.Value,
|
|
keepdim: bool,
|
|
):
|
|
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin")
|
|
|
|
|
|
@_onnx_symbolic("aten::pow")
|
|
@_beartype.beartype
|
|
def pow(g: jit_utils.GraphContext, self, exponent):
|
|
return g.op("Pow", self, exponent)
|
|
|
|
|
|
@_onnx_symbolic("aten::ge")
|
|
@_beartype.beartype
|
|
def ge(g: jit_utils.GraphContext, input, other):
|
|
return g.op("GreaterOrEqual", input, other)
|
|
|
|
|
|
@_onnx_symbolic("aten::le")
|
|
@_beartype.beartype
|
|
def le(g: jit_utils.GraphContext, input, other):
|
|
return g.op("LessOrEqual", input, other)
|
|
|
|
|
|
@_onnx_symbolic("aten::unfold")
|
|
@symbolic_helper.parse_args("v", "i", "v", "v")
|
|
@_beartype.beartype
|
|
def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
|
|
const_size = symbolic_helper._maybe_get_const(size, "i")
|
|
const_step = symbolic_helper._maybe_get_const(step, "i")
|
|
if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value(
|
|
const_step
|
|
):
|
|
return opset9.unfold(g, input, dimension, const_size, const_step)
|
|
if symbolic_helper.is_caffe2_aten_fallback():
|
|
return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step)
|
|
|
|
sizedim = symbolic_helper._get_tensor_dim_size(input, dimension)
|
|
if sizedim is not None:
|
|
low_start = g.op("Constant", value_t=torch.tensor(0))
|
|
low_end = g.op("Constant", value_t=torch.tensor(sizedim))
|
|
hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1))
|
|
low_indices = g.op("Range", low_start, low_end, step)
|
|
hi_indices = g.op("Range", size, hi_end, step)
|
|
|
|
low_size = symbolic_helper._size_helper(
|
|
g, low_indices, g.op("Constant", value_t=torch.tensor(0))
|
|
)
|
|
hi_size = symbolic_helper._size_helper(
|
|
g, hi_indices, g.op("Constant", value_t=torch.tensor(0))
|
|
)
|
|
|
|
ndim = symbolic_helper._get_tensor_rank(input)
|
|
assert ndim is not None
|
|
perm = list(range(0, ndim))
|
|
perm.append(perm.pop(dimension))
|
|
|
|
unsqueeze_list = []
|
|
loop_condition = g.op("Constant", value_t=torch.tensor(1))
|
|
loop_condition = g.op(
|
|
"Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
|
|
)
|
|
loop_len = g.op("Min", low_size, hi_size)
|
|
|
|
loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
|
|
g, "Loop", loop_len, loop_condition, n_blocks=1
|
|
)
|
|
|
|
loop_block = loop_context.block
|
|
block_input_iter = utils._add_input_to_block(loop_block)
|
|
# FIXME(justinchuby): cond is unused?
|
|
cond = utils._add_input_to_block(loop_block)
|
|
|
|
starts = loop_context.op("Gather", low_indices, block_input_iter)
|
|
ends = loop_context.op("Gather", hi_indices, block_input_iter)
|
|
axes = loop_context.op("Constant", value_t=torch.tensor([2]))
|
|
starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0])
|
|
ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0])
|
|
stack = loop_context.op("Slice", input, starts, ends, axes)
|
|
|
|
unsqueeze = symbolic_helper._unsqueeze_helper(
|
|
loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension]
|
|
)
|
|
unsqueeze_list.append(unsqueeze)
|
|
concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0)
|
|
|
|
cond_out = loop_context.op(
|
|
"Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL
|
|
)
|
|
utils._add_output_to_block(loop_block, cond_out)
|
|
utils._add_output_to_block(loop_block, concat)
|
|
|
|
loop_output = loop.node().output()
|
|
perm = [0, 1, 2, 3, 4]
|
|
perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0]
|
|
transpose = g.op("Transpose", loop_output, perm_i=perm)
|
|
squeeze = symbolic_helper._squeeze_helper(g, transpose, [0])
|
|
|
|
return squeeze
|
|
|
|
return symbolic_helper._unimplemented("Unfold", "input size not accessible")
|
|
|
|
|
|
@_onnx_symbolic("aten::tensordot")
|
|
@symbolic_helper.parse_args("v", "v", "is", "is", "v")
|
|
@_beartype.beartype
|
|
def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None):
|
|
if out is not None:
|
|
symbolic_helper._unimplemented(
|
|
"Tensordot", "Out parameter is not supported for tensordot."
|
|
)
|
|
|
|
dim_count_a = symbolic_helper._get_tensor_rank(input_a)
|
|
if dim_count_a is None:
|
|
raise errors.SymbolicValueError(
|
|
"Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.",
|
|
input_a,
|
|
)
|
|
|
|
dim_count_b = symbolic_helper._get_tensor_rank(input_b)
|
|
if dim_count_b is None:
|
|
raise errors.SymbolicValueError(
|
|
"Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.",
|
|
input_b,
|
|
)
|
|
|
|
dims_a = [
|
|
(dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i]
|
|
for i in range(len(dims_a))
|
|
]
|
|
dims_b = [
|
|
(dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i]
|
|
for i in range(len(dims_b))
|
|
]
|
|
|
|
left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)]
|
|
left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)]
|
|
|
|
new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a)
|
|
new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b)
|
|
|
|
input_shape = g.op("Shape", new_input_a)
|
|
left_sizes_a = symbolic_helper._slice_helper(
|
|
g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)]
|
|
)
|
|
shape_sizes = [
|
|
left_sizes_a,
|
|
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
|
|
]
|
|
output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
|
|
|
|
input_shape = g.op("Shape", output_a)
|
|
slices = symbolic_helper._slice_helper(
|
|
g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
|
|
)
|
|
shape_sizes = [
|
|
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
|
|
slices,
|
|
]
|
|
output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
|
|
|
|
input_shape = g.op("Shape", new_input_b)
|
|
left_sizes_b = symbolic_helper._slice_helper(
|
|
g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize]
|
|
)
|
|
slices = symbolic_helper._slice_helper(
|
|
g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)]
|
|
)
|
|
shape_sizes = [
|
|
slices,
|
|
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
|
|
]
|
|
output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
|
|
|
|
input_shape = g.op("Shape", output_b)
|
|
slices = symbolic_helper._slice_helper(
|
|
g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
|
|
)
|
|
shape_sizes = [
|
|
g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
|
|
slices,
|
|
]
|
|
output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
|
|
|
|
output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b]))
|
|
|
|
shape_sizes = [left_sizes_a, left_sizes_b]
|
|
return opset9._reshape_from_tensor(g, output, shape_sizes)
|