mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: The "cast" operator is currently added after the cumsum operator, but it should be added before, since torch.cumsum supports more types than ONNX (specifically, bool). Pull Request resolved: https://github.com/pytorch/pytorch/pull/40044 Reviewed By: hl475 Differential Revision: D22158013 Pulled By: houseroad fbshipit-source-id: e6c706572b9b8de880d4d71eaa132744ef01ad4d
677 lines
29 KiB
Python
677 lines
29 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
from sys import maxsize
|
|
|
|
import torch
|
|
import torch.onnx.symbolic_helper as sym_help
|
|
import warnings
|
|
import numpy
|
|
|
|
from torch.onnx.symbolic_helper import parse_args, _unimplemented
|
|
from torch.onnx.symbolic_opset9 import expand, unused
|
|
from torch.nn.modules.utils import _single, _pair, _triple
|
|
|
|
|
|
# EDITING THIS FILE? READ THIS FIRST!
|
|
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
|
|
|
# This file exports ONNX ops for opset 11
|
|
|
|
|
|
@parse_args('v', 'f', 'f')
|
|
def hardtanh(g, self, min_val, max_val):
|
|
dtype = self.type().scalarType()
|
|
if dtype is None:
|
|
dtype = 6 # float
|
|
else:
|
|
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
|
|
min_val = g.op("Constant", value_t=torch.tensor(min_val, dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
|
|
max_val = g.op("Constant", value_t=torch.tensor(max_val, dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
|
|
return g.op("Clip", self, min_val, max_val)
|
|
|
|
|
|
def clamp(g, self, min, max):
|
|
dtype = self.type().scalarType()
|
|
|
|
def _cast_if_not_none(tensor, dtype):
|
|
if tensor is not None and not sym_help._is_none(tensor):
|
|
return g.op("Cast", tensor, to_i=sym_help.cast_pytorch_to_onnx[dtype])
|
|
else:
|
|
return tensor
|
|
|
|
if dtype is not None:
|
|
min = _cast_if_not_none(min, dtype)
|
|
max = _cast_if_not_none(max, dtype)
|
|
return g.op("Clip", self, min, max)
|
|
|
|
|
|
def clamp_min(g, self, min):
|
|
max = unused(g)
|
|
return clamp(g, self, min, max)
|
|
|
|
|
|
def clamp_max(g, self, max):
|
|
min = unused(g)
|
|
return clamp(g, self, min, max)
|
|
|
|
|
|
# Opset 11 gather accepts negative indices
|
|
@parse_args('v', 'i', 'v')
|
|
def select(g, self, dim, index):
|
|
return g.op("Gather", self, index, axis_i=dim)
|
|
|
|
|
|
def index_put(g, self, indices_list_value, values, accumulate=False):
|
|
indices_list = sym_help._unpack_list(indices_list_value)
|
|
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
|
args = [self] + indices_list + [values, accumulate]
|
|
return g.op("ATen", *args, operator_s='index_put')
|
|
|
|
from torch.onnx.symbolic_opset9 import add, expand
|
|
accumulate = sym_help._parse_arg(accumulate, 'b')
|
|
|
|
index = indices_list[0]
|
|
|
|
if len(indices_list) > 1:
|
|
for ind in indices_list[1:]:
|
|
index = add(g, index, ind)
|
|
broadcast_index_shape = g.op("Shape", index)
|
|
indices_list = [
|
|
g.op("Unsqueeze", expand(g, ind, broadcast_index_shape, None), axes_i=[-1]) for ind in indices_list
|
|
]
|
|
index = g.op("Concat", *indices_list, axis_i=-1)
|
|
else:
|
|
broadcast_index_shape = g.op("Shape", index)
|
|
index = g.op("Unsqueeze", index, axes_i=[-1])
|
|
sub_data_shape = sym_help._slice_helper(
|
|
g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize])
|
|
values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0)
|
|
values = g.op("Reshape", values, values_shape)
|
|
|
|
if accumulate:
|
|
dtype = self.type().scalarType()
|
|
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
|
|
dtype = sym_help.scalar_type_to_pytorch_type[dtype]
|
|
zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype))
|
|
result = g.op("ScatterND", zeros, index, values)
|
|
result = add(g, self, result)
|
|
else:
|
|
result = g.op("ScatterND", self, index, values)
|
|
|
|
return result
|
|
|
|
|
|
@parse_args('v', 'i')
|
|
def pixel_shuffle(g, self, upscale_factor):
|
|
dims = self.type().sizes()
|
|
if len(dims) != 4:
|
|
return _unimplemented("pixel_shuffle", "only support 4d input")
|
|
return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD")
|
|
|
|
|
|
def _interpolate(name, dim, interpolate_mode):
|
|
def symbolic_fn(g, input, output_size, *args):
|
|
scales, align_corners = sym_help._get_interpolate_attributes(g, interpolate_mode, args)
|
|
align_corners = sym_help._maybe_get_scalar(align_corners)
|
|
coordinate_transformation_mode = "asymmetric" if interpolate_mode == "nearest" \
|
|
else "align_corners" if align_corners else "pytorch_half_pixel"
|
|
empty_tensor = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
|
|
|
|
if scales is None:
|
|
input_size = g.op("Shape", input)
|
|
input_size_beg = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0])
|
|
output_size = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Long"])
|
|
output_size = g.op("Concat", input_size_beg, output_size, axis_i=0)
|
|
scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
|
|
return g.op("Resize",
|
|
input,
|
|
empty_tensor, # roi only takes effect whith coordinate_transformation_mode="tf_crop_and_resize"
|
|
scales, # scales is not needed since we are sending out_size
|
|
output_size,
|
|
coordinate_transformation_mode_s=coordinate_transformation_mode,
|
|
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
|
|
mode_s=interpolate_mode, # nearest, linear, or cubic
|
|
nearest_mode_s="floor") # only valid when mode="nearest"
|
|
else:
|
|
return g.op("Resize",
|
|
input,
|
|
empty_tensor, # roi only takes effect with coordinate_transformation_mode="tf_crop_and_resize"
|
|
scales, # scales is not needed since we are sending out_size
|
|
coordinate_transformation_mode_s=coordinate_transformation_mode,
|
|
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
|
|
mode_s=interpolate_mode, # nearest, linear, or cubic
|
|
nearest_mode_s="floor") # only valid when mode="nearest"
|
|
return symbolic_fn
|
|
|
|
|
|
upsample_nearest1d = _interpolate('upsample_nearest1d', 3, "nearest")
|
|
upsample_nearest2d = _interpolate('upsample_nearest2d', 4, "nearest")
|
|
upsample_nearest3d = _interpolate('upsample_nearest3d', 5, "nearest")
|
|
upsample_linear1d = _interpolate('upsample_linear1d', 3, "linear")
|
|
upsample_bilinear2d = _interpolate('upsample_bilinear2d', 4, "linear")
|
|
upsample_trilinear3d = _interpolate('upsample_trilinear3d', 5, "linear")
|
|
upsample_bicubic2d = _interpolate('upsample_bicubic2d', 4, "cubic")
|
|
|
|
|
|
def __interpolate(g, input, size, scale_factor, mode, align_corners, recompute_scale_factor):
|
|
mode = sym_help._maybe_get_const(mode, 's')
|
|
if 'linear' in mode:
|
|
mode = 'linear'
|
|
if 'cubic' in mode:
|
|
mode = 'cubic'
|
|
align_corners = sym_help._maybe_get_const(align_corners, 'b')
|
|
align_corners = False if not isinstance(align_corners, bool) else align_corners
|
|
coordinate_transformation_mode = "asymmetric" if mode == "nearest" \
|
|
else "align_corners" if align_corners else "pytorch_half_pixel"
|
|
# roi only takes effect with coordinate_transformation_mode="tf_crop_and_resize"
|
|
roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
|
|
|
|
if not sym_help._is_none(size) :
|
|
input_size = g.op("Shape", input)
|
|
input_size = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0])
|
|
# in some cases size is not a packed list but size is a scalar
|
|
# We need to also verify that (sym_help._maybe_get_const(size, 't').dim() == 0)
|
|
# but this information is not always available. Try to get the dim,
|
|
# and if not assume that it is not a scalar.
|
|
try:
|
|
is_scalar = not sym_help._is_packed_list(size) and ((sym_help._maybe_get_const(size, 't').dim() == 0))
|
|
except AttributeError:
|
|
is_scalar = not sym_help._is_packed_list(size)
|
|
if not is_scalar:
|
|
warnings.warn("Cannot verify if the output_size is a scalar "
|
|
"while exporting interpolate. Assuming that it is not a scalar.")
|
|
|
|
if is_scalar:
|
|
if not input.type().dim():
|
|
return sym_help._unimplemented("interpolate (with a scalar output_size)",
|
|
"missing input shape (try giving an array of output_size values)")
|
|
size = unsqueeze(g, size, 0)
|
|
size = [size for i in range(input.type().dim() - 2)]
|
|
size = g.op("Concat", *size, axis_i=0)
|
|
size = g.op("Cast", size, to_i=sym_help.cast_pytorch_to_onnx['Long'])
|
|
size = g.op("Concat", input_size, size, axis_i=0)
|
|
scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
|
|
return g.op("Resize",
|
|
input,
|
|
roi,
|
|
scales,
|
|
size,
|
|
coordinate_transformation_mode_s=coordinate_transformation_mode,
|
|
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
|
|
mode_s=mode, # nearest, linear, or cubic
|
|
nearest_mode_s="floor")
|
|
else: # if not sym_help._is_none(scales)
|
|
if not input.type().dim():
|
|
return sym_help._unimplemented("interpolate (with scales)", "missing input shape")
|
|
scales = sym_help._interpolate_get_scales(g, scale_factor, input.type().dim())
|
|
return g.op("Resize",
|
|
input,
|
|
roi,
|
|
scales,
|
|
coordinate_transformation_mode_s=coordinate_transformation_mode,
|
|
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
|
|
mode_s=mode, # nearest, linear, or cubic
|
|
nearest_mode_s="floor") # only valid when mode="nearest"
|
|
|
|
@parse_args('v', 'i', 'v', 'v')
|
|
def gather(g, self, dim, index, sparse_grad=False):
|
|
if sym_help._maybe_get_const(sparse_grad, 'i'):
|
|
return _unimplemented("gather", "sparse_grad == True")
|
|
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
|
return g.op("ATen", self, dim, index, sparse_grad, operator_s="gather")
|
|
return g.op("GatherElements", self, index, axis_i=dim)
|
|
|
|
|
|
@parse_args('v', 'i', 'v', 'v')
|
|
def scatter(g, self, dim, index, src):
|
|
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
|
return g.op("ATen", self, dim, index, src, operator_s="scatter")
|
|
return g.op("ScatterElements", self, index, src, axis_i=dim)
|
|
|
|
|
|
@parse_args('v', 'i', 'none')
|
|
def cumsum(g, self, dim, dtype=None):
|
|
dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int))
|
|
if dtype and dtype.node().kind() != 'prim::Constant':
|
|
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
|
|
cast = g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
|
|
else:
|
|
cast = self
|
|
csum = g.op("CumSum", cast, dim_tensor)
|
|
return csum
|
|
|
|
|
|
def masked_select(g, self, mask):
|
|
from torch.onnx.symbolic_opset9 import nonzero, expand_as
|
|
index = nonzero(g, expand_as(g, mask, self))
|
|
return g.op('GatherND', self, index)
|
|
|
|
|
|
def masked_scatter(g, self, mask, source):
|
|
from torch.onnx.symbolic_opset9 import nonzero, expand_as, view, size
|
|
index = nonzero(g, expand_as(g, mask, self))
|
|
# NOTE: source can have more elements than needed.
|
|
# It could also have arbitrary shape.
|
|
# This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor.
|
|
source = view(g, source, torch.LongTensor([-1]))
|
|
source = sym_help._slice_helper(g, source,
|
|
axes=torch.LongTensor([0]),
|
|
starts=torch.LongTensor([0]),
|
|
ends=size(g, index, torch.LongTensor([0])),
|
|
dynamic_slice=True)
|
|
return g.op('ScatterND', self, index, source)
|
|
|
|
|
|
def _len(g, self):
|
|
return g.op("SequenceLength", self)
|
|
|
|
|
|
def __getitem_(g, self, i):
|
|
if self.type().isSubtypeOf(torch._C.ListType.ofTensors()):
|
|
# SequenceAt requires that the input be a List of Tensors
|
|
return g.op("SequenceAt", self, i)
|
|
else:
|
|
from torch.onnx.symbolic_opset9 import __getitem_ as getitem
|
|
return getitem(g, self, i)
|
|
|
|
|
|
def append(g, self, tensor):
|
|
return g.op("SequenceInsert", self, tensor)
|
|
|
|
|
|
def insert(g, self, pos, tensor):
|
|
return g.op("SequenceInsert", self, tensor, pos)
|
|
|
|
|
|
def pop(g, tensor_list, dim):
|
|
return g.op("SequenceErase", tensor_list, dim)
|
|
|
|
|
|
def cat(g, tensor_list, dim):
|
|
if sym_help._is_packed_list(tensor_list):
|
|
from torch.onnx.symbolic_opset9 import cat as cat_opset9
|
|
return cat_opset9(g, tensor_list, dim)
|
|
else:
|
|
dim = sym_help._get_const(dim, 'i', 'dim')
|
|
return g.op("ConcatFromSequence", tensor_list, axis_i=dim)
|
|
|
|
|
|
def stack(g, tensor_list, dim):
|
|
if sym_help._is_packed_list(tensor_list):
|
|
from torch.onnx.symbolic_opset9 import stack as stack_opset9
|
|
return stack_opset9(g, tensor_list, dim)
|
|
else:
|
|
dim = sym_help._get_const(dim, 'i', 'dim')
|
|
return g.op("ConcatFromSequence", tensor_list, axis_i=dim, new_axis_i=1)
|
|
|
|
|
|
@parse_args('v', 'i', 'i', 'i')
|
|
def _unique2(g, self, sorted, return_inverse, return_counts):
|
|
u, indices, inverse_indices, counts = g.op("Unique", self, sorted_i=sorted, outputs=4)
|
|
return u, inverse_indices, counts
|
|
|
|
|
|
def _avg_pool(name, tuple_fn):
|
|
@parse_args('v', 'is', 'is', 'is', 'i', 'i', 'none')
|
|
def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override=None):
|
|
padding = sym_help._avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, name)
|
|
if not stride:
|
|
stride = kernel_size
|
|
if count_include_pad:
|
|
input = g.op("Pad", input,
|
|
g.op("Constant", value_t=torch.tensor(((0,) * 2 + padding) * 2)), mode_s='constant')
|
|
padding = (0,) * len(padding)
|
|
output = g.op("AveragePool", input,
|
|
kernel_shape_i=tuple_fn(kernel_size),
|
|
strides_i=tuple_fn(stride),
|
|
pads_i=padding * 2,
|
|
ceil_mode_i=ceil_mode)
|
|
return output
|
|
return symbolic_fn
|
|
|
|
|
|
avg_pool1d = _avg_pool('avg_pool1d', _single)
|
|
avg_pool2d = _avg_pool('avg_pool2d', _pair)
|
|
avg_pool3d = _avg_pool('avg_pool3d', _triple)
|
|
|
|
|
|
@parse_args('v', 'i', 'i', 'i', 'i')
|
|
def unique_dim(g, self, dim, sorted, return_inverse, return_counts):
|
|
u, indices, inverse_indices, counts = g.op("Unique", self, axis_i=dim, sorted_i=sorted, outputs=4)
|
|
return u, inverse_indices, counts
|
|
|
|
|
|
@parse_args('v', 'v', 'i', 'i', 'i', 'none')
|
|
def topk(g, self, k, dim, largest, sorted, out=None):
|
|
return sym_help._topk_helper(g, self, k, dim, largest=largest, sorted=sorted, out=out)
|
|
|
|
|
|
@parse_args('v', 'i', 'i', 'none')
|
|
def sort(g, self, dim, decending, out=None):
|
|
return sym_help._sort_helper(g, self, dim, decending=decending, out=out)
|
|
|
|
|
|
def round(g, self):
|
|
return g.op("Round", self)
|
|
|
|
|
|
@parse_args('v', 'v', 'i')
|
|
def split_with_sizes(g, self, split_sizes, dim):
|
|
if sym_help._is_value(split_sizes) and split_sizes.node().kind() == 'prim::ListConstruct':
|
|
return g.op("SplitToSequence", self, split_sizes, axis_i=dim)
|
|
else:
|
|
return torch.onnx.symbolic_opset9.split_with_sizes(g, self, split_sizes, dim)
|
|
|
|
|
|
# Generate paddings in ONNX order based on pad in pytorch.
|
|
# Arguments:
|
|
# dim: the dimension of the tensor.
|
|
# pad: the paddings in pytorch.
|
|
# The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end,
|
|
# where m is in range [0, n].
|
|
def _prepare_onnx_paddings(g, dim, pad):
|
|
# The desired order of paddings is
|
|
# dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
|
|
# n is the dimension of input.
|
|
# Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning
|
|
pad_len = torch.onnx.symbolic_opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0])))
|
|
# Set extension = [0] * (dim * 2 - len(pad))
|
|
extension = g.op("Sub", g.op("Mul", g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int64)),
|
|
g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))), pad_len)
|
|
# Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ]
|
|
# Currently ONNX only supports int64 type for Pad
|
|
pad = g.op("Cast", pad, to_i=sym_help.cast_pytorch_to_onnx['Long'])
|
|
paddings = g.op("Concat", pad, g.op("ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64)), axis_i=0)
|
|
# Reshape and reverse order and collate first beginnings and then ends
|
|
# paddings = [[..., 0, dim_n-1_begin, dim_n_begin],
|
|
# [..., 0, dim_n-1_end, dim_n_end]]
|
|
# Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end]
|
|
paddings = g.op("Reshape", paddings, g.op("Constant", value_t=torch.tensor([-1, 2])))
|
|
paddings = g.op("Transpose", torch.onnx.symbolic_opset10.flip(g, paddings, [0]), perm_i=[1, 0])
|
|
paddings = g.op("Reshape", paddings, g.op("Constant", value_t=torch.tensor([-1])))
|
|
padding_c = g.op("Cast", paddings, to_i=sym_help.cast_pytorch_to_onnx['Long'])
|
|
return padding_c
|
|
|
|
|
|
def constant_pad_nd(g, input, padding, value=None):
|
|
mode = "constant"
|
|
value = sym_help._maybe_get_scalar(value)
|
|
value = sym_help._if_scalar_type_as(g, value, input)
|
|
pad = _prepare_onnx_paddings(g, input.type().dim(), padding)
|
|
return g.op("Pad", input, pad, value, mode_s=mode)
|
|
|
|
|
|
def reflection_pad(g, input, padding):
|
|
mode = "reflect"
|
|
paddings = _prepare_onnx_paddings(g, input.type().dim(), padding)
|
|
return g.op("Pad", input, paddings, mode_s=mode)
|
|
|
|
|
|
def replication_pad(g, input, padding):
|
|
mode = "edge"
|
|
paddings = _prepare_onnx_paddings(g, input.type().dim(), padding)
|
|
return g.op("Pad", input, paddings, mode_s=mode)
|
|
|
|
|
|
reflection_pad1d = reflection_pad
|
|
reflection_pad2d = reflection_pad
|
|
reflection_pad3d = reflection_pad
|
|
replication_pad1d = replication_pad
|
|
replication_pad2d = replication_pad
|
|
replication_pad3d = replication_pad
|
|
|
|
|
|
def det(g, self):
|
|
return g.op("Det", self)
|
|
|
|
|
|
def logdet(g, input):
|
|
from torch.onnx.symbolic_opset9 import log
|
|
return log(g, det(g, input))
|
|
|
|
|
|
def arange(g, *args):
|
|
def _get_arange_dtype(dtype):
|
|
dtype = sym_help._maybe_get_const(dtype, 'i')
|
|
return dtype
|
|
|
|
if len(args) == 5:
|
|
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
|
dtype = _get_arange_dtype(args[1])
|
|
type, end, start, step = sym_help._arange_cast_helper(g, end=args[0], dtype=dtype)
|
|
start_default = g.op("Constant", value_t=torch.tensor(0, dtype=sym_help.scalar_type_to_pytorch_type[type]))
|
|
delta_default = g.op("Constant", value_t=torch.tensor(1, dtype=sym_help.scalar_type_to_pytorch_type[type]))
|
|
arange_tensor = g.op("Range", start_default, end, delta_default)
|
|
elif len(args) == 6:
|
|
# aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
|
dtype = _get_arange_dtype(args[2])
|
|
type, end, start, step = sym_help._arange_cast_helper(g, start=args[0], end=args[1], dtype=dtype)
|
|
delta_default = g.op("Constant", value_t=torch.tensor(1, dtype=sym_help.scalar_type_to_pytorch_type[type]))
|
|
arange_tensor = g.op("Range", start, end, delta_default)
|
|
elif len(args) == 7:
|
|
# aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
|
|
dtype = _get_arange_dtype(args[3])
|
|
type, end, start, step = sym_help._arange_cast_helper(g, start=args[0], end=args[1], step=args[2], dtype=dtype)
|
|
arange_tensor = g.op("Range", start, end, step)
|
|
else:
|
|
raise NotImplementedError("Unknown aten::arange signature taking " + str(len(args)) + " arguments.")
|
|
return arange_tensor
|
|
|
|
|
|
@parse_args('v', 'i')
|
|
def _dim_arange(g, like, dim):
|
|
like_shape = g.op('Shape', like)
|
|
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
|
|
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
|
return g.op("_caffe2::Range", stop)
|
|
return arange(g, stop, 4, None, None, None)
|
|
|
|
|
|
def size(g, self, dim=None):
|
|
if dim is None:
|
|
return g.op("Shape", self)
|
|
return sym_help._size_helper(g, self, dim)
|
|
|
|
|
|
def squeeze(g, self, dim=None):
|
|
if dim is None:
|
|
dims = []
|
|
for i, size in enumerate(self.type().sizes()):
|
|
if size == 1:
|
|
dims.append(i)
|
|
else:
|
|
dims = [sym_help._get_const(dim, 'i', 'dim')]
|
|
return g.op("Squeeze", self, axes_i=dims)
|
|
|
|
|
|
@parse_args('v', 'i')
|
|
def unsqueeze(g, self, dim):
|
|
return g.op("Unsqueeze", self, axes_i=[dim])
|
|
|
|
|
|
def mm(g, self, other):
|
|
return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0)
|
|
|
|
|
|
def index_fill(g, self, dim, index, value):
|
|
dim_value = sym_help._parse_arg(dim, 'i')
|
|
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
|
return g.op("ATen", self, index, value, dim_i=dim_value, operator_s="index_fill")
|
|
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
|
|
value = sym_help._maybe_get_scalar(value)
|
|
value = sym_help._if_scalar_type_as(g, value, self)
|
|
expanded_value = expand(g, value, expanded_index_shape, None)
|
|
return scatter(g, self, dim, expanded_index, expanded_value)
|
|
|
|
|
|
def index_copy(g, self, dim, index, source):
|
|
dim_value = sym_help._parse_arg(dim, 'i')
|
|
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
|
return g.op("ATen", self, index, source, dim_i=dim_value, operator_s="index_copy")
|
|
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
|
|
return scatter(g, self, dim, expanded_index, source)
|
|
|
|
|
|
def __rshift_(g, self, other):
|
|
# make sure to cast other to self's type
|
|
# (when self is long, make sure that other is not float)
|
|
if other.type().scalarType() != self.type().scalarType():
|
|
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
|
|
|
|
if self.type().scalarType() == 'Byte':
|
|
return g.op('BitShift', self, other, direction_s="RIGHT")
|
|
|
|
two = g.op('Constant', value_t=torch.tensor(2, dtype=torch.float32))
|
|
# exponent (same type as self) has to be float or double in onnx::Pow
|
|
if not sym_help._is_fp(self):
|
|
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx['Float'])
|
|
two_pow = g.op('Pow', two, other)
|
|
|
|
rshift = g.op('Div', self, two_pow)
|
|
return rshift
|
|
|
|
|
|
def __lshift_(g, self, other):
|
|
# make sure to cast other to self's type
|
|
# (when self is long, make sure that other is not float)
|
|
if other.type().scalarType() != self.type().scalarType():
|
|
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
|
|
|
|
if self.type().scalarType() == 'Byte':
|
|
return g.op('BitShift', self, other, direction_s="LEFT")
|
|
|
|
two = g.op('Constant', value_t=torch.tensor(2, dtype=torch.float32))
|
|
# exponent (same type as self) has to be float or double in onnx::Pow
|
|
if not sym_help._is_fp(self):
|
|
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx['Float'])
|
|
two_pow = g.op('Pow', two, other)
|
|
|
|
lshift = g.op('Mul', self, two_pow)
|
|
return lshift
|
|
|
|
|
|
def _get_im2col_indices_along_dim(g, input_d, kernel_size_d, dilation_d, padding_d, stride_d):
|
|
# Input is always 4-D (N, C, H, W)
|
|
# Calculate indices of sliding blocks along spatial dimension
|
|
# Slide kernel over input each dim d:
|
|
# each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1)
|
|
# with steps = stride
|
|
|
|
blocks_d = g.op("Add", input_d, g.op("Constant", value_t=torch.tensor(padding_d * 2)))
|
|
blocks_d = g.op("Sub", blocks_d, g.op("Constant", value_t=torch.tensor(dilation_d * (kernel_size_d - 1))))
|
|
|
|
# Stride kernel over input and find starting indices along dim d
|
|
blocks_d_indices = g.op("Range", g.op("Constant", value_t=torch.tensor(0)),
|
|
blocks_d, g.op("Constant", value_t=torch.tensor(stride_d)))
|
|
|
|
# Apply dilation on kernel and find its indices along dim d
|
|
kernel_grid = numpy.arange(0, kernel_size_d * dilation_d, dilation_d)
|
|
kernel_grid = g.op("Constant", value_t=torch.tensor([kernel_grid]))
|
|
|
|
# Broadcast and add kernel staring positions (indices) with
|
|
# kernel_grid along dim d, to get block indices along dim d
|
|
blocks_d_indices = g.op('Unsqueeze', blocks_d_indices, axes_i=[0]) # Reshape to [1, -1]
|
|
kernel_mask = g.op('Reshape', kernel_grid, g.op('Constant', value_t=torch.tensor([-1, 1])))
|
|
block_mask = g.op("Add", blocks_d_indices, kernel_mask)
|
|
|
|
return block_mask
|
|
|
|
|
|
def _get_im2col_padded_input(g, input, padding_h, padding_w):
|
|
# Input is always 4-D tensor (N, C, H, W)
|
|
# Padding tensor has the following format: (padding_h, padding_w)
|
|
# Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
|
|
pad = g.op("Constant", value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2))
|
|
return g.op("Pad", input, pad)
|
|
|
|
|
|
def _get_im2col_output_shape(g, input, kernel_h, kernel_w):
|
|
batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0)))
|
|
channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1)))
|
|
channel_unfolded = g.op("Mul", channel_dim,
|
|
g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w)))
|
|
|
|
return g.op("Concat",
|
|
g.op("Unsqueeze", batch_dim, axes_i=[0]),
|
|
g.op("Unsqueeze", channel_unfolded, axes_i=[0]),
|
|
g.op("Constant", value_t=torch.tensor([-1])), axis_i=0)
|
|
|
|
|
|
@parse_args('v', 'is', 'is', 'is', 'is')
|
|
def im2col(g, input, kernel_size, dilation, padding, stride):
|
|
# Input is always 4-D tensor (N, C, H, W)
|
|
# All other args are int[2]
|
|
|
|
input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2)))
|
|
input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3)))
|
|
|
|
stride_h, stride_w = stride[0], stride[1]
|
|
padding_h, padding_w = padding[0], padding[1]
|
|
dilation_h, dilation_w = dilation[0], dilation[1]
|
|
kernel_h, kernel_w = kernel_size[0], kernel_size[1]
|
|
|
|
blocks_row_indices = _get_im2col_indices_along_dim(g, input_h, kernel_h, dilation_h, padding_h, stride_h)
|
|
blocks_col_indices = _get_im2col_indices_along_dim(g, input_w, kernel_w, dilation_w, padding_w, stride_w)
|
|
|
|
output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w)
|
|
padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w)
|
|
|
|
# For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1
|
|
# [[[[1., 2., 3.,],
|
|
# [4., 5., 6.,],
|
|
# [7., 8., 9.,]]]]
|
|
# First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get:
|
|
# [[[[[1., 2., 3.],
|
|
# [4., 5., 6.]],
|
|
# [[4., 5., 6.],
|
|
# [7., 8., 9.]]]]]
|
|
# And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get:
|
|
# [[[[[[1., 2.],
|
|
# [4., 5.]],
|
|
# [[2., 3.],
|
|
# [5., 6]]],
|
|
# [[[4., 5.],
|
|
# [7., 8.]],
|
|
# [[5., 6.],
|
|
# [8., 9.]]]]]]
|
|
# Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get:
|
|
# [[[1., 2., 4., 5.],
|
|
# [2., 3., 5., 6.],
|
|
# [4., 5., 7., 8.],
|
|
# [5., 6., 8., 9.]]]
|
|
output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2)
|
|
output = g.op("Gather", output, blocks_col_indices, axis_i=4)
|
|
output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5])
|
|
return g.op("Reshape", output, output_shape)
|
|
|
|
|
|
@parse_args('v', 'i', 'i')
|
|
def flatten(g, input, start_dim, end_dim):
|
|
dim = input.type().dim()
|
|
# use ONNX's Flatten operator for cases where the output shape is 2D
|
|
if start_dim == 1:
|
|
if (end_dim == -1 or (end_dim is not None and end_dim == dim - 1)):
|
|
return g.op("Flatten", input, axis_i=start_dim)
|
|
elif start_dim == 0:
|
|
if (end_dim == -2 or (end_dim is not None and end_dim == dim - 2)):
|
|
return g.op("Flatten", input, axis_i=end_dim + 1)
|
|
# use Reshape for cases where the output shape is not 2D
|
|
if not input.isCompleteTensor():
|
|
return _unimplemented("flatten",
|
|
"input size not accessible "
|
|
"(consider using reshape op instead of flatten op to export to ONNX)")
|
|
# if end_dim is negative add dim
|
|
if end_dim < 0 :
|
|
end_dim = dim + end_dim
|
|
input_dims = input.type().sizes()
|
|
output_dims = []
|
|
for i in range(0, dim):
|
|
if start_dim < i and end_dim >= i:
|
|
output_dims[start_dim] = output_dims[start_dim] * input_dims[i]
|
|
else:
|
|
output_dims.append(input_dims[i])
|
|
shape = g.op("Constant", value_t=torch.LongTensor(output_dims))
|
|
from torch.onnx.symbolic_opset9 import _reshape_from_tensor
|
|
p = _reshape_from_tensor(g, input, shape)
|
|
return p
|