mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Co-authored-by: David Fan <jiafamicrosoft.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/72986
361 lines
17 KiB
Python
361 lines
17 KiB
Python
# -*- coding: utf-8 -*-
|
||
|
||
import torch
|
||
from torch.nn.modules.utils import _single, _pair, _triple
|
||
import torch.onnx
|
||
# This import monkey-patches graph manipulation methods on Graph, used for the
|
||
# ONNX symbolics
|
||
import torch.onnx.utils
|
||
|
||
import torch.onnx.symbolic_helper as sym_help
|
||
from torch.onnx.symbolic_helper import parse_args, _unimplemented
|
||
import torch.onnx.symbolic_opset9
|
||
from torch.onnx.symbolic_opset9 import linear
|
||
|
||
from sys import maxsize
|
||
|
||
# EDITING THIS FILE? READ THIS FIRST!
|
||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||
|
||
# This file exports ONNX ops for opset 10
|
||
# Opset 10 is supported by ONNX release 1.5.0
|
||
# release on 04/24/19
|
||
|
||
|
||
def div(g, self, other, *args):
|
||
if len(args) == 0:
|
||
return torch.onnx.symbolic_opset9.true_divide(g, self, other)
|
||
else:
|
||
return _div_rounding_mode(g, self, other, *args)
|
||
|
||
|
||
@parse_args("v", "v", "s")
|
||
def _div_rounding_mode(g, self, other, rounding_mode):
|
||
if rounding_mode == "floor":
|
||
return _floor_divide(g, self, other)
|
||
else:
|
||
return torch.onnx.symbolic_opset9._div_rounding_mode(g, self, other, rounding_mode)
|
||
|
||
|
||
def _floor_divide(g, self, other):
|
||
if sym_help._is_fp(self) or sym_help._is_fp(other):
|
||
out = torch.onnx.symbolic_opset9.true_divide(g, self, other)
|
||
return g.op("Floor", out)
|
||
else:
|
||
# Integer division does trunction rounding
|
||
div = g.op("Div", self, other)
|
||
# Division is negative if: self < 0 != other < 0
|
||
zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
|
||
negative = g.op("Xor",
|
||
g.op("Less", self, zero),
|
||
g.op("Less", other, zero))
|
||
|
||
# For negative numbers with self % other != 0, subtract 1 to round down instead of up
|
||
mod = g.op("Mod", self, other, fmod_i=0)
|
||
fixup_mask = g.op("And", negative,
|
||
g.op("Not", g.op("Equal", mod, zero)))
|
||
|
||
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
|
||
fixup = g.op("Sub", div, one)
|
||
return g.op("Where", fixup_mask, fixup, div)
|
||
|
||
|
||
@parse_args("v", "i", "i", "none")
|
||
def sort(g, self, dim, decending, out=None):
|
||
return sym_help._sort_helper(g, self, dim, decending=decending, out=out)
|
||
|
||
|
||
@parse_args("v", "v", "i", "i", "i", "none")
|
||
def topk(g, self, k, dim, largest, sorted, out=None):
|
||
return sym_help._topk_helper(g, self, k, dim, largest=largest, sorted=sorted, out=out)
|
||
|
||
|
||
def _max_pool(name, tuple_fn, ndims, return_indices):
|
||
@parse_args("v", "is", "is", "is", "is", "i")
|
||
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
|
||
if not stride:
|
||
stride = kernel_size
|
||
kwargs = {
|
||
"kernel_shape_i": tuple_fn(kernel_size),
|
||
"pads_i": tuple_fn(padding) * 2,
|
||
"strides_i": tuple_fn(stride),
|
||
"ceil_mode_i": ceil_mode,
|
||
}
|
||
if set(tuple_fn(dilation)) != {1}:
|
||
kwargs["dilations_i"] = tuple_fn(dilation)
|
||
# easy but hacky way to get flattened indices values
|
||
# to be used to convert the indices values to non-flattened.
|
||
# In ONNX the indices are computed as a flatten 1-D tensor,
|
||
# so the values in indices are in [0, N x C x D1 x ... x Dn).
|
||
# To convert the indices to the same format used by Pytorch,
|
||
# we first execute a maxpool with a kernel and stride of 1 on the same input.
|
||
# This will result in a tensor of indices in which each index will have it's own value.
|
||
# Using this tensor as a reference, we extract the first index of each axis and subtract
|
||
# it from each index of this axis in the indices to convert.
|
||
# This step will result in a tensor were each dimension has values of indices within
|
||
# the dimension it is in.
|
||
# For more information :
|
||
# https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
|
||
if return_indices:
|
||
r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
|
||
_, flattened_indices = g.op("MaxPool", input, outputs=2,
|
||
kernel_shape_i=[1 for _ in range(ndims)],
|
||
strides_i=[1 for _ in range(ndims)])
|
||
# convert indices to have non-flattened indices values
|
||
from torch.onnx.symbolic_opset9 import sub
|
||
s = sym_help._slice_helper(g, flattened_indices, axes=[2 + i for i in range(ndims)],
|
||
starts=tuple_fn(0), ends=tuple_fn(1))
|
||
indices = sub(g, indices, s)
|
||
return r, indices
|
||
else:
|
||
r = g.op("MaxPool", input, outputs=1, **kwargs)
|
||
return r
|
||
|
||
return symbolic_fn
|
||
|
||
|
||
max_pool1d = _max_pool("max_pool1d", _single, 1, return_indices=False)
|
||
max_pool2d = _max_pool("max_pool2d", _pair, 2, return_indices=False)
|
||
max_pool3d = _max_pool("max_pool3d", _triple, 3, return_indices=False)
|
||
max_pool1d_with_indices = _max_pool("max_pool1d_with_indices", _single, 1, return_indices=True)
|
||
max_pool2d_with_indices = _max_pool("max_pool2d_with_indices", _pair, 2, return_indices=True)
|
||
max_pool3d_with_indices = _max_pool("max_pool3d_with_indices", _triple, 3, return_indices=True)
|
||
|
||
|
||
def _avg_pool(name, tuple_fn):
|
||
@parse_args("v", "is", "is", "is", "i", "i", "none")
|
||
def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override=None):
|
||
if not stride:
|
||
stride = kernel_size
|
||
padding = sym_help._avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, name)
|
||
if count_include_pad:
|
||
input = g.op("Pad", input,
|
||
pads_i=((0,) * 2 + padding) * 2,
|
||
mode_s="constant",
|
||
value_f=0.)
|
||
padding = (0,) * len(padding)
|
||
output = g.op("AveragePool", input,
|
||
kernel_shape_i=tuple_fn(kernel_size),
|
||
strides_i=tuple_fn(stride),
|
||
pads_i=padding * 2,
|
||
ceil_mode_i=ceil_mode)
|
||
return output
|
||
return symbolic_fn
|
||
|
||
|
||
avg_pool1d = _avg_pool("avg_pool1d", _single)
|
||
avg_pool2d = _avg_pool("avg_pool2d", _pair)
|
||
avg_pool3d = _avg_pool("avg_pool3d", _triple)
|
||
|
||
|
||
def _interpolate(name, dim, interpolate_mode):
|
||
def symbolic_fn(g, input, output_size, *args):
|
||
scales, align_corners = sym_help._get_interpolate_attributes(g, interpolate_mode, args)
|
||
sym_help._interpolate_warning(interpolate_mode)
|
||
align_corners = sym_help._maybe_get_scalar(align_corners)
|
||
if align_corners:
|
||
return _unimplemented(name, "align_corners == True")
|
||
if scales is None:
|
||
scales = sym_help._interpolate_size_to_scales(g, input, output_size, dim)
|
||
return g.op("Resize", input, scales, mode_s=interpolate_mode)
|
||
return symbolic_fn
|
||
|
||
|
||
upsample_nearest1d = _interpolate("upsample_nearest1d", 3, "nearest")
|
||
upsample_nearest2d = _interpolate("upsample_nearest2d", 4, "nearest")
|
||
upsample_nearest3d = _interpolate("upsample_nearest3d", 5, "nearest")
|
||
upsample_linear1d = _interpolate("upsample_linear1d", 3, "linear")
|
||
upsample_bilinear2d = _interpolate("upsample_bilinear2d", 4, "linear")
|
||
upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear")
|
||
|
||
def __interpolate(g, input, size, scale_factor, mode , align_corners, recompute_scale_factor, antialias):
|
||
scales, mode = sym_help._interpolate_get_scales_and_mode(g, input, size, scale_factor,
|
||
mode , align_corners)
|
||
return g.op("Resize", input, scales, mode_s=mode)
|
||
|
||
|
||
def _slice(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
|
||
if dynamic_slice:
|
||
starts = sym_help._unsqueeze_helper(g, starts, [0])
|
||
ends = sym_help._unsqueeze_helper(g, ends, [0])
|
||
if isinstance(axes, int):
|
||
axes = g.op("Constant", value_t=torch.tensor(axes))
|
||
axes = sym_help._unsqueeze_helper(g, axes, [0])
|
||
else:
|
||
assert len(starts) == len(ends)
|
||
assert len(starts) == len(axes)
|
||
assert steps is None or len(starts) == len(steps)
|
||
if len(starts) == 1 and starts[0] == 0 and ends[0] == 9223372036854775807\
|
||
and (steps is None or (len(steps) == 1 and steps[0] == 1)):
|
||
return input
|
||
axes = g.op("Constant", value_t=torch.tensor(axes))
|
||
starts = g.op("Constant", value_t=torch.tensor(starts))
|
||
ends = g.op("Constant", value_t=torch.tensor(ends))
|
||
if steps is None:
|
||
return g.op("Slice", input, starts, ends, axes)
|
||
steps = g.op("Constant", value_t=torch.tensor(steps))
|
||
return g.op("Slice", input, starts, ends, axes, steps)
|
||
|
||
|
||
def slice(g, self, *args):
|
||
if len(args) == 4:
|
||
# aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor
|
||
dim, start, end, step = args
|
||
elif len(args) == 3:
|
||
# aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[]
|
||
start, end, step = args
|
||
dim = 0
|
||
else:
|
||
raise NotImplementedError("Unknown aten::slice signature")
|
||
is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == "NoneType"
|
||
is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == "NoneType"
|
||
is_start_onnx_const = start.node().kind() == "onnx::Constant"
|
||
is_end_onnx_const = end.node().kind() == "onnx::Constant"
|
||
step = sym_help._parse_arg(step, "i")
|
||
if (not is_start_none and not is_start_onnx_const) or \
|
||
(not isinstance(end, int) and not is_end_none and not is_end_onnx_const) or \
|
||
(not isinstance(dim, int) and dim.node().kind() != "onnx::Constant"):
|
||
dynamic_slice = True
|
||
if is_start_none:
|
||
start = g.op("Constant", value_t=torch.tensor(0))
|
||
if is_end_none:
|
||
end = g.op("Constant", value_t=torch.tensor(9223372036854775807))
|
||
else:
|
||
start = [0 if is_start_none else sym_help._parse_arg(start, "i")]
|
||
end = [9223372036854775807 if is_end_none else sym_help._parse_arg(end, "i")]
|
||
dim = [sym_help._parse_arg(dim, "i")]
|
||
dynamic_slice = False
|
||
return sym_help._slice_helper(g, self, axes=dim, starts=start, ends=end, steps=[step], dynamic_slice=dynamic_slice)
|
||
|
||
|
||
@parse_args("v", "is")
|
||
def flip(g, input, dims):
|
||
return sym_help._slice_helper(g, input, axes=dims,
|
||
starts=[-1] * len(dims),
|
||
ends=[-9223372036854775807] * len(dims),
|
||
steps=[-1] * len(dims))
|
||
|
||
|
||
def fmod(g, input, other):
|
||
return g.op("Mod", input, other, fmod_i=1)
|
||
|
||
|
||
@parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
|
||
def embedding_bag(g,
|
||
embedding_matrix,
|
||
indices,
|
||
offsets,
|
||
scale_grad_by_freq,
|
||
mode,
|
||
sparse,
|
||
per_sample_weights,
|
||
include_last_offset,
|
||
padding_idx):
|
||
if scale_grad_by_freq and sym_help._training_mode:
|
||
return sym_help._onnx_unsupported("embedding_bag with scale_grad_by_freq for training mode")
|
||
if padding_idx is not None and padding_idx >= 0:
|
||
raise RuntimeError("embedding_bag with padding_idx")
|
||
from torch.onnx.symbolic_opset9 import select
|
||
import warnings
|
||
warnings.warn("Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. "
|
||
"Please use opset 11 or higher to export model for dynamic input shape.'")
|
||
offsets_dim_0 = sym_help._get_tensor_dim_size(offsets, 0)
|
||
if offsets_dim_0 is not None:
|
||
if include_last_offset:
|
||
offset_len = offsets_dim_0 - 1
|
||
offsets_extended = offsets
|
||
else:
|
||
offset_len = offsets_dim_0
|
||
offsets_extended = [offsets, g.op("Constant", value_t=torch.tensor([maxsize]))]
|
||
offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
|
||
list_ = []
|
||
for i in range(offset_len):
|
||
start_ = sym_help._unsqueeze_helper(g, select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), [0])
|
||
end_ = sym_help._unsqueeze_helper(g, select(g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)), [0])
|
||
axes_ = g.op("Constant", value_t=torch.tensor([0]))
|
||
indices_row = g.op("Slice", indices, start_, end_, axes_)
|
||
|
||
embeddings = g.op("Gather", embedding_matrix, indices_row)
|
||
if not sym_help._is_none(per_sample_weights):
|
||
per_sample_weights_row = g.op("Slice", per_sample_weights, start_, end_, axes_)
|
||
per_sample_weights_row = sym_help._unsqueeze_helper(g, per_sample_weights_row, [1])
|
||
embeddings = g.op("Mul", embeddings, per_sample_weights_row)
|
||
if mode == 0:
|
||
embeddings = sym_help._reducesum_helper(g, embeddings, axes_i=[0], keepdims_i=0)
|
||
elif mode == 1:
|
||
embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
|
||
else:
|
||
embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
|
||
|
||
embeddings = sym_help._unsqueeze_helper(g, embeddings, [0])
|
||
list_.append(embeddings)
|
||
|
||
output = g.op("Concat", *list_, axis_i=0)
|
||
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
|
||
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
|
||
return output, None, None, None
|
||
else:
|
||
return sym_help._onnx_unsupported("embedding_bag with unknown shape of offsets for opset 10 is not supported. "
|
||
"please use opset 11 or higher.")
|
||
|
||
|
||
@parse_args("v", "t", "i", "i", "i")
|
||
def fake_quantize_per_tensor_affine(g, inputs, scale, zero_point, quant_min=-128, quant_max=127):
|
||
if quant_min not in [0, -128] or quant_max not in [127, 255]:
|
||
raise RuntimeError(
|
||
"ONNX defines [0, 255] for quint8 and [-128, 127] for qint8, got [{}, {}]".format(quant_min, quant_max))
|
||
scale = scale.float().data # Avoid exporter generating double type
|
||
zero_point_dtype = torch.int8 if quant_min == -128 else torch.uint8
|
||
zero_point = torch.tensor(zero_point, dtype=zero_point_dtype) # ONNX requires zero_point to be tensor
|
||
return g.op("DequantizeLinear", g.op("QuantizeLinear", inputs, scale, zero_point), scale, zero_point)
|
||
|
||
|
||
def isinf(g, input):
|
||
from torch.onnx.symbolic_opset9 import _cast_Double # type: ignore[attr-defined]
|
||
return g.op("IsInf", _cast_Double(g, input, False))
|
||
|
||
|
||
def isfinite(g, input):
|
||
from torch.onnx.symbolic_opset9 import isnan, __not_, __or_
|
||
inf_node = isinf(g, input)
|
||
nan_node = isnan(g, input)
|
||
return __not_(g, __or_(g, inf_node, nan_node))
|
||
|
||
|
||
# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
|
||
class Quantized:
|
||
domain = "quantized"
|
||
|
||
# DequantizeLinear was added in opset version 10.
|
||
@staticmethod
|
||
def linear(g, input_original, weight, weight_scale, weight_zero_point, bias, op_scale, op_zero_point):
|
||
input_value, input_scale, input_zero_point = sym_help._unpack_tuple(input_original)
|
||
# From https://pytorch.org/docs/master/generated/torch.nn.quantized.functional.linear.html
|
||
# input (Tensor) – Quantized input of type torch.quint8
|
||
input_type_dq = torch.onnx.TensorProtoDataType.UINT8
|
||
input_value = g.op("Cast", input_value, to_i=input_type_dq)
|
||
input_scale = g.op("Cast", input_scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
|
||
input_zero_point = g.op("Cast", input_zero_point, to_i=input_type_dq)
|
||
input = g.op("DequantizeLinear", input_value, input_scale, input_zero_point)
|
||
# weight (Tensor) – Quantized weight of type torch.qint8
|
||
weight_type_dq = torch.onnx.TensorProtoDataType.INT8
|
||
weight = g.op("Cast", weight, to_i=weight_type_dq)
|
||
weight_scale = g.op("Cast", weight_scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
|
||
weight_zero_point = g.op("Cast", weight_zero_point, to_i=weight_type_dq)
|
||
weight = g.op("DequantizeLinear", weight, weight_scale, weight_zero_point)
|
||
# bias (Tensor) – None or fp32 bias of type torch.float
|
||
bias = g.op("Cast", bias, to_i=torch.onnx.TensorProtoDataType.FLOAT)
|
||
output = linear(g, input, weight, bias)
|
||
|
||
if op_scale is None:
|
||
op_scale = input_scale
|
||
elif op_scale.type().scalarType() != "Float":
|
||
op_scale = g.op("Cast", op_scale, to_i=sym_help.cast_pytorch_to_onnx["Float"])
|
||
|
||
if op_zero_point is None:
|
||
op_zero_point = input_zero_point
|
||
elif op_zero_point.type().scalarType() != "Byte":
|
||
op_zero_point = g.op("Cast", op_zero_point, to_i=sym_help.cast_pytorch_to_onnx["Byte"])
|
||
output = g.op("QuantizeLinear", output, op_scale, op_zero_point)
|
||
return g.op("prim::TupleConstruct", output, op_scale, op_zero_point)
|