ONNX Export ConstantOfShape with default dtype (#27577)

Summary:
Exporting a scripted module to ONNX, with ops like torch.zeros(), fails when the dtype is not specified.
This PR adds support to exporting scripted torch.zeros() ops (and similar ops) without specifying the dtype (dtype will default to float).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27577

Reviewed By: hl475

Differential Revision: D17822318

Pulled By: houseroad

fbshipit-source-id: b2d4300b869e782a9b72534fea1263eb83744953
This commit is contained in:
Lara Haidar 2019-10-09 17:03:27 -07:00 committed by Facebook Github Bot
parent e049e0b027
commit 2093fac4ee
4 changed files with 26 additions and 0 deletions

View File

@ -175,6 +175,16 @@ class TestONNXRuntime(unittest.TestCase):
x = {"test_key_in": torch.randn(1, 2, 3)}
self.run_test(MyModel(), (x,))
@skipIfUnsupportedMinOpsetVersion(9)
def test_cste_script(self):
class MyModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return torch.zeros(x.size(0)), torch.ones((x.size(1), x.size(0)), dtype=torch.int64)
x = torch.randn(3, 4)
self.run_test(MyModel(), x)
def test_clamp(self):
class ClampModel(torch.nn.Module):
def forward(self, x):

View File

@ -54,6 +54,8 @@ def _parse_arg(value, desc):
return value
if desc == 'v' or not _is_value(value):
return value
if value.node().mustBeNone():
return None
if value.node().kind() == 'onnx::Constant':
tval = value.node()['value']
if desc == 'i':

View File

@ -205,6 +205,8 @@ def flatten(g, input, start_dim, end_dim):
def _constant_fill(g, sizes, dtype, const_value):
if dtype is None:
dtype = 6 # float
if not sym_help.scalar_type_to_pytorch_type[dtype].is_floating_point:
result = g.op(
"ConstantFill", sizes, dtype_i=sym_help.cast_pytorch_to_onnx["Float"], input_as_shape_i=1, value_f=const_value)

View File

@ -1215,6 +1215,8 @@ def empty_like(g, input, dtype, layout, device, pin_memory=False, memory_format=
@parse_args('v', 'i', 'v', 'v', 'v')
def zeros(g, sizes, dtype, layout, device, pin_memory=False):
# NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it
if dtype is None:
dtype = 6 # float
return g.op("ConstantOfShape", sizes,
value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@ -1222,12 +1224,16 @@ def zeros(g, sizes, dtype, layout, device, pin_memory=False):
@parse_args('v', 'i', 'v', 'v', 'v')
def zeros_like(g, input, dtype, layout, device, pin_memory=False):
shape = g.op("Shape", input)
if dtype is None:
dtype = 6 # float
return g.op("ConstantOfShape", shape,
value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@parse_args('v', 'i', 'v', 'v', 'v')
def ones(g, sizes, dtype, layout, device, pin_memory=False):
if dtype is None:
dtype = 6 # float
return g.op("ConstantOfShape", sizes,
value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@ -1235,11 +1241,15 @@ def ones(g, sizes, dtype, layout, device, pin_memory=False):
@parse_args('v', 'i', 'v', 'v', 'v')
def ones_like(g, input, dtype, layout, device, pin_memory=False):
shape = g.op("Shape", input)
if dtype is None:
dtype = 6 # float
return g.op("ConstantOfShape", shape,
value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
def full(g, sizes, value, dtype, layout, device, pin_memory=False):
if dtype is None:
dtype = 6 # float
const_value = sym_help._maybe_get_const(value, 't')
if sym_help._is_value(const_value):
tmp = zeros(g, sizes, dtype, layout, device)
@ -1253,6 +1263,8 @@ def full(g, sizes, value, dtype, layout, device, pin_memory=False):
@parse_args('v', 'f', 'i', 'v', 'v', 'v')
def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False):
shape = g.op("Shape", input)
if dtype is None:
dtype = 6 # float
return g.op("ConstantOfShape", shape,
value_t=torch.tensor([fill_value], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))