mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
In https://github.com/pytorch/pytorch/pull/106270, the solution managed to solve the [`ceil_model` corner issue](https://github.com/onnx/onnx/issues/5711) with the usage of `get_pool_ceil_padding`. However, padding the ceil in converter side only works when we already know the input shapes, therefore, a regression happens when users want to do dynamic inputs. This PR provides (1) refactor codes with torchlib implementation, (2) add dynamic shapes test, and (3) disable the corner tests with comments saying re-enable it when the [real fix from ONNX](https://github.com/onnx/onnx/pull/5741) is merged. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113318 Approved by: https://github.com/thiagocrepaldi
1234 lines
37 KiB
Python
1234 lines
37 KiB
Python
from __future__ import annotations
|
||
|
||
import functools
|
||
import sys
|
||
import warnings
|
||
from typing import List, Optional, Sequence, Tuple, Union
|
||
|
||
import torch
|
||
import torch._C._onnx as _C_onnx
|
||
import torch.onnx
|
||
from torch import _C
|
||
|
||
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
|
||
from torch.onnx import (
|
||
_constants,
|
||
_type_utils,
|
||
errors,
|
||
symbolic_helper,
|
||
symbolic_opset9 as opset9,
|
||
)
|
||
from torch.onnx._globals import GLOBALS
|
||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||
|
||
# EDITING THIS FILE? READ THIS FIRST!
|
||
# see Note [Edit Symbolic Files] in README.md
|
||
|
||
# This file exports ONNX ops for opset 10
|
||
# Opset 10 is supported by ONNX release 1.5.0
|
||
# release on 04/24/19
|
||
|
||
|
||
__all__ = [
|
||
"dequantize",
|
||
"div",
|
||
"embedding_bag",
|
||
"fake_quantize_per_tensor_affine",
|
||
"flip",
|
||
"fmod",
|
||
"isfinite",
|
||
"isinf",
|
||
"nan_to_num",
|
||
"quantize_per_tensor",
|
||
"quantized_add_relu",
|
||
"quantized_add",
|
||
"quantized_cat",
|
||
"quantized_conv1d_relu",
|
||
"quantized_conv2d_relu",
|
||
"quantized_conv3d_relu",
|
||
"quantized_conv1d",
|
||
"quantized_conv2d",
|
||
"quantized_conv3d",
|
||
"quantized_conv_transpose1d",
|
||
"quantized_conv_transpose2d",
|
||
"quantized_conv_transpose3d",
|
||
"quantized_group_norm",
|
||
"quantized_hardswish",
|
||
"quantized_instance_norm",
|
||
"quantized_layer_norm",
|
||
"quantized_leaky_relu",
|
||
"quantized_linear",
|
||
"quantized_linear_relu",
|
||
"quantized_mul",
|
||
"quantized_sigmoid",
|
||
"slice",
|
||
"sort",
|
||
"topk",
|
||
]
|
||
|
||
|
||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10)
|
||
|
||
|
||
def _apply_params(*args, **kwargs):
|
||
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
|
||
|
||
def _apply(fn):
|
||
return fn(*args, **kwargs)
|
||
|
||
return _apply
|
||
|
||
|
||
@_onnx_symbolic("aten::div")
|
||
@_beartype.beartype
|
||
def div(g: jit_utils.GraphContext, self, other, *args):
|
||
if len(args) == 0:
|
||
return opset9.true_divide(g, self, other)
|
||
else:
|
||
return _div_rounding_mode(g, self, other, *args)
|
||
|
||
|
||
@symbolic_helper.parse_args("v", "v", "s")
|
||
@_beartype.beartype
|
||
def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode):
|
||
if rounding_mode == "floor":
|
||
return _floor_divide(g, self, other)
|
||
else:
|
||
return opset9._div_rounding_mode(g, self, other, rounding_mode)
|
||
|
||
|
||
@_onnx_symbolic("aten::_floor_divide")
|
||
@_beartype.beartype
|
||
def _floor_divide(g: jit_utils.GraphContext, self, other):
|
||
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
|
||
out = opset9.true_divide(g, self, other)
|
||
return g.op("Floor", out)
|
||
else:
|
||
# Integer division does trunction rounding
|
||
div = g.op("Div", self, other)
|
||
# Division is negative if: self < 0 != other < 0
|
||
zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
|
||
negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero))
|
||
|
||
# For negative numbers with self % other != 0, subtract 1 to round down instead of up
|
||
mod = g.op("Mod", self, other, fmod_i=0)
|
||
fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero)))
|
||
|
||
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
|
||
fixup = g.op("Sub", div, one)
|
||
return g.op("Where", fixup_mask, fixup, div)
|
||
|
||
|
||
@_onnx_symbolic("aten::sort")
|
||
@symbolic_helper.parse_args("v", "i", "i", "none")
|
||
@_beartype.beartype
|
||
def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
|
||
return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
|
||
|
||
|
||
@_onnx_symbolic("aten::topk")
|
||
@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
|
||
@_beartype.beartype
|
||
def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
|
||
return symbolic_helper._topk_helper(
|
||
g, self, k, dim, largest=largest, sorted=sorted, out=out
|
||
)
|
||
|
||
|
||
def _aten_max_pool_onnx(
|
||
g: jit_utils.GraphContext,
|
||
self: _C.Value,
|
||
kernel_shape: Sequence[int],
|
||
strides: Sequence[int],
|
||
pads: Sequence[int],
|
||
dilations: Sequence[int],
|
||
ceil_mode: bool,
|
||
unbatched_rank: int,
|
||
) -> _C.Value:
|
||
self_rank = g.op("Size", g.op("Shape", self))
|
||
if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1
|
||
self = g.op(
|
||
"Unsqueeze",
|
||
self,
|
||
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
|
||
)
|
||
|
||
pool_result, _ = g.op(
|
||
"MaxPool",
|
||
self,
|
||
outputs=2,
|
||
ceil_mode_i=ceil_mode,
|
||
dilations_i=dilations,
|
||
kernel_shape_i=kernel_shape,
|
||
pads_i=pads,
|
||
strides_i=strides,
|
||
)
|
||
|
||
if self_rank == unbatched_rank:
|
||
pool_result = g.op(
|
||
"Squeeze",
|
||
pool_result,
|
||
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
|
||
)
|
||
|
||
return pool_result
|
||
|
||
|
||
# For MaxPool
|
||
def _adjust_attributes_of_max_pool(
|
||
expand_size: int,
|
||
kernel_size: Union[Sequence[int], int],
|
||
stride: Union[Sequence[int], int],
|
||
padding: Union[Sequence[int], int],
|
||
dilation: Union[Sequence[int], int],
|
||
) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]:
|
||
"""Adjust attributes of avg_pool to match ONNX specification."""
|
||
|
||
if isinstance(dilation, int):
|
||
dilation = [dilation] * expand_size
|
||
|
||
if isinstance(kernel_size, int):
|
||
kernel_shape = [kernel_size] * expand_size
|
||
else:
|
||
kernel_shape = kernel_size # type: ignore[assignment]
|
||
|
||
if isinstance(padding, int):
|
||
pads = [padding] * expand_size * 2 # type: ignore[operator, assignment]
|
||
elif len(padding) == 1:
|
||
pads = padding * expand_size * 2 # type: ignore[operator, assignment]
|
||
elif len(padding) == 2:
|
||
# 2D padding
|
||
pads = padding * 2 # type: ignore[operator, assignment]
|
||
elif len(padding) == 3:
|
||
# 3D padding
|
||
pads = padding * 2 # type: ignore[operator, assignment]
|
||
else:
|
||
# When padding is already done for all dimensions,
|
||
# we don't need to double it
|
||
# eg: (1, 1, 1, 1, 1, 1)
|
||
pads = padding # type: ignore[assignment]
|
||
|
||
if isinstance(stride, int):
|
||
strides = [stride] * expand_size
|
||
elif not stride:
|
||
strides = kernel_shape
|
||
else:
|
||
strides = stride # type: ignore[assignment]
|
||
|
||
return (kernel_shape, strides, pads, dilation)
|
||
|
||
|
||
def _aten_max_pool_with_indices_onnx(
|
||
g: jit_utils.GraphContext,
|
||
self: _C.Value,
|
||
kernel_shape: Sequence[int],
|
||
strides: Sequence[int],
|
||
pads: Sequence[int],
|
||
dilations: Sequence[int],
|
||
ceil_mode: bool,
|
||
unbatched_rank: int,
|
||
n_dims_one: Sequence[int],
|
||
n_dims_zero: Sequence[int],
|
||
n_dims_axes: Sequence[int],
|
||
) -> Tuple[_C.Value, Sequence[int]]:
|
||
self_rank = g.op("Size", g.op("Shape", self))
|
||
if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1
|
||
self = g.op(
|
||
"Unsqueeze",
|
||
self,
|
||
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
|
||
)
|
||
|
||
pool_result, indices = g.op(
|
||
"MaxPool",
|
||
self,
|
||
outputs=2,
|
||
ceil_mode_i=ceil_mode,
|
||
dilations_i=dilations,
|
||
kernel_shape_i=kernel_shape,
|
||
pads_i=pads,
|
||
strides_i=strides,
|
||
)
|
||
_, flatten_indices = g.op(
|
||
"MaxPool",
|
||
self,
|
||
outputs=2,
|
||
dilations_i=dilations,
|
||
kernel_shape_i=n_dims_one,
|
||
strides_i=n_dims_one,
|
||
)
|
||
|
||
ends = g.op("Constant", value_t=torch.tensor(n_dims_one))
|
||
starts = g.op("Constant", value_t=torch.tensor(n_dims_zero))
|
||
axes = g.op("Constant", value_t=torch.tensor(n_dims_axes))
|
||
|
||
delta = g.op("Slice", flatten_indices, starts, ends, axes)
|
||
indices = g.op("Sub", indices, delta)
|
||
|
||
if self_rank == unbatched_rank:
|
||
pool_result = g.op(
|
||
"Squeeze", pool_result, value_t=torch.tensor([0], dtype=torch.int64)
|
||
)
|
||
indices = g.op("Squeeze", indices, value_t=torch.tensor([0], dtype=torch.int64))
|
||
|
||
return (pool_result, indices)
|
||
|
||
|
||
@_onnx_symbolic(
|
||
"aten::max_pool1d",
|
||
decorate=[_apply_params("max_pool1d", 1, return_indices=False)],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::max_pool2d",
|
||
decorate=[_apply_params("max_pool2d", 2, return_indices=False)],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::max_pool3d",
|
||
decorate=[_apply_params("max_pool3d", 3, return_indices=False)],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::max_pool1d_with_indices",
|
||
decorate=[
|
||
_apply_params(
|
||
"max_pool1d_with_indices",
|
||
1,
|
||
return_indices=True,
|
||
)
|
||
],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::max_pool2d_with_indices",
|
||
decorate=[
|
||
_apply_params(
|
||
"max_pool2d_with_indices",
|
||
2,
|
||
return_indices=True,
|
||
)
|
||
],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::max_pool3d_with_indices",
|
||
decorate=[
|
||
_apply_params(
|
||
"max_pool3d_with_indices",
|
||
3,
|
||
return_indices=True,
|
||
)
|
||
],
|
||
)
|
||
@_beartype.beartype
|
||
def _max_pool(name: str, expand_size: int, return_indices: bool):
|
||
@symbolic_helper.quantized_args(True, False, False, False, False, False)
|
||
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
|
||
def symbolic_fn(
|
||
g: jit_utils.GraphContext,
|
||
input: _C.Value,
|
||
kernel_size: Sequence[int],
|
||
stride: Sequence[int],
|
||
padding: Union[int, Sequence[int]],
|
||
dilation: Sequence[int],
|
||
ceil_mode: bool,
|
||
):
|
||
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
|
||
expand_size, kernel_size, stride, padding, dilation
|
||
)
|
||
|
||
if return_indices:
|
||
return _aten_max_pool_with_indices_onnx(
|
||
g,
|
||
input,
|
||
kernel_shape,
|
||
strides,
|
||
pads,
|
||
dilations,
|
||
ceil_mode,
|
||
expand_size + 1,
|
||
([1] * expand_size),
|
||
([0] * expand_size),
|
||
([2 + i for i in range(expand_size)]),
|
||
)
|
||
else:
|
||
return _aten_max_pool_onnx(
|
||
g,
|
||
input,
|
||
kernel_shape,
|
||
strides,
|
||
pads,
|
||
dilations,
|
||
ceil_mode,
|
||
expand_size + 1,
|
||
)
|
||
|
||
return symbolic_fn
|
||
|
||
|
||
# For AvgPool
|
||
def _adjust_attributes_of_avg_pool(
|
||
expand_size: int,
|
||
kernel_size: Union[Sequence[int], int],
|
||
stride: Union[Sequence[int], int],
|
||
padding: Union[Sequence[int], int],
|
||
) -> Tuple[Sequence[int], Sequence[int], Sequence[int]]:
|
||
"""Adjust attributes of avg_pool to match ONNX specification."""
|
||
|
||
if isinstance(kernel_size, int):
|
||
kernel_shape = [kernel_size] * expand_size
|
||
else:
|
||
kernel_shape = kernel_size # type: ignore[assignment]
|
||
|
||
if isinstance(padding, int):
|
||
pads = [padding] * expand_size * 2
|
||
elif len(padding) == 1:
|
||
pads = padding * expand_size * 2 # type: ignore[operator, assignment]
|
||
elif len(padding) == 2:
|
||
pads = padding * expand_size # type: ignore[operator, assignment]
|
||
else:
|
||
pads = padding * 2 # type: ignore[operator, assignment]
|
||
|
||
if isinstance(stride, int):
|
||
strides = [stride] * expand_size
|
||
elif not stride:
|
||
strides = kernel_shape
|
||
else:
|
||
strides = stride # type: ignore[assignment]
|
||
|
||
return (kernel_shape, strides, pads)
|
||
|
||
|
||
@_onnx_symbolic(
|
||
"aten::avg_pool1d",
|
||
decorate=[_apply_params("avg_pool1d", 1)],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::avg_pool2d",
|
||
decorate=[_apply_params("avg_pool2d", 2)],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::avg_pool3d",
|
||
decorate=[_apply_params("avg_pool3d", 3)],
|
||
)
|
||
@_beartype.beartype
|
||
def _avg_pool(name, expand_size):
|
||
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
|
||
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
|
||
@_beartype.beartype
|
||
def symbolic_fn(
|
||
g,
|
||
input: _C.Value,
|
||
kernel_size: Sequence[int],
|
||
stride: Sequence[int],
|
||
padding: Union[int, Sequence[int]],
|
||
ceil_mode: int,
|
||
count_include_pad: int,
|
||
divisor_override=None,
|
||
):
|
||
kernel_shape, strides, pads = _adjust_attributes_of_avg_pool(
|
||
expand_size, kernel_size, stride, padding
|
||
)
|
||
|
||
result = g.op(
|
||
"AveragePool",
|
||
input,
|
||
ceil_mode_i=ceil_mode,
|
||
count_include_pad_i=count_include_pad,
|
||
kernel_shape_i=kernel_shape,
|
||
pads_i=pads,
|
||
strides_i=strides,
|
||
)
|
||
|
||
return result
|
||
|
||
return symbolic_fn
|
||
|
||
|
||
@_onnx_symbolic(
|
||
"aten::upsample_nearest1d",
|
||
decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::upsample_nearest2d",
|
||
decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::upsample_nearest3d",
|
||
decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::upsample_linear1d",
|
||
decorate=[_apply_params("upsample_linear1d", 3, "linear")],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::upsample_bilinear2d",
|
||
decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::upsample_trilinear3d",
|
||
decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
|
||
)
|
||
@_beartype.beartype
|
||
def _interpolate(name, dim, interpolate_mode):
|
||
@symbolic_helper.quantized_args(True, False, False)
|
||
@_beartype.beartype
|
||
def symbolic_fn(g, input, output_size, *args):
|
||
scales, align_corners = symbolic_helper._get_interpolate_attributes(
|
||
g, interpolate_mode, args
|
||
)
|
||
symbolic_helper._interpolate_warning(interpolate_mode)
|
||
align_corners = symbolic_helper._maybe_get_scalar(align_corners)
|
||
if align_corners:
|
||
return symbolic_helper._unimplemented(name, "align_corners == True", input)
|
||
if scales is None:
|
||
scales = symbolic_helper._interpolate_size_to_scales(
|
||
g, input, output_size, dim
|
||
)
|
||
return g.op("Resize", input, scales, mode_s=interpolate_mode)
|
||
|
||
return symbolic_fn
|
||
|
||
|
||
@_onnx_symbolic("aten::__interpolate")
|
||
@_beartype.beartype
|
||
def __interpolate(
|
||
g: jit_utils.GraphContext,
|
||
input,
|
||
size,
|
||
scale_factor,
|
||
mode,
|
||
align_corners,
|
||
recompute_scale_factor,
|
||
antialias,
|
||
):
|
||
scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
|
||
g, input, size, scale_factor, mode, align_corners
|
||
)
|
||
return g.op("Resize", input, scales, mode_s=mode)
|
||
|
||
|
||
@_beartype.beartype
|
||
def _slice(
|
||
g: jit_utils.GraphContext,
|
||
input: torch._C.Value,
|
||
axes: Union[List, torch.Tensor, torch._C.Value],
|
||
starts: Union[List, torch.Tensor, torch._C.Value],
|
||
ends: Union[List, torch.Tensor, torch._C.Value],
|
||
steps: Optional[Union[List, torch.Tensor, torch._C.Value]] = None,
|
||
):
|
||
def is_none_value(value):
|
||
if value is None:
|
||
return True
|
||
return (
|
||
isinstance(value, torch._C.Value)
|
||
and value.node().kind() == "prim::Constant"
|
||
and isinstance(value.type(), _C.NoneType)
|
||
)
|
||
|
||
def to_slice_input(list_or_value, default_value=None):
|
||
# Convert input param into a 1D torch.Value.
|
||
if is_none_value(list_or_value) and default_value is not None:
|
||
list_or_value = [default_value]
|
||
|
||
if isinstance(list_or_value, (list, torch.Tensor)):
|
||
return g.op("Constant", value_t=torch.tensor(list_or_value))
|
||
|
||
rank = symbolic_helper._get_tensor_rank(list_or_value)
|
||
if rank == 0:
|
||
return symbolic_helper._unsqueeze_helper(g, list_or_value, [0])
|
||
if rank == 1:
|
||
return list_or_value
|
||
raise errors.SymbolicValueError(
|
||
f"Rank must be 0 or 1, not {rank}", list_or_value
|
||
)
|
||
|
||
def get_const_value(list_or_value):
|
||
if isinstance(list_or_value, (list, torch.Tensor)):
|
||
if len(list_or_value) == 1:
|
||
return list_or_value[0]
|
||
return None
|
||
return symbolic_helper._maybe_get_const(list_or_value, "i")
|
||
|
||
# Check if slice is a no-op
|
||
if (
|
||
get_const_value(starts) == 0
|
||
and get_const_value(ends) == _constants.INT64_MAX
|
||
and (steps is None or get_const_value(steps) == 1)
|
||
):
|
||
return input
|
||
|
||
axes = to_slice_input(axes)
|
||
starts = to_slice_input(starts, default_value=0)
|
||
ends = to_slice_input(ends, default_value=_constants.INT64_MAX)
|
||
if steps is None:
|
||
return g.op("Slice", input, starts, ends, axes)
|
||
steps = to_slice_input(steps, default_value=1)
|
||
return g.op("Slice", input, starts, ends, axes, steps)
|
||
|
||
|
||
@_onnx_symbolic("aten::slice")
|
||
@_beartype.beartype
|
||
def slice(g: jit_utils.GraphContext, self, *args):
|
||
if len(args) == 4:
|
||
# aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor
|
||
dims, start, end, step = args
|
||
elif len(args) == 3:
|
||
# aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[]
|
||
start, end, step = args
|
||
dims = [0]
|
||
else:
|
||
raise errors.SymbolicValueError("Unknown aten::slice signature", self)
|
||
|
||
return symbolic_helper._slice_helper(
|
||
g,
|
||
self,
|
||
axes=dims,
|
||
starts=start,
|
||
ends=end,
|
||
steps=step,
|
||
)
|
||
|
||
|
||
@_onnx_symbolic("aten::flip")
|
||
@symbolic_helper.parse_args("v", "is")
|
||
@_beartype.beartype
|
||
def flip(g: jit_utils.GraphContext, input, dims):
|
||
return symbolic_helper._slice_helper(
|
||
g,
|
||
input,
|
||
axes=dims,
|
||
starts=[-1] * len(dims),
|
||
ends=[-_constants.INT64_MAX] * len(dims),
|
||
steps=[-1] * len(dims),
|
||
)
|
||
|
||
|
||
@_onnx_symbolic("aten::fmod")
|
||
@_beartype.beartype
|
||
def fmod(g: jit_utils.GraphContext, input, other):
|
||
return g.op("Mod", input, other, fmod_i=1)
|
||
|
||
|
||
@_onnx_symbolic("aten::embedding_bag")
|
||
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
|
||
@_beartype.beartype
|
||
def embedding_bag(
|
||
g: jit_utils.GraphContext,
|
||
embedding_matrix,
|
||
indices,
|
||
offsets,
|
||
scale_grad_by_freq,
|
||
mode,
|
||
sparse,
|
||
per_sample_weights,
|
||
include_last_offset,
|
||
padding_idx,
|
||
):
|
||
if scale_grad_by_freq and GLOBALS.export_training:
|
||
return symbolic_helper._onnx_unsupported(
|
||
"embedding_bag with scale_grad_by_freq for training mode"
|
||
)
|
||
if padding_idx is not None and padding_idx >= 0:
|
||
raise RuntimeError("embedding_bag with padding_idx")
|
||
|
||
warnings.warn(
|
||
"Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. "
|
||
"Please use opset 11 or higher to export model for dynamic input shape.'"
|
||
)
|
||
offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0)
|
||
if offsets_dim_0 is not None:
|
||
if include_last_offset:
|
||
offset_len = offsets_dim_0 - 1
|
||
offsets_extended = offsets
|
||
else:
|
||
offset_len = offsets_dim_0
|
||
offsets_extended = [
|
||
offsets,
|
||
g.op("Constant", value_t=torch.tensor([sys.maxsize])),
|
||
]
|
||
offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
|
||
list_ = []
|
||
for i in range(offset_len):
|
||
start_ = symbolic_helper._unsqueeze_helper(
|
||
g,
|
||
opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)),
|
||
[0],
|
||
)
|
||
end_ = symbolic_helper._unsqueeze_helper(
|
||
g,
|
||
opset9.select(
|
||
g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)
|
||
),
|
||
[0],
|
||
)
|
||
axes_ = g.op("Constant", value_t=torch.tensor([0]))
|
||
indices_row = g.op("Slice", indices, start_, end_, axes_)
|
||
|
||
embeddings = g.op("Gather", embedding_matrix, indices_row)
|
||
if not symbolic_helper._is_none(per_sample_weights):
|
||
per_sample_weights_row = g.op(
|
||
"Slice", per_sample_weights, start_, end_, axes_
|
||
)
|
||
per_sample_weights_row = symbolic_helper._unsqueeze_helper(
|
||
g, per_sample_weights_row, [1]
|
||
)
|
||
embeddings = g.op("Mul", embeddings, per_sample_weights_row)
|
||
if mode == 0:
|
||
embeddings = symbolic_helper._reducesum_helper(
|
||
g, embeddings, axes_i=[0], keepdims_i=0
|
||
)
|
||
elif mode == 1:
|
||
embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
|
||
else:
|
||
embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
|
||
|
||
embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0])
|
||
list_.append(embeddings)
|
||
|
||
output = g.op("Concat", *list_, axis_i=0)
|
||
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
|
||
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
|
||
return output, None, None, None
|
||
else:
|
||
return symbolic_helper._onnx_unsupported(
|
||
"embedding_bag with unknown shape of offsets for opset 10 is not supported. "
|
||
"please use opset 11 or higher."
|
||
)
|
||
|
||
|
||
@_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
|
||
@symbolic_helper.parse_args("v", "v", "v", "i", "i")
|
||
@_beartype.beartype
|
||
def fake_quantize_per_tensor_affine(
|
||
g: jit_utils.GraphContext,
|
||
inputs,
|
||
scale,
|
||
zero_point,
|
||
quant_min=-128,
|
||
quant_max=127,
|
||
):
|
||
# NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127).
|
||
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
|
||
if (quant_min, quant_max) == (0, 127):
|
||
symbolic_helper._onnx_opset_unsupported_detailed(
|
||
"fake_quantize_per_tensor_affine",
|
||
10,
|
||
13,
|
||
"Quantize range (0, 127) not supported, requires opset 13 Clip",
|
||
inputs,
|
||
)
|
||
if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
|
||
raise errors.SymbolicValueError(
|
||
f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
|
||
f"Got ({quant_min}, {quant_max})",
|
||
inputs,
|
||
)
|
||
scale = symbolic_helper._maybe_get_scalar(scale)
|
||
if scale is None:
|
||
symbolic_helper._onnx_opset_unsupported_detailed(
|
||
"fake_quantize_per_tensor_affine",
|
||
10,
|
||
13,
|
||
"Non-constant scale not supported",
|
||
inputs,
|
||
)
|
||
scale = scale.float().data # Avoid exporter generating double type
|
||
if quant_min == 0:
|
||
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
|
||
else:
|
||
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
|
||
return g.op(
|
||
"DequantizeLinear",
|
||
g.op("QuantizeLinear", inputs, scale, zero_point),
|
||
scale,
|
||
zero_point,
|
||
)
|
||
|
||
|
||
@_onnx_symbolic("aten::isinf")
|
||
@_beartype.beartype
|
||
def isinf(g: jit_utils.GraphContext, input):
|
||
return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE))
|
||
|
||
|
||
@_onnx_symbolic("aten::isfinite")
|
||
@_beartype.beartype
|
||
def isfinite(g: jit_utils.GraphContext, input):
|
||
inf_node = isinf(g, input)
|
||
nan_node = opset9.isnan(g, input)
|
||
return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node))
|
||
|
||
|
||
@_onnx_symbolic("aten::quantize_per_tensor")
|
||
@_beartype.beartype
|
||
def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype):
|
||
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
# TODO(justinchuby): Extract all the cast ops into a helper function.
|
||
zero_point = g.op(
|
||
"Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type()
|
||
)
|
||
scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||
return symbolic_helper.quantize_helper(g, input, scale, zero_point)
|
||
|
||
|
||
@_onnx_symbolic("aten::dequantize")
|
||
@_beartype.beartype
|
||
def dequantize(g: jit_utils.GraphContext, input):
|
||
return symbolic_helper.dequantize_helper(g, input)[0]
|
||
|
||
|
||
@_onnx_symbolic("aten::nan_to_num")
|
||
@symbolic_helper.parse_args("v", "f", "f", "f")
|
||
@_beartype.beartype
|
||
def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf):
|
||
# Cannot create a int type tensor with inf/nan values, so we simply
|
||
# return the original tensor
|
||
if not symbolic_helper._is_fp(input):
|
||
return input
|
||
input_dtype = _type_utils.JitScalarType.from_value(input).dtype()
|
||
if nan is None:
|
||
nan = 0.0
|
||
nan_cond = opset9.isnan(g, input)
|
||
nan_result = g.op(
|
||
"Where",
|
||
nan_cond,
|
||
g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)),
|
||
input,
|
||
)
|
||
|
||
# For None values of posinf, neginf we use the greatest/lowest finite
|
||
# value representable by input’s dtype.
|
||
finfo = torch.finfo(input_dtype)
|
||
if posinf is None:
|
||
posinf = finfo.max
|
||
posinf_cond = opset9.logical_and(
|
||
g,
|
||
isinf(g, nan_result),
|
||
opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))),
|
||
)
|
||
nan_posinf_result = g.op(
|
||
"Where",
|
||
posinf_cond,
|
||
g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)),
|
||
nan_result,
|
||
)
|
||
|
||
if neginf is None:
|
||
neginf = finfo.min
|
||
neginf_cond = opset9.logical_and(
|
||
g,
|
||
isinf(g, nan_posinf_result),
|
||
opset9.lt(
|
||
g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0]))
|
||
),
|
||
)
|
||
return g.op(
|
||
"Where",
|
||
neginf_cond,
|
||
g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)),
|
||
nan_posinf_result,
|
||
)
|
||
|
||
|
||
# Quantized symbolics ---------------------------------------------------------
|
||
# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
|
||
# Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were
|
||
# introduced in opset version 10.
|
||
@_onnx_symbolic("quantized::linear")
|
||
@_beartype.beartype
|
||
def quantized_linear(
|
||
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.linear(g, input, weight, bias)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::linear_relu")
|
||
@_beartype.beartype
|
||
def quantized_linear_relu(
|
||
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.linear(g, input, weight, bias)
|
||
output = opset9.relu(g, output)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::add")
|
||
@_beartype.beartype
|
||
def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
|
||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
|
||
|
||
output = opset9.add(g, x, y)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::add_relu")
|
||
@_beartype.beartype
|
||
def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
|
||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
|
||
|
||
output = opset9.add(g, x, y)
|
||
output = opset9.relu(g, output)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::mul")
|
||
@_beartype.beartype
|
||
def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
|
||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
|
||
|
||
output = opset9.mul(g, x, y)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::hardswish")
|
||
@_beartype.beartype
|
||
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
|
||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||
|
||
output = opset9.hardswish(g, x)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::sigmoid")
|
||
@_beartype.beartype
|
||
def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
|
||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||
|
||
output = opset9.sigmoid(g, x)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::leaky_relu")
|
||
@_beartype.beartype
|
||
def quantized_leaky_relu(
|
||
g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point
|
||
):
|
||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||
|
||
output = opset9.leaky_relu(g, x, negative_slope, inplace)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::layer_norm")
|
||
@_beartype.beartype
|
||
def quantized_layer_norm(
|
||
g: jit_utils.GraphContext,
|
||
x,
|
||
normalized_shape,
|
||
weight,
|
||
bias,
|
||
eps,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||
|
||
output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::group_norm")
|
||
@_beartype.beartype
|
||
def quantized_group_norm(
|
||
g: jit_utils.GraphContext,
|
||
x,
|
||
num_groups,
|
||
weight,
|
||
bias,
|
||
eps,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||
|
||
output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::instance_norm")
|
||
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v")
|
||
@_beartype.beartype
|
||
def quantized_instance_norm(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
weight,
|
||
bias,
|
||
eps,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
|
||
output = opset9.instance_norm(
|
||
g, input, weight, bias, None, None, False, 0.0, eps, False
|
||
)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv1d_relu")
|
||
@_beartype.beartype
|
||
def quantized_conv1d_relu(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups)
|
||
output = opset9.relu(g, output)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv2d_relu")
|
||
@_beartype.beartype
|
||
def quantized_conv2d_relu(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
|
||
output = opset9.relu(g, output)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv3d_relu")
|
||
@_beartype.beartype
|
||
def quantized_conv3d_relu(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups)
|
||
output = opset9.relu(g, output)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv1d")
|
||
@_beartype.beartype
|
||
def quantized_conv1d(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv2d")
|
||
@_beartype.beartype
|
||
def quantized_conv2d(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv3d")
|
||
@_beartype.beartype
|
||
def quantized_conv3d(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv_transpose1d")
|
||
@_beartype.beartype
|
||
def quantized_conv_transpose1d(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
output_padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv_transpose2d(
|
||
g, input, weight, bias, stride, padding, output_padding, groups, dilation
|
||
)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv_transpose2d")
|
||
@_beartype.beartype
|
||
def quantized_conv_transpose2d(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
output_padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv_transpose2d(
|
||
g, input, weight, bias, stride, padding, output_padding, groups, dilation
|
||
)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv_transpose3d")
|
||
@_beartype.beartype
|
||
def quantized_conv_transpose3d(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
output_padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv_transpose3d(
|
||
g, input, weight, bias, stride, padding, output_padding, groups, dilation
|
||
)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::cat")
|
||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||
@_beartype.beartype
|
||
def quantized_cat(
|
||
g: jit_utils.GraphContext,
|
||
q_inputs: _C.Value,
|
||
dim: int,
|
||
op_scale: _C.Value,
|
||
op_zero_point: _C.Value,
|
||
) -> _C.Value:
|
||
unpacked_inputs = symbolic_helper._unpack_list(q_inputs)
|
||
dequantized = [
|
||
symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs
|
||
]
|
||
concatenated = g.op("Concat", *dequantized, axis_i=dim)
|
||
return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point)
|