mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67807 Also make quoting of string literals consistent. Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32181309 Pulled By: malfet fbshipit-source-id: e1053701e3589f0310d8b5ef920359c03c6713f0
This commit is contained in:
parent
958d517643
commit
eb22d06e5e
|
|
@ -1,4 +1,4 @@
|
|||
|
||||
import enum
|
||||
import torch
|
||||
import warnings
|
||||
import inspect
|
||||
|
|
@ -911,6 +911,27 @@ scalar_name_to_pytorch = {
|
|||
}
|
||||
|
||||
|
||||
|
||||
class ScalarType(enum.IntEnum):
|
||||
"""A human-readable name for a key into scalar_type_to_pytorch_type."""
|
||||
UINT8 = 0
|
||||
INT8 = enum.auto()
|
||||
SHORT = enum.auto()
|
||||
INT = enum.auto()
|
||||
INT64 = enum.auto()
|
||||
HALF = enum.auto()
|
||||
FLOAT = enum.auto()
|
||||
DOUBLE = enum.auto()
|
||||
COMPLEX32 = enum.auto()
|
||||
COMPLEX64 = enum.auto()
|
||||
COMPLEX128 = enum.auto()
|
||||
BOOL = enum.auto()
|
||||
QINT8 = enum.auto()
|
||||
QUINT8 = enum.auto()
|
||||
QINT32 = enum.auto()
|
||||
BFLOAT16 = enum.auto()
|
||||
|
||||
|
||||
# This indicates each scalar type's corresponding
|
||||
# torch type. Related source:
|
||||
# https://github.com/pytorch/pytorch/blob/da7468853ae322252270bbb58032668bd21b7457/c10/core/ScalarType.h
|
||||
|
|
|
|||
|
|
@ -205,23 +205,23 @@ def slice(g, self, *args):
|
|||
dim = 0
|
||||
else:
|
||||
raise NotImplementedError("Unknown aten::slice signature")
|
||||
is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == 'NoneType'
|
||||
is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == 'NoneType'
|
||||
is_start_onnx_const = start.node().kind() == 'onnx::Constant'
|
||||
is_end_onnx_const = end.node().kind() == 'onnx::Constant'
|
||||
step = sym_help._parse_arg(step, 'i')
|
||||
is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == "NoneType"
|
||||
is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == "NoneType"
|
||||
is_start_onnx_const = start.node().kind() == "onnx::Constant"
|
||||
is_end_onnx_const = end.node().kind() == "onnx::Constant"
|
||||
step = sym_help._parse_arg(step, "i")
|
||||
if (not is_start_none and not is_start_onnx_const) or \
|
||||
(not isinstance(end, int) and not is_end_none and not is_end_onnx_const) or \
|
||||
(not isinstance(dim, int) and dim.node().kind() != 'onnx::Constant'):
|
||||
(not isinstance(dim, int) and dim.node().kind() != "onnx::Constant"):
|
||||
dynamic_slice = True
|
||||
if is_start_none:
|
||||
start = g.op("Constant", value_t=torch.tensor(0))
|
||||
if is_end_none:
|
||||
end = g.op("Constant", value_t=torch.tensor(9223372036854775807))
|
||||
else:
|
||||
start = [0 if is_start_none else sym_help._parse_arg(start, 'i')]
|
||||
end = [9223372036854775807 if is_end_none else sym_help._parse_arg(end, 'i')]
|
||||
dim = [sym_help._parse_arg(dim, 'i')]
|
||||
start = [0 if is_start_none else sym_help._parse_arg(start, "i")]
|
||||
end = [9223372036854775807 if is_end_none else sym_help._parse_arg(end, "i")]
|
||||
dim = [sym_help._parse_arg(dim, "i")]
|
||||
dynamic_slice = False
|
||||
return sym_help._slice_helper(g, self, axes=dim, starts=start, ends=end, steps=[step], dynamic_slice=dynamic_slice)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import torch
|
|||
import torch.onnx.symbolic_helper as sym_help
|
||||
import warnings
|
||||
|
||||
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _is_tensor_list
|
||||
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _is_tensor_list, ScalarType
|
||||
from torch.onnx.symbolic_opset9 import expand, unused, mul
|
||||
from torch.nn.modules.utils import _single, _pair, _triple
|
||||
from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block
|
||||
|
|
@ -22,7 +22,7 @@ from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_blo
|
|||
def hardtanh(g, self, min_val, max_val):
|
||||
dtype = self.type().scalarType()
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
dtype = ScalarType.FLOAT
|
||||
else:
|
||||
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
|
||||
min_val = g.op("Constant", value_t=torch.tensor(min_val, dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
|
||||
|
|
@ -54,7 +54,7 @@ def clamp(g, self, min, max):
|
|||
return clamp_max(g, clamp_min(g, self, min), max)
|
||||
|
||||
|
||||
@parse_args('v', 'v')
|
||||
@parse_args("v", "v")
|
||||
def clamp_min(g, self, min):
|
||||
dtype = self.type().scalarType()
|
||||
min = g.op("Cast", min, to_i=sym_help.cast_pytorch_to_onnx[dtype])
|
||||
|
|
@ -65,7 +65,7 @@ def clamp_min(g, self, min):
|
|||
return g.op("Max", self, min)
|
||||
|
||||
|
||||
@parse_args('v', 'v')
|
||||
@parse_args("v", "v")
|
||||
def clamp_max(g, self, max):
|
||||
dtype = self.type().scalarType()
|
||||
max = g.op("Cast", max, to_i=sym_help.cast_pytorch_to_onnx[dtype])
|
||||
|
|
@ -80,7 +80,7 @@ def relu6(g, input):
|
|||
relu = g.op("Relu", input)
|
||||
dtype = input.type().scalarType()
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
dtype = ScalarType.FLOAT
|
||||
else:
|
||||
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
|
||||
min_val = g.op("Constant", value_t=torch.tensor(0, dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
|
||||
|
|
@ -111,7 +111,7 @@ def index_put(g, self, indices_list_value, values, accumulate=False):
|
|||
|
||||
if len(indices_list) > 1:
|
||||
for idx_ in range(len(indices_list)):
|
||||
if indices_list[idx_].type().scalarType() == 'Bool':
|
||||
if indices_list[idx_].type().scalarType() == "Bool":
|
||||
indices_list[idx_] = g.op("NonZero", indices_list[idx_])
|
||||
index = indices_list[0]
|
||||
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ def _reduce_op_symbolic(onnx_op_name):
|
|||
# all-reduce path
|
||||
return sym_help._handle_reduce_dim_none(g, self, onnx_op_name)
|
||||
else:
|
||||
keepdim = sym_help._get_const(keepdim, 'i', 'keepdim')
|
||||
keepdim = sym_help._get_const(keepdim, "i", "keepdim")
|
||||
return g.op(onnx_op_name, self, dim, keepdims_i=keepdim)
|
||||
return symbolic
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import torch
|
|||
import torch.onnx.symbolic_helper as sym_help
|
||||
import torch.onnx.symbolic_opset9 as sym_opset9
|
||||
|
||||
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _block_list_in_opset, _try_get_scalar_type
|
||||
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _block_list_in_opset, _try_get_scalar_type, ScalarType
|
||||
from torch.onnx.symbolic_opset9 import _cast_Float # type: ignore[attr-defined]
|
||||
|
||||
import warnings
|
||||
|
|
@ -209,7 +209,7 @@ def flatten(g, input, start_dim, end_dim):
|
|||
|
||||
def _constant_fill(g, sizes, dtype, const_value):
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
dtype = ScalarType.FLOAT
|
||||
if not sym_help.scalar_type_to_pytorch_type[dtype].is_floating_point:
|
||||
result = g.op(
|
||||
"ConstantFill", sizes, dtype_i=sym_help.cast_pytorch_to_onnx["Float"], input_as_shape_i=1, value_f=const_value)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from functools import partial
|
|||
from functools import wraps
|
||||
|
||||
import torch.onnx.symbolic_helper as sym_help
|
||||
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented
|
||||
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented, ScalarType
|
||||
|
||||
from typing import Optional
|
||||
from sys import maxsize as maxsize
|
||||
|
|
@ -456,7 +456,7 @@ def expand(g, self, size, implicit):
|
|||
# Since onnx::expand supports two-way broadcasting,
|
||||
# -1 dim value can be exported to onnx as 1
|
||||
size = sym_help._reshape_helper(g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])))
|
||||
dtype = 4 # dim type is int64
|
||||
dtype = ScalarType.INT64
|
||||
ones = ones_like(g, size, dtype)
|
||||
neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1)))
|
||||
size = where(g, g.op("Equal", size, neg_ones), ones, size)
|
||||
|
|
@ -723,7 +723,7 @@ def silu(g, input):
|
|||
return g.op("Mul", input, g.op("Sigmoid", input))
|
||||
|
||||
def mish(g, input):
|
||||
return g.op('Mul', input, g.op('Tanh', g.op('Softplus', input)))
|
||||
return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input)))
|
||||
|
||||
def relu(g, input):
|
||||
return g.op("Relu", input)
|
||||
|
|
@ -1439,7 +1439,7 @@ def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_s
|
|||
bias_ = repeat(g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)))
|
||||
running_mean_ = repeat(g, running_mean, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)))
|
||||
running_var_ = repeat(g, running_var, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)))
|
||||
input_reshaped = g.op('Reshape', input, g.op('Constant', value_t=torch.LongTensor(input_size_reshape)))
|
||||
input_reshaped = g.op("Reshape", input, g.op("Constant", value_t=torch.LongTensor(input_size_reshape)))
|
||||
out = batch_norm(g, input_reshaped, weight_, bias_, running_mean_, running_var_, use_input_stats,
|
||||
momentum, eps, cudnn_enabled)
|
||||
return view(g, out, g.op("Constant", value_t=torch.tensor(input_size)))
|
||||
|
|
@ -1542,9 +1542,9 @@ def type_as(g, self, other):
|
|||
# We don't know the type of other, bail by emitting ATen
|
||||
return g.op("ATen", self, other, operator_s="type_as")
|
||||
else:
|
||||
raise RuntimeError('Unsupported: ONNX export of type_as for tensor '
|
||||
'of unknown dtype. Please check if the dtype of the '
|
||||
'parameter passed to the type_as function is correct.')
|
||||
raise RuntimeError("Unsupported: ONNX export of type_as for tensor "
|
||||
"of unknown dtype. Please check if the dtype of the "
|
||||
"parameter passed to the type_as function is correct.")
|
||||
|
||||
|
||||
@parse_args("v", "v", "i", "f")
|
||||
|
|
@ -1602,10 +1602,10 @@ def clamp(g, self, min, max):
|
|||
return clamp_max(g, clamp_min(g, self, min), max)
|
||||
|
||||
|
||||
@parse_args('v', 'v')
|
||||
@parse_args("v", "v")
|
||||
def clamp_min(g, self, min):
|
||||
if sym_help._is_constant(min):
|
||||
return g.op("Clip", self, min_f=_parse_arg(min, 'f'))
|
||||
return g.op("Clip", self, min_f=_parse_arg(min, "f"))
|
||||
else:
|
||||
dtype = self.type().scalarType()
|
||||
min = g.op("Cast", min, to_i=sym_help.cast_pytorch_to_onnx[dtype])
|
||||
|
|
@ -1615,7 +1615,7 @@ def clamp_min(g, self, min):
|
|||
@parse_args("v", "v")
|
||||
def clamp_max(g, self, max):
|
||||
if sym_help._is_constant(max):
|
||||
return g.op("Clip", self, max_f=_parse_arg(max, 'f'))
|
||||
return g.op("Clip", self, max_f=_parse_arg(max, "f"))
|
||||
else:
|
||||
dtype = self.type().scalarType()
|
||||
max = g.op("Cast", max, to_i=sym_help.cast_pytorch_to_onnx[dtype])
|
||||
|
|
@ -1765,7 +1765,7 @@ def new_empty(g, self, sizes, dtype, layout, device, pin_memory=False):
|
|||
def scalar_tensor(g, scalar, dtype, *options):
|
||||
dtype = sym_help._get_const(dtype, "i", "dtype")
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
dtype = ScalarType.FLOAT
|
||||
scalar = g.op("Cast", scalar, to_i=sym_help.scalar_type_to_onnx[dtype])
|
||||
return scalar
|
||||
|
||||
|
|
@ -1798,8 +1798,8 @@ def as_tensor(g, data, dtype=None, device=None):
|
|||
def zeros(g, sizes, dtype, layout, device, pin_memory=False):
|
||||
# NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
sizes_ = sym_help._maybe_get_const(sizes, 'is')
|
||||
dtype = ScalarType.FLOAT
|
||||
sizes_ = sym_help._maybe_get_const(sizes, "is")
|
||||
if isinstance(sizes_, list) and len(sizes_) == 0:
|
||||
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
|
||||
return g.op("ConstantOfShape", sizes,
|
||||
|
|
@ -1810,7 +1810,7 @@ def zeros(g, sizes, dtype, layout, device, pin_memory=False):
|
|||
def zeros_like(g, input, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
|
||||
shape = g.op("Shape", input)
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
dtype = ScalarType.FLOAT
|
||||
return g.op("ConstantOfShape", shape,
|
||||
value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
|
||||
|
||||
|
|
@ -1826,8 +1826,8 @@ def new_zeros(g, self, sizes, dtype, layout, device, pin_memory=False):
|
|||
@parse_args("v", "i", "v", "v", "v")
|
||||
def ones(g, sizes, dtype, layout, device, pin_memory=False):
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
sizes_ = sym_help._maybe_get_const(sizes, 'is')
|
||||
dtype = ScalarType.FLOAT
|
||||
sizes_ = sym_help._maybe_get_const(sizes, "is")
|
||||
if isinstance(sizes_, list) and len(sizes_) == 0:
|
||||
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
|
||||
return g.op("ConstantOfShape", sizes,
|
||||
|
|
@ -1838,7 +1838,7 @@ def ones(g, sizes, dtype, layout, device, pin_memory=False):
|
|||
def ones_like(g, input, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
|
||||
shape = g.op("Shape", input)
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
dtype = ScalarType.FLOAT
|
||||
return g.op("ConstantOfShape", shape,
|
||||
value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
|
||||
|
||||
|
|
@ -1852,13 +1852,13 @@ def new_ones(g, self, sizes, dtype, layout, device, pin_memory=False):
|
|||
def full(g, sizes, value, dtype, layout, device, pin_memory=False):
|
||||
const_value = sym_help._maybe_get_const(value, "t")
|
||||
if sym_help._is_value(const_value):
|
||||
dtype = 6 if dtype is None else dtype
|
||||
dtype = ScalarType.FLOAT if dtype is None else dtype
|
||||
tmp = zeros(g, sizes, dtype, layout, device)
|
||||
return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
|
||||
else:
|
||||
dtype = sym_help._get_const(dtype, "i", "dtype")
|
||||
dtype = 6 if dtype is None else dtype
|
||||
sizes_ = sym_help._maybe_get_const(sizes, 'is')
|
||||
dtype = ScalarType.FLOAT if dtype is None else dtype
|
||||
sizes_ = sym_help._maybe_get_const(sizes, "is")
|
||||
if isinstance(sizes_, list) and len(sizes_) == 0:
|
||||
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
|
||||
return g.op("ConstantOfShape", sizes,
|
||||
|
|
@ -1868,7 +1868,7 @@ def full(g, sizes, value, dtype, layout, device, pin_memory=False):
|
|||
def full_like(g, input, fill_value, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
|
||||
fill_value = sym_help._maybe_get_const(fill_value, "f")
|
||||
dtype = sym_help._get_const(dtype, "i", "dtype")
|
||||
dtype = 6 if dtype is None else dtype
|
||||
dtype = ScalarType.FLOAT if dtype is None else dtype
|
||||
if sym_help._is_value(fill_value):
|
||||
tmp = zeros_like(g, input, dtype, layout, device)
|
||||
fill_value = g.op("Cast", fill_value, to_i=sym_help.scalar_type_to_onnx[dtype])
|
||||
|
|
@ -1912,13 +1912,13 @@ def slice(g, self, *args):
|
|||
step = _parse_arg(step, "i")
|
||||
if step != 1:
|
||||
raise RuntimeError("step!=1 is currently not supported")
|
||||
is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == 'NoneType'
|
||||
is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == 'NoneType'
|
||||
is_start_onnx_const = start.node().kind() == 'onnx::Constant'
|
||||
is_end_onnx_const = end.node().kind() == 'onnx::Constant'
|
||||
is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == "NoneType"
|
||||
is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == "NoneType"
|
||||
is_start_onnx_const = start.node().kind() == "onnx::Constant"
|
||||
is_end_onnx_const = end.node().kind() == "onnx::Constant"
|
||||
if ((not is_start_none) and (not is_start_onnx_const)) or \
|
||||
((not is_end_none) and (not is_end_onnx_const)) or \
|
||||
dim.node().kind() != 'onnx::Constant':
|
||||
dim.node().kind() != "onnx::Constant":
|
||||
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX:
|
||||
raise RuntimeError("Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice "
|
||||
"is a deprecated experimental op. Please use statically allocated "
|
||||
|
|
@ -1929,18 +1929,18 @@ def slice(g, self, *args):
|
|||
dim_unsqueezed = sym_help._unsqueeze_helper(g, dim, [0])
|
||||
return g.op("DynamicSlice", self, start_unsqueezed, end_unsqueezed, dim_unsqueezed)
|
||||
else:
|
||||
start = 0 if is_start_none else _parse_arg(start, 'i')
|
||||
end = 9223372036854775807 if is_end_none else _parse_arg(end, 'i')
|
||||
dim = _parse_arg(dim, 'i')
|
||||
start = 0 if is_start_none else _parse_arg(start, "i")
|
||||
end = 9223372036854775807 if is_end_none else _parse_arg(end, "i")
|
||||
dim = _parse_arg(dim, "i")
|
||||
return sym_help._slice_helper(g, self, axes=[dim], starts=[start], ends=[end])
|
||||
elif len(args) == 3:
|
||||
# aten::slice(t[] l, int start, int end, int step) -> t[]
|
||||
start, end, step = args
|
||||
dim = 0
|
||||
is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == 'NoneType'
|
||||
is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == 'NoneType'
|
||||
start = 0 if is_start_none else _parse_arg(start, 'i')
|
||||
end = 9223372036854775807 if is_end_none else _parse_arg(end, 'i')
|
||||
is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == "NoneType"
|
||||
is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == "NoneType"
|
||||
start = 0 if is_start_none else _parse_arg(start, "i")
|
||||
end = 9223372036854775807 if is_end_none else _parse_arg(end, "i")
|
||||
return sym_help._slice_helper(g, self, axes=[dim], starts=[start], ends=[end])
|
||||
else:
|
||||
raise NotImplementedError("Unknown aten::slice signature")
|
||||
|
|
@ -2066,7 +2066,7 @@ def to(g, self, *args):
|
|||
|
||||
|
||||
def repeat(g, self, repeats):
|
||||
dtype = 4 # int64
|
||||
dtype = ScalarType.INT64
|
||||
shape_ = ones_like(g, repeats, dtype)
|
||||
self = g.op("Expand", self, shape_)
|
||||
return g.op("Tile", self, repeats)
|
||||
|
|
@ -2348,7 +2348,7 @@ def lstm_cell(g, self, hidden, w_ih, w_hh, b_ih, b_hh):
|
|||
hidden = [sym_help._unsqueeze_helper(g, x, [0]) for x in hidden]
|
||||
weight = (w_ih, w_hh, b_ih, b_hh) if sym_help._is_tensor(b_ih) else (w_ih, w_hh)
|
||||
has_biases = True if sym_help._is_tensor(b_ih) else False
|
||||
_, h_outs, c_outs = _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers=1,
|
||||
_, h_outs, c_outs = _generic_rnn(g, "LSTM", input, hidden, weight, has_biases, num_layers=1,
|
||||
dropout=0, train=0, bidirectional=False, batch_first=False)
|
||||
return sym_help._squeeze_helper(g, h_outs, [0]), sym_help._squeeze_helper(g, c_outs, [0])
|
||||
|
||||
|
|
@ -2434,7 +2434,7 @@ def _pad_packed_sequence(g, data, batch_sizes, batch_first, padding_value, total
|
|||
def randn(g, shapes, dtype, *options):
|
||||
dtype = sym_help._get_const(dtype, "i", "dtype")
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
dtype = ScalarType.FLOAT
|
||||
shape = sym_help._maybe_get_const(shapes, "is")
|
||||
if sym_help._is_value(shape):
|
||||
shape_const = g.op("ConstantOfShape", shapes,
|
||||
|
|
@ -2446,7 +2446,7 @@ def randn(g, shapes, dtype, *options):
|
|||
def rand(g, shapes, dtype, *options):
|
||||
dtype = sym_help._get_const(dtype, "i", "dtype")
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
dtype = ScalarType.FLOAT
|
||||
shape = sym_help._maybe_get_const(shapes, "is")
|
||||
if sym_help._is_value(shape):
|
||||
shape_const = g.op("ConstantOfShape", shapes,
|
||||
|
|
@ -2458,14 +2458,14 @@ def rand(g, shapes, dtype, *options):
|
|||
def randn_like(g, self, dtype, layout=None, device=None, pin_memory=False, memory_format=None):
|
||||
dtype = sym_help._get_const(dtype, "i", "dtype")
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
dtype = ScalarType.FLOAT
|
||||
return g.op("RandomNormalLike", self, dtype_i=sym_help.scalar_type_to_onnx[dtype])
|
||||
|
||||
|
||||
def rand_like(g, self, dtype, layout=None, device=None, pin_memory=False, memory_format=None):
|
||||
dtype = sym_help._get_const(dtype, "i", "dtype")
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
dtype = ScalarType.FLOAT
|
||||
return g.op("RandomUniformLike", self, dtype_i=sym_help.scalar_type_to_onnx[dtype])
|
||||
|
||||
|
||||
|
|
@ -2484,12 +2484,12 @@ def bernoulli(g, input, generator=None, out=None):
|
|||
dtype = sym_help._try_get_scalar_type(input)
|
||||
if dtype is None:
|
||||
return _unimplemented("Bernoulli", "input dtype not accessible")
|
||||
p = g.op('RandomUniformLike', input, high_f=1.0, low_f=0.0, dtype_i=sym_help.cast_pytorch_to_onnx[dtype])
|
||||
output = g.op('Less', p, input)
|
||||
p = g.op("RandomUniformLike", input, high_f=1.0, low_f=0.0, dtype_i=sym_help.cast_pytorch_to_onnx[dtype])
|
||||
output = g.op("Less", p, input)
|
||||
return g.op("Cast", output, to_i=sym_help.cast_pytorch_to_onnx[dtype])
|
||||
|
||||
|
||||
@parse_args('v')
|
||||
@parse_args("v")
|
||||
def log_sigmoid(g, input):
|
||||
p = g.op("Sigmoid", input)
|
||||
return g.op("Log", p)
|
||||
|
|
@ -2965,13 +2965,13 @@ def baddbmm(g, self, batch1, batch2, beta, alpha):
|
|||
return add(g, mul_a, mul_b)
|
||||
|
||||
|
||||
@parse_args('v', 's')
|
||||
@parse_args("v", "s")
|
||||
def meshgrid(g, tensor_list, indexing: Optional[str] = None):
|
||||
if indexing is None:
|
||||
indexing = 'ij'
|
||||
elif indexing not in {'ij', 'xy'}:
|
||||
raise ValueError(f'Unsupported indexing: {indexing}')
|
||||
if indexing == 'xy':
|
||||
indexing = "ij"
|
||||
elif indexing not in {"ij", "xy"}:
|
||||
raise ValueError(f"Unsupported indexing: {indexing}")
|
||||
if indexing == "xy":
|
||||
tensor_list[0], tensor_list[1] = tensor_list[1], tensor_list[0]
|
||||
tensors = [sym_help._reshape_helper(g, t, g.op("Constant", value_t=torch.LongTensor([-1])))
|
||||
for t in sym_help._unpack_list(tensor_list)]
|
||||
|
|
@ -2983,7 +2983,7 @@ def meshgrid(g, tensor_list, indexing: Optional[str] = None):
|
|||
shape_i[i] = tensors_shape[i]
|
||||
t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0))
|
||||
out.append(g.op("Expand", t_reshaped, out_shape))
|
||||
if indexing == 'xy':
|
||||
if indexing == "xy":
|
||||
out[0], out[1] = out[1], out[0]
|
||||
return g.op("prim::ListConstruct", *out)
|
||||
|
||||
|
|
@ -3217,11 +3217,11 @@ def dot(g, self, other):
|
|||
return matmul(g, self, other)
|
||||
|
||||
|
||||
@parse_args('v', 'v')
|
||||
@parse_args("v", "v")
|
||||
def fill(g, self, value):
|
||||
dtype = self.type().scalarType()
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
dtype = ScalarType.FLOAT
|
||||
else:
|
||||
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
|
||||
|
||||
|
|
@ -3284,7 +3284,7 @@ def index_add(g, self, dim, index, other):
|
|||
return scatter_add(g, self, dim, expand_as(g, index, other), other)
|
||||
|
||||
|
||||
@parse_args('v', 'is', 'is')
|
||||
@parse_args("v", "is", "is")
|
||||
def roll(g, self, shifts, dims):
|
||||
assert len(shifts) == len(dims)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user