From 2093fac4ee3dabd24122d726a7e05e5bbab5b04c Mon Sep 17 00:00:00 2001 From: Lara Haidar Date: Wed, 9 Oct 2019 17:03:27 -0700 Subject: [PATCH] ONNX Export ConstantOfShape with default dtype (#27577) Summary: Exporting a scripted module to ONNX, with ops like torch.zeros(), fails when the dtype is not specified. This PR adds support to exporting scripted torch.zeros() ops (and similar ops) without specifying the dtype (dtype will default to float). Pull Request resolved: https://github.com/pytorch/pytorch/pull/27577 Reviewed By: hl475 Differential Revision: D17822318 Pulled By: houseroad fbshipit-source-id: b2d4300b869e782a9b72534fea1263eb83744953 --- test/onnx/test_pytorch_onnx_onnxruntime.py | 10 ++++++++++ torch/onnx/symbolic_helper.py | 2 ++ torch/onnx/symbolic_opset8.py | 2 ++ torch/onnx/symbolic_opset9.py | 12 ++++++++++++ 4 files changed, 26 insertions(+) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 3c741fb8845..4e39baf3227 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -175,6 +175,16 @@ class TestONNXRuntime(unittest.TestCase): x = {"test_key_in": torch.randn(1, 2, 3)} self.run_test(MyModel(), (x,)) + @skipIfUnsupportedMinOpsetVersion(9) + def test_cste_script(self): + class MyModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + return torch.zeros(x.size(0)), torch.ones((x.size(1), x.size(0)), dtype=torch.int64) + + x = torch.randn(3, 4) + self.run_test(MyModel(), x) + def test_clamp(self): class ClampModel(torch.nn.Module): def forward(self, x): diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 048e69c36c1..762e0a393ac 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -54,6 +54,8 @@ def _parse_arg(value, desc): return value if desc == 'v' or not _is_value(value): return value + if value.node().mustBeNone(): + return None if value.node().kind() == 'onnx::Constant': tval = value.node()['value'] if desc == 'i': diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py index 52a48d2fa58..4da2bc1ead5 100644 --- a/torch/onnx/symbolic_opset8.py +++ b/torch/onnx/symbolic_opset8.py @@ -205,6 +205,8 @@ def flatten(g, input, start_dim, end_dim): def _constant_fill(g, sizes, dtype, const_value): + if dtype is None: + dtype = 6 # float if not sym_help.scalar_type_to_pytorch_type[dtype].is_floating_point: result = g.op( "ConstantFill", sizes, dtype_i=sym_help.cast_pytorch_to_onnx["Float"], input_as_shape_i=1, value_f=const_value) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 19dd546ecc8..2ceee200806 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -1215,6 +1215,8 @@ def empty_like(g, input, dtype, layout, device, pin_memory=False, memory_format= @parse_args('v', 'i', 'v', 'v', 'v') def zeros(g, sizes, dtype, layout, device, pin_memory=False): # NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it + if dtype is None: + dtype = 6 # float return g.op("ConstantOfShape", sizes, value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype])) @@ -1222,12 +1224,16 @@ def zeros(g, sizes, dtype, layout, device, pin_memory=False): @parse_args('v', 'i', 'v', 'v', 'v') def zeros_like(g, input, dtype, layout, device, pin_memory=False): shape = g.op("Shape", input) + if dtype is None: + dtype = 6 # float return g.op("ConstantOfShape", shape, value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype])) @parse_args('v', 'i', 'v', 'v', 'v') def ones(g, sizes, dtype, layout, device, pin_memory=False): + if dtype is None: + dtype = 6 # float return g.op("ConstantOfShape", sizes, value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype])) @@ -1235,11 +1241,15 @@ def ones(g, sizes, dtype, layout, device, pin_memory=False): @parse_args('v', 'i', 'v', 'v', 'v') def ones_like(g, input, dtype, layout, device, pin_memory=False): shape = g.op("Shape", input) + if dtype is None: + dtype = 6 # float return g.op("ConstantOfShape", shape, value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype])) def full(g, sizes, value, dtype, layout, device, pin_memory=False): + if dtype is None: + dtype = 6 # float const_value = sym_help._maybe_get_const(value, 't') if sym_help._is_value(const_value): tmp = zeros(g, sizes, dtype, layout, device) @@ -1253,6 +1263,8 @@ def full(g, sizes, value, dtype, layout, device, pin_memory=False): @parse_args('v', 'f', 'i', 'v', 'v', 'v') def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False): shape = g.op("Shape", input) + if dtype is None: + dtype = 6 # float return g.op("ConstantOfShape", shape, value_t=torch.tensor([fill_value], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))