mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Description: - Added antialias flag to interpolate (CPU only) - forward and backward for bilinear mode - added tests ### Benchmarks <details> <summary> Forward pass, CPU. PTH interpolation vs PIL </summary> Cases: - PTH RGB 3 Channels, float32 vs PIL RGB uint8 (apply vs pears) - PTH 1 Channel, float32 vs PIL 1 Channel Float Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112 ``` # OMP_NUM_THREADS=1 python bench_interp_aa_vs_pillow.py Torch config: PyTorch built with: - GCC 9.3 - C++ Version: 201402 - OpenMP 201511 (a.k.a. OpenMP 4.5) - CPU capability usage: AVX2 - CUDA Runtime 11.1 - NVCC architecture flags: -gencode;arch=compute_75,code=sm_75 - CuDNN 8.0.5 - Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.10.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON, Num threads: 1 [------------------------ Downsampling: torch.Size([1, 3, 906, 438]) -> (320, 196) ------------------------] | Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 2.9 | 3.1 channels_last non-contiguous torch.float32 | 2.6 | 3.6 Times are in milliseconds (ms). [------------------------ Downsampling: torch.Size([1, 3, 906, 438]) -> (460, 220) ------------------------] | Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 3.4 | 4.0 channels_last non-contiguous torch.float32 | 3.4 | 4.8 Times are in milliseconds (ms). [------------------------ Downsampling: torch.Size([1, 3, 906, 438]) -> (120, 96) -------------------------] | Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 1.6 | 1.8 channels_last non-contiguous torch.float32 | 1.6 | 1.9 Times are in milliseconds (ms). [----------------------- Downsampling: torch.Size([1, 3, 906, 438]) -> (1200, 196) ------------------------] | Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 9.0 | 11.3 channels_last non-contiguous torch.float32 | 8.9 | 12.5 Times are in milliseconds (ms). [----------------------- Downsampling: torch.Size([1, 3, 906, 438]) -> (120, 1200) ------------------------] | Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 2.1 | 1.8 channels_last non-contiguous torch.float32 | 2.1 | 3.4 Times are in milliseconds (ms). [--------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (320, 196) --------------] | Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 1.2 | 1.0 Times are in milliseconds (ms). [--------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (460, 220) --------------] | Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 1.4 | 1.3 Times are in milliseconds (ms). [--------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (120, 96) ---------------] | Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 719.9 | 599.9 Times are in microseconds (us). [-------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (1200, 196) --------------] | Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 3.7 | 3.5 Times are in milliseconds (ms). [-------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (120, 1200) --------------] | Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 834.4 | 605.7 Times are in microseconds (us). ``` </details> Code is moved from torchvision: https://github.com/pytorch/vision/pull/4208 Pull Request resolved: https://github.com/pytorch/pytorch/pull/65142 Reviewed By: mrshenli Differential Revision: D32432405 Pulled By: jbschlosser fbshipit-source-id: b66c548347f257c522c36105868532e8bc1d4c6d
282 lines
11 KiB
Python
282 lines
11 KiB
Python
|
|
import torch
|
|
import torch.onnx.symbolic_helper as sym_help
|
|
import torch.onnx.symbolic_opset9 as sym_opset9
|
|
|
|
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _block_list_in_opset, _try_get_scalar_type, ScalarType
|
|
from torch.onnx.symbolic_opset9 import _cast_Float # type: ignore[attr-defined]
|
|
|
|
import warnings
|
|
|
|
# Note [ONNX operators that are added/updated from opset 8 to opset 9]
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
# New operators:
|
|
# Compress
|
|
# ConstantOfShape
|
|
# EyeLike
|
|
# MaxUnpool
|
|
# OneHot
|
|
# Sinh
|
|
# Cosh
|
|
# Asinh
|
|
# Acosh
|
|
# Atanh
|
|
# Shrink
|
|
# IsNaN
|
|
# Sign
|
|
# Erf
|
|
# Scatter
|
|
# Where
|
|
# NonZero
|
|
# TfIdfVectorizer
|
|
# MeanVarianceNormalization
|
|
#
|
|
# Updated operators:
|
|
# BatchNormalization: removed spatial attribute.
|
|
# Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported.
|
|
# Cast: more data types{string} supported.
|
|
# Upsample: moved scales from attribute to input.
|
|
# Scan
|
|
|
|
block_listed_operators = [
|
|
"nonzero", "where", "scatter", "scatter_add", "erf", "sign", "isnan", "gather",
|
|
"arange", "masked_fill",
|
|
"index_fill", "index_copy", "repeat_interleave",
|
|
"isnan",
|
|
"any", "all"
|
|
]
|
|
|
|
for block_listed_op in block_listed_operators:
|
|
vars()[block_listed_op] = _block_list_in_opset(block_listed_op)
|
|
|
|
|
|
def _interpolate(name, dim, interpolate_mode):
|
|
def symbolic_fn(g, input, output_size, *args):
|
|
scales, align_corners = sym_help._get_interpolate_attributes(g, interpolate_mode, args)
|
|
sym_help._interpolate_warning(interpolate_mode)
|
|
align_corners = sym_help._maybe_get_scalar(align_corners)
|
|
if align_corners:
|
|
return _unimplemented(name, "align_corners == True")
|
|
output_size = sym_help._maybe_get_const(output_size, "is")
|
|
if sym_help._is_value(output_size):
|
|
return _unimplemented(name, "torch._C.Value (output_size) indexing")
|
|
if scales is None:
|
|
scales = [1. if i < 2 else
|
|
float(output_size[-(dim - i)]) / float(input.type().sizes()[-(dim - i)])
|
|
for i in range(0, dim)]
|
|
return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales)
|
|
return symbolic_fn
|
|
|
|
|
|
upsample_nearest1d = _interpolate("upsample_nearest1d", 3, "nearest")
|
|
upsample_nearest2d = _interpolate("upsample_nearest2d", 4, "nearest")
|
|
upsample_nearest3d = _interpolate("upsample_nearest3d", 5, "nearest")
|
|
upsample_linear1d = _interpolate("upsample_linear1d", 3, "linear")
|
|
upsample_bilinear2d = _interpolate("upsample_bilinear2d", 4, "linear")
|
|
upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear")
|
|
|
|
|
|
def __interpolate(g, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias):
|
|
align_corners = sym_help._maybe_get_const(align_corners, "b")
|
|
if not sym_help._is_none(align_corners) and align_corners:
|
|
return _unimplemented("interpolate", "align_corners == True")
|
|
|
|
if not sym_help._is_none(scale_factor) and sym_help._is_value(scale_factor):
|
|
return _unimplemented("interpolate", "dynamic scales in opset 8")
|
|
|
|
if not sym_help._is_none(size) and sym_help._is_value(size):
|
|
return _unimplemented("interpolate", "dynamic size in opset 8")
|
|
|
|
scales, mode = sym_help._interpolate_get_scales_and_mode(g, input, size, scale_factor,
|
|
mode , align_corners)
|
|
return g.op("Upsample", input, mode_s=mode, scales_f=scales)
|
|
|
|
|
|
# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation
|
|
# issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which
|
|
# is lost after casting.
|
|
def _try_cast_integer_to_float(g, *args):
|
|
floating_scalar_types = ["Half", "Float", "Double"]
|
|
old_type = None
|
|
# Cast the input tensor to Float if its scalarType is known and is not floating number.
|
|
# If casting is performed, return the old scalarType, otherwise return None.
|
|
arg0_type = args[0].type().scalarType()
|
|
if arg0_type is not None:
|
|
old_type = arg0_type
|
|
if old_type not in floating_scalar_types:
|
|
args = tuple(_cast_Float(g, arg, False) for arg in args)
|
|
else:
|
|
return (None,) + args
|
|
else:
|
|
warnings.warn("Only floating datatype is supported for these operators: "
|
|
"{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause "
|
|
"the onnx model to be incorrect, if inputs have integer datatypes.")
|
|
return (old_type,) + args
|
|
|
|
|
|
def _cast_to_type(g, input, to_type):
|
|
if to_type is None:
|
|
return input
|
|
return getattr(sym_opset9, "_cast_{}".format(to_type))(g, input, False)
|
|
|
|
|
|
def _comparison_operator(g, input, other, op_name):
|
|
other = sym_help._maybe_get_scalar(other)
|
|
other = sym_help._if_scalar_type_as(g, other, input)
|
|
_, input, other = _try_cast_integer_to_float(g, input, other)
|
|
return g.op(op_name, input, other)
|
|
|
|
|
|
# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten},
|
|
# integer input type not supported in opset8. Cast to float if possible.
|
|
def gt(g, input, other):
|
|
return _comparison_operator(g, input, other, "Greater")
|
|
|
|
|
|
def lt(g, input, other):
|
|
return _comparison_operator(g, input, other, "Less")
|
|
|
|
|
|
def bmm(g, self, other):
|
|
if _try_get_scalar_type(self):
|
|
old_type, self, other = _try_cast_integer_to_float(g, self, other)
|
|
return _cast_to_type(g, g.op("MatMul", self, other), old_type)
|
|
else:
|
|
return g.op("MatMul", self, other)
|
|
|
|
|
|
def matmul(g, self, other):
|
|
return bmm(g, self, other)
|
|
|
|
|
|
def prelu(g, self, weight):
|
|
self_rank = sym_help._get_tensor_rank(self)
|
|
if self_rank is not None and self_rank > 2:
|
|
weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1)))
|
|
if _try_get_scalar_type(self):
|
|
old_type, self, weight = _try_cast_integer_to_float(g, self, weight)
|
|
return _cast_to_type(g, g.op("PRelu", self, weight), old_type)
|
|
else:
|
|
return g.op("PRelu", self, weight)
|
|
|
|
|
|
def mm(g, self, other):
|
|
# Create a dummy C tensor. Only needed for API purposes, the value is
|
|
# since beta = 0
|
|
ty = sym_help._try_get_scalar_type(self, other).lower()
|
|
C = g.constant(0, [1], ty)
|
|
if _try_get_scalar_type(self):
|
|
old_type, self, other, C = _try_cast_integer_to_float(g, self, other, C)
|
|
return _cast_to_type(g, g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0), old_type)
|
|
else:
|
|
return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
|
|
|
|
|
|
@parse_args("v", "v", "v", "t", "t")
|
|
def addmm(g, self, mat1, mat2, beta, alpha):
|
|
if _try_get_scalar_type(self):
|
|
old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2)
|
|
return _cast_to_type(
|
|
g, g.op("Gemm", mat1, mat2, self,
|
|
beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha)), old_type)
|
|
else:
|
|
return g.op("Gemm", mat1, mat2, self, beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha))
|
|
|
|
|
|
def flatten(g, input, start_dim, end_dim):
|
|
start_dim_i = sym_help._get_const(start_dim, "i", "start_dim")
|
|
end_dim_i = sym_help._get_const(end_dim, "i", "end_dim")
|
|
|
|
dim = input.type().dim()
|
|
if end_dim_i < 0 :
|
|
end_dim_i = dim + end_dim_i
|
|
# use ONNX's Flatten operator for cases where the output shape is 2D
|
|
if start_dim_i == 1 and end_dim_i == dim - 1 :
|
|
if _try_get_scalar_type(input):
|
|
old_type, input = _try_cast_integer_to_float(g, input)
|
|
return _cast_to_type(g, g.op("Flatten", input, axis_i=start_dim_i), old_type)
|
|
else:
|
|
return g.op("Flatten", input, axis_i=start_dim_i)
|
|
if start_dim_i == 0 and end_dim_i == dim - 2 :
|
|
if _try_get_scalar_type(input):
|
|
old_type, input = _try_cast_integer_to_float(g, input)
|
|
return _cast_to_type(g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type)
|
|
else:
|
|
return g.op("Flatten", input, axis_i=end_dim_i + 1)
|
|
|
|
return sym_opset9.flatten(g, input, start_dim, end_dim)
|
|
|
|
|
|
def _constant_fill(g, sizes, dtype, const_value):
|
|
if dtype is None:
|
|
dtype = ScalarType.FLOAT
|
|
if not sym_help.scalar_type_to_pytorch_type[dtype].is_floating_point:
|
|
result = g.op(
|
|
"ConstantFill", sizes, dtype_i=sym_help.cast_pytorch_to_onnx["Float"], input_as_shape_i=1, value_f=const_value)
|
|
return sym_help._cast_func_template(sym_help.scalar_type_to_onnx[dtype], g, result, None)
|
|
else:
|
|
return g.op("ConstantFill", sizes, dtype_i=sym_help.scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=const_value)
|
|
|
|
@parse_args("v", "i", "v", "v", "v", "v")
|
|
def empty(g, sizes, dtype, layout, device, pin_memory=False, memory_format=None):
|
|
return zeros(g, sizes, dtype, layout, device, pin_memory)
|
|
|
|
|
|
@parse_args("v", "i", "v", "v", "v", "v")
|
|
def empty_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None):
|
|
return zeros_like(g, input, dtype, layout, device, pin_memory)
|
|
|
|
@parse_args("v", "i", "v", "v", "v")
|
|
def zeros(g, sizes, dtype, layout, device, pin_memory=False):
|
|
# NOTE: no way to set device and layout in ONNX, so we ignore it
|
|
return _constant_fill(g, sizes, dtype, 0)
|
|
|
|
|
|
@parse_args("v", "i", "v", "v", "v", "v")
|
|
def zeros_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None):
|
|
shape = g.op("Shape", input)
|
|
return _constant_fill(g, shape, dtype, 0)
|
|
|
|
|
|
@parse_args("v", "i", "v", "v", "v")
|
|
def ones(g, sizes, dtype, layout, device, pin_memory=False):
|
|
return _constant_fill(g, sizes, dtype, 1)
|
|
|
|
|
|
@parse_args("v", "i", "v", "v", "v", "v")
|
|
def ones_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None):
|
|
shape = g.op("Shape", input)
|
|
return _constant_fill(g, shape, dtype, 1)
|
|
|
|
|
|
def full(g, sizes, value, dtype, layout, device, pin_memory=False):
|
|
const_value = sym_help._maybe_get_const(value, "t")
|
|
if sym_help._is_value(const_value):
|
|
tmp = zeros(g, sizes, dtype, layout, device)
|
|
return sym_opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
|
|
else:
|
|
dtype = sym_help._get_const(dtype, "i", "dtype")
|
|
return _constant_fill(g, sizes, dtype, const_value)
|
|
|
|
|
|
@parse_args("v", "f", "i", "v", "v", "v", "v")
|
|
def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False, memory_format=None):
|
|
shape = g.op("Shape", input)
|
|
return _constant_fill(g, shape, dtype, fill_value)
|
|
|
|
|
|
def repeat(g, self, repeats):
|
|
if not sym_help._is_value(repeats):
|
|
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
|
|
if sym_help._is_packed_list(repeats):
|
|
repeat_size_len = len(sym_help._unpack_list(repeats))
|
|
else:
|
|
const_repeats = sym_help._maybe_get_const(repeats, "is")
|
|
repeat_size_len = len(const_repeats)
|
|
if self.isCompleteTensor():
|
|
sizes = self.type().sizes()
|
|
diff_dims = repeat_size_len - len(sizes)
|
|
if diff_dims > 0:
|
|
self = sym_opset9.view(g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes)))
|
|
return g.op("Tile", self, repeats)
|