[ONNX] Use human readable enum for dtype scalars (#66822) (#67807)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67807

Also make quoting of string literals consistent.

Test Plan: Imported from OSS

Reviewed By: msaroufim

Differential Revision: D32181309

Pulled By: malfet

fbshipit-source-id: e1053701e3589f0310d8b5ef920359c03c6713f0
This commit is contained in:
Gary Miguel 2021-11-08 14:29:12 -08:00 committed by Facebook GitHub Bot
parent 958d517643
commit eb22d06e5e
6 changed files with 91 additions and 70 deletions

View File

@ -1,4 +1,4 @@
import enum
import torch
import warnings
import inspect
@ -911,6 +911,27 @@ scalar_name_to_pytorch = {
}
class ScalarType(enum.IntEnum):
"""A human-readable name for a key into scalar_type_to_pytorch_type."""
UINT8 = 0
INT8 = enum.auto()
SHORT = enum.auto()
INT = enum.auto()
INT64 = enum.auto()
HALF = enum.auto()
FLOAT = enum.auto()
DOUBLE = enum.auto()
COMPLEX32 = enum.auto()
COMPLEX64 = enum.auto()
COMPLEX128 = enum.auto()
BOOL = enum.auto()
QINT8 = enum.auto()
QUINT8 = enum.auto()
QINT32 = enum.auto()
BFLOAT16 = enum.auto()
# This indicates each scalar type's corresponding
# torch type. Related source:
# https://github.com/pytorch/pytorch/blob/da7468853ae322252270bbb58032668bd21b7457/c10/core/ScalarType.h

View File

@ -205,23 +205,23 @@ def slice(g, self, *args):
dim = 0
else:
raise NotImplementedError("Unknown aten::slice signature")
is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == 'NoneType'
is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == 'NoneType'
is_start_onnx_const = start.node().kind() == 'onnx::Constant'
is_end_onnx_const = end.node().kind() == 'onnx::Constant'
step = sym_help._parse_arg(step, 'i')
is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == "NoneType"
is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == "NoneType"
is_start_onnx_const = start.node().kind() == "onnx::Constant"
is_end_onnx_const = end.node().kind() == "onnx::Constant"
step = sym_help._parse_arg(step, "i")
if (not is_start_none and not is_start_onnx_const) or \
(not isinstance(end, int) and not is_end_none and not is_end_onnx_const) or \
(not isinstance(dim, int) and dim.node().kind() != 'onnx::Constant'):
(not isinstance(dim, int) and dim.node().kind() != "onnx::Constant"):
dynamic_slice = True
if is_start_none:
start = g.op("Constant", value_t=torch.tensor(0))
if is_end_none:
end = g.op("Constant", value_t=torch.tensor(9223372036854775807))
else:
start = [0 if is_start_none else sym_help._parse_arg(start, 'i')]
end = [9223372036854775807 if is_end_none else sym_help._parse_arg(end, 'i')]
dim = [sym_help._parse_arg(dim, 'i')]
start = [0 if is_start_none else sym_help._parse_arg(start, "i")]
end = [9223372036854775807 if is_end_none else sym_help._parse_arg(end, "i")]
dim = [sym_help._parse_arg(dim, "i")]
dynamic_slice = False
return sym_help._slice_helper(g, self, axes=dim, starts=start, ends=end, steps=[step], dynamic_slice=dynamic_slice)

View File

@ -7,7 +7,7 @@ import torch
import torch.onnx.symbolic_helper as sym_help
import warnings
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _is_tensor_list
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _is_tensor_list, ScalarType
from torch.onnx.symbolic_opset9 import expand, unused, mul
from torch.nn.modules.utils import _single, _pair, _triple
from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block
@ -22,7 +22,7 @@ from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_blo
def hardtanh(g, self, min_val, max_val):
dtype = self.type().scalarType()
if dtype is None:
dtype = 6 # float
dtype = ScalarType.FLOAT
else:
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
min_val = g.op("Constant", value_t=torch.tensor(min_val, dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@ -54,7 +54,7 @@ def clamp(g, self, min, max):
return clamp_max(g, clamp_min(g, self, min), max)
@parse_args('v', 'v')
@parse_args("v", "v")
def clamp_min(g, self, min):
dtype = self.type().scalarType()
min = g.op("Cast", min, to_i=sym_help.cast_pytorch_to_onnx[dtype])
@ -65,7 +65,7 @@ def clamp_min(g, self, min):
return g.op("Max", self, min)
@parse_args('v', 'v')
@parse_args("v", "v")
def clamp_max(g, self, max):
dtype = self.type().scalarType()
max = g.op("Cast", max, to_i=sym_help.cast_pytorch_to_onnx[dtype])
@ -80,7 +80,7 @@ def relu6(g, input):
relu = g.op("Relu", input)
dtype = input.type().scalarType()
if dtype is None:
dtype = 6 # float
dtype = ScalarType.FLOAT
else:
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
min_val = g.op("Constant", value_t=torch.tensor(0, dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@ -111,7 +111,7 @@ def index_put(g, self, indices_list_value, values, accumulate=False):
if len(indices_list) > 1:
for idx_ in range(len(indices_list)):
if indices_list[idx_].type().scalarType() == 'Bool':
if indices_list[idx_].type().scalarType() == "Bool":
indices_list[idx_] = g.op("NonZero", indices_list[idx_])
index = indices_list[0]

View File

@ -152,7 +152,7 @@ def _reduce_op_symbolic(onnx_op_name):
# all-reduce path
return sym_help._handle_reduce_dim_none(g, self, onnx_op_name)
else:
keepdim = sym_help._get_const(keepdim, 'i', 'keepdim')
keepdim = sym_help._get_const(keepdim, "i", "keepdim")
return g.op(onnx_op_name, self, dim, keepdims_i=keepdim)
return symbolic

View File

@ -3,7 +3,7 @@ import torch
import torch.onnx.symbolic_helper as sym_help
import torch.onnx.symbolic_opset9 as sym_opset9
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _block_list_in_opset, _try_get_scalar_type
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _block_list_in_opset, _try_get_scalar_type, ScalarType
from torch.onnx.symbolic_opset9 import _cast_Float # type: ignore[attr-defined]
import warnings
@ -209,7 +209,7 @@ def flatten(g, input, start_dim, end_dim):
def _constant_fill(g, sizes, dtype, const_value):
if dtype is None:
dtype = 6 # float
dtype = ScalarType.FLOAT
if not sym_help.scalar_type_to_pytorch_type[dtype].is_floating_point:
result = g.op(
"ConstantFill", sizes, dtype_i=sym_help.cast_pytorch_to_onnx["Float"], input_as_shape_i=1, value_f=const_value)

View File

@ -11,7 +11,7 @@ from functools import partial
from functools import wraps
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented, ScalarType
from typing import Optional
from sys import maxsize as maxsize
@ -456,7 +456,7 @@ def expand(g, self, size, implicit):
# Since onnx::expand supports two-way broadcasting,
# -1 dim value can be exported to onnx as 1
size = sym_help._reshape_helper(g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])))
dtype = 4 # dim type is int64
dtype = ScalarType.INT64
ones = ones_like(g, size, dtype)
neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1)))
size = where(g, g.op("Equal", size, neg_ones), ones, size)
@ -723,7 +723,7 @@ def silu(g, input):
return g.op("Mul", input, g.op("Sigmoid", input))
def mish(g, input):
return g.op('Mul', input, g.op('Tanh', g.op('Softplus', input)))
return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input)))
def relu(g, input):
return g.op("Relu", input)
@ -1439,7 +1439,7 @@ def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_s
bias_ = repeat(g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)))
running_mean_ = repeat(g, running_mean, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)))
running_var_ = repeat(g, running_var, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)))
input_reshaped = g.op('Reshape', input, g.op('Constant', value_t=torch.LongTensor(input_size_reshape)))
input_reshaped = g.op("Reshape", input, g.op("Constant", value_t=torch.LongTensor(input_size_reshape)))
out = batch_norm(g, input_reshaped, weight_, bias_, running_mean_, running_var_, use_input_stats,
momentum, eps, cudnn_enabled)
return view(g, out, g.op("Constant", value_t=torch.tensor(input_size)))
@ -1542,9 +1542,9 @@ def type_as(g, self, other):
# We don't know the type of other, bail by emitting ATen
return g.op("ATen", self, other, operator_s="type_as")
else:
raise RuntimeError('Unsupported: ONNX export of type_as for tensor '
'of unknown dtype. Please check if the dtype of the '
'parameter passed to the type_as function is correct.')
raise RuntimeError("Unsupported: ONNX export of type_as for tensor "
"of unknown dtype. Please check if the dtype of the "
"parameter passed to the type_as function is correct.")
@parse_args("v", "v", "i", "f")
@ -1602,10 +1602,10 @@ def clamp(g, self, min, max):
return clamp_max(g, clamp_min(g, self, min), max)
@parse_args('v', 'v')
@parse_args("v", "v")
def clamp_min(g, self, min):
if sym_help._is_constant(min):
return g.op("Clip", self, min_f=_parse_arg(min, 'f'))
return g.op("Clip", self, min_f=_parse_arg(min, "f"))
else:
dtype = self.type().scalarType()
min = g.op("Cast", min, to_i=sym_help.cast_pytorch_to_onnx[dtype])
@ -1615,7 +1615,7 @@ def clamp_min(g, self, min):
@parse_args("v", "v")
def clamp_max(g, self, max):
if sym_help._is_constant(max):
return g.op("Clip", self, max_f=_parse_arg(max, 'f'))
return g.op("Clip", self, max_f=_parse_arg(max, "f"))
else:
dtype = self.type().scalarType()
max = g.op("Cast", max, to_i=sym_help.cast_pytorch_to_onnx[dtype])
@ -1765,7 +1765,7 @@ def new_empty(g, self, sizes, dtype, layout, device, pin_memory=False):
def scalar_tensor(g, scalar, dtype, *options):
dtype = sym_help._get_const(dtype, "i", "dtype")
if dtype is None:
dtype = 6 # float
dtype = ScalarType.FLOAT
scalar = g.op("Cast", scalar, to_i=sym_help.scalar_type_to_onnx[dtype])
return scalar
@ -1798,8 +1798,8 @@ def as_tensor(g, data, dtype=None, device=None):
def zeros(g, sizes, dtype, layout, device, pin_memory=False):
# NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it
if dtype is None:
dtype = 6 # float
sizes_ = sym_help._maybe_get_const(sizes, 'is')
dtype = ScalarType.FLOAT
sizes_ = sym_help._maybe_get_const(sizes, "is")
if isinstance(sizes_, list) and len(sizes_) == 0:
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
return g.op("ConstantOfShape", sizes,
@ -1810,7 +1810,7 @@ def zeros(g, sizes, dtype, layout, device, pin_memory=False):
def zeros_like(g, input, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
shape = g.op("Shape", input)
if dtype is None:
dtype = 6 # float
dtype = ScalarType.FLOAT
return g.op("ConstantOfShape", shape,
value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@ -1826,8 +1826,8 @@ def new_zeros(g, self, sizes, dtype, layout, device, pin_memory=False):
@parse_args("v", "i", "v", "v", "v")
def ones(g, sizes, dtype, layout, device, pin_memory=False):
if dtype is None:
dtype = 6 # float
sizes_ = sym_help._maybe_get_const(sizes, 'is')
dtype = ScalarType.FLOAT
sizes_ = sym_help._maybe_get_const(sizes, "is")
if isinstance(sizes_, list) and len(sizes_) == 0:
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
return g.op("ConstantOfShape", sizes,
@ -1838,7 +1838,7 @@ def ones(g, sizes, dtype, layout, device, pin_memory=False):
def ones_like(g, input, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
shape = g.op("Shape", input)
if dtype is None:
dtype = 6 # float
dtype = ScalarType.FLOAT
return g.op("ConstantOfShape", shape,
value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@ -1852,13 +1852,13 @@ def new_ones(g, self, sizes, dtype, layout, device, pin_memory=False):
def full(g, sizes, value, dtype, layout, device, pin_memory=False):
const_value = sym_help._maybe_get_const(value, "t")
if sym_help._is_value(const_value):
dtype = 6 if dtype is None else dtype
dtype = ScalarType.FLOAT if dtype is None else dtype
tmp = zeros(g, sizes, dtype, layout, device)
return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
else:
dtype = sym_help._get_const(dtype, "i", "dtype")
dtype = 6 if dtype is None else dtype
sizes_ = sym_help._maybe_get_const(sizes, 'is')
dtype = ScalarType.FLOAT if dtype is None else dtype
sizes_ = sym_help._maybe_get_const(sizes, "is")
if isinstance(sizes_, list) and len(sizes_) == 0:
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
return g.op("ConstantOfShape", sizes,
@ -1868,7 +1868,7 @@ def full(g, sizes, value, dtype, layout, device, pin_memory=False):
def full_like(g, input, fill_value, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
fill_value = sym_help._maybe_get_const(fill_value, "f")
dtype = sym_help._get_const(dtype, "i", "dtype")
dtype = 6 if dtype is None else dtype
dtype = ScalarType.FLOAT if dtype is None else dtype
if sym_help._is_value(fill_value):
tmp = zeros_like(g, input, dtype, layout, device)
fill_value = g.op("Cast", fill_value, to_i=sym_help.scalar_type_to_onnx[dtype])
@ -1912,13 +1912,13 @@ def slice(g, self, *args):
step = _parse_arg(step, "i")
if step != 1:
raise RuntimeError("step!=1 is currently not supported")
is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == 'NoneType'
is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == 'NoneType'
is_start_onnx_const = start.node().kind() == 'onnx::Constant'
is_end_onnx_const = end.node().kind() == 'onnx::Constant'
is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == "NoneType"
is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == "NoneType"
is_start_onnx_const = start.node().kind() == "onnx::Constant"
is_end_onnx_const = end.node().kind() == "onnx::Constant"
if ((not is_start_none) and (not is_start_onnx_const)) or \
((not is_end_none) and (not is_end_onnx_const)) or \
dim.node().kind() != 'onnx::Constant':
dim.node().kind() != "onnx::Constant":
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX:
raise RuntimeError("Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice "
"is a deprecated experimental op. Please use statically allocated "
@ -1929,18 +1929,18 @@ def slice(g, self, *args):
dim_unsqueezed = sym_help._unsqueeze_helper(g, dim, [0])
return g.op("DynamicSlice", self, start_unsqueezed, end_unsqueezed, dim_unsqueezed)
else:
start = 0 if is_start_none else _parse_arg(start, 'i')
end = 9223372036854775807 if is_end_none else _parse_arg(end, 'i')
dim = _parse_arg(dim, 'i')
start = 0 if is_start_none else _parse_arg(start, "i")
end = 9223372036854775807 if is_end_none else _parse_arg(end, "i")
dim = _parse_arg(dim, "i")
return sym_help._slice_helper(g, self, axes=[dim], starts=[start], ends=[end])
elif len(args) == 3:
# aten::slice(t[] l, int start, int end, int step) -> t[]
start, end, step = args
dim = 0
is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == 'NoneType'
is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == 'NoneType'
start = 0 if is_start_none else _parse_arg(start, 'i')
end = 9223372036854775807 if is_end_none else _parse_arg(end, 'i')
is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == "NoneType"
is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == "NoneType"
start = 0 if is_start_none else _parse_arg(start, "i")
end = 9223372036854775807 if is_end_none else _parse_arg(end, "i")
return sym_help._slice_helper(g, self, axes=[dim], starts=[start], ends=[end])
else:
raise NotImplementedError("Unknown aten::slice signature")
@ -2066,7 +2066,7 @@ def to(g, self, *args):
def repeat(g, self, repeats):
dtype = 4 # int64
dtype = ScalarType.INT64
shape_ = ones_like(g, repeats, dtype)
self = g.op("Expand", self, shape_)
return g.op("Tile", self, repeats)
@ -2348,7 +2348,7 @@ def lstm_cell(g, self, hidden, w_ih, w_hh, b_ih, b_hh):
hidden = [sym_help._unsqueeze_helper(g, x, [0]) for x in hidden]
weight = (w_ih, w_hh, b_ih, b_hh) if sym_help._is_tensor(b_ih) else (w_ih, w_hh)
has_biases = True if sym_help._is_tensor(b_ih) else False
_, h_outs, c_outs = _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers=1,
_, h_outs, c_outs = _generic_rnn(g, "LSTM", input, hidden, weight, has_biases, num_layers=1,
dropout=0, train=0, bidirectional=False, batch_first=False)
return sym_help._squeeze_helper(g, h_outs, [0]), sym_help._squeeze_helper(g, c_outs, [0])
@ -2434,7 +2434,7 @@ def _pad_packed_sequence(g, data, batch_sizes, batch_first, padding_value, total
def randn(g, shapes, dtype, *options):
dtype = sym_help._get_const(dtype, "i", "dtype")
if dtype is None:
dtype = 6 # float
dtype = ScalarType.FLOAT
shape = sym_help._maybe_get_const(shapes, "is")
if sym_help._is_value(shape):
shape_const = g.op("ConstantOfShape", shapes,
@ -2446,7 +2446,7 @@ def randn(g, shapes, dtype, *options):
def rand(g, shapes, dtype, *options):
dtype = sym_help._get_const(dtype, "i", "dtype")
if dtype is None:
dtype = 6 # float
dtype = ScalarType.FLOAT
shape = sym_help._maybe_get_const(shapes, "is")
if sym_help._is_value(shape):
shape_const = g.op("ConstantOfShape", shapes,
@ -2458,14 +2458,14 @@ def rand(g, shapes, dtype, *options):
def randn_like(g, self, dtype, layout=None, device=None, pin_memory=False, memory_format=None):
dtype = sym_help._get_const(dtype, "i", "dtype")
if dtype is None:
dtype = 6 # float
dtype = ScalarType.FLOAT
return g.op("RandomNormalLike", self, dtype_i=sym_help.scalar_type_to_onnx[dtype])
def rand_like(g, self, dtype, layout=None, device=None, pin_memory=False, memory_format=None):
dtype = sym_help._get_const(dtype, "i", "dtype")
if dtype is None:
dtype = 6 # float
dtype = ScalarType.FLOAT
return g.op("RandomUniformLike", self, dtype_i=sym_help.scalar_type_to_onnx[dtype])
@ -2484,12 +2484,12 @@ def bernoulli(g, input, generator=None, out=None):
dtype = sym_help._try_get_scalar_type(input)
if dtype is None:
return _unimplemented("Bernoulli", "input dtype not accessible")
p = g.op('RandomUniformLike', input, high_f=1.0, low_f=0.0, dtype_i=sym_help.cast_pytorch_to_onnx[dtype])
output = g.op('Less', p, input)
p = g.op("RandomUniformLike", input, high_f=1.0, low_f=0.0, dtype_i=sym_help.cast_pytorch_to_onnx[dtype])
output = g.op("Less", p, input)
return g.op("Cast", output, to_i=sym_help.cast_pytorch_to_onnx[dtype])
@parse_args('v')
@parse_args("v")
def log_sigmoid(g, input):
p = g.op("Sigmoid", input)
return g.op("Log", p)
@ -2965,13 +2965,13 @@ def baddbmm(g, self, batch1, batch2, beta, alpha):
return add(g, mul_a, mul_b)
@parse_args('v', 's')
@parse_args("v", "s")
def meshgrid(g, tensor_list, indexing: Optional[str] = None):
if indexing is None:
indexing = 'ij'
elif indexing not in {'ij', 'xy'}:
raise ValueError(f'Unsupported indexing: {indexing}')
if indexing == 'xy':
indexing = "ij"
elif indexing not in {"ij", "xy"}:
raise ValueError(f"Unsupported indexing: {indexing}")
if indexing == "xy":
tensor_list[0], tensor_list[1] = tensor_list[1], tensor_list[0]
tensors = [sym_help._reshape_helper(g, t, g.op("Constant", value_t=torch.LongTensor([-1])))
for t in sym_help._unpack_list(tensor_list)]
@ -2983,7 +2983,7 @@ def meshgrid(g, tensor_list, indexing: Optional[str] = None):
shape_i[i] = tensors_shape[i]
t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0))
out.append(g.op("Expand", t_reshaped, out_shape))
if indexing == 'xy':
if indexing == "xy":
out[0], out[1] = out[1], out[0]
return g.op("prim::ListConstruct", *out)
@ -3217,11 +3217,11 @@ def dot(g, self, other):
return matmul(g, self, other)
@parse_args('v', 'v')
@parse_args("v", "v")
def fill(g, self, value):
dtype = self.type().scalarType()
if dtype is None:
dtype = 6 # float
dtype = ScalarType.FLOAT
else:
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
@ -3284,7 +3284,7 @@ def index_add(g, self, dim, index, other):
return scatter_add(g, self, dim, expand_as(g, index, other), other)
@parse_args('v', 'is', 'is')
@parse_args("v", "is", "is")
def roll(g, self, shifts, dims):
assert len(shifts) == len(dims)