diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 3c741fb8845..4e39baf3227 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -175,6 +175,16 @@ class TestONNXRuntime(unittest.TestCase): x = {"test_key_in": torch.randn(1, 2, 3)} self.run_test(MyModel(), (x,)) + @skipIfUnsupportedMinOpsetVersion(9) + def test_cste_script(self): + class MyModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + return torch.zeros(x.size(0)), torch.ones((x.size(1), x.size(0)), dtype=torch.int64) + + x = torch.randn(3, 4) + self.run_test(MyModel(), x) + def test_clamp(self): class ClampModel(torch.nn.Module): def forward(self, x): diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 048e69c36c1..762e0a393ac 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -54,6 +54,8 @@ def _parse_arg(value, desc): return value if desc == 'v' or not _is_value(value): return value + if value.node().mustBeNone(): + return None if value.node().kind() == 'onnx::Constant': tval = value.node()['value'] if desc == 'i': diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py index 52a48d2fa58..4da2bc1ead5 100644 --- a/torch/onnx/symbolic_opset8.py +++ b/torch/onnx/symbolic_opset8.py @@ -205,6 +205,8 @@ def flatten(g, input, start_dim, end_dim): def _constant_fill(g, sizes, dtype, const_value): + if dtype is None: + dtype = 6 # float if not sym_help.scalar_type_to_pytorch_type[dtype].is_floating_point: result = g.op( "ConstantFill", sizes, dtype_i=sym_help.cast_pytorch_to_onnx["Float"], input_as_shape_i=1, value_f=const_value) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 19dd546ecc8..2ceee200806 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -1215,6 +1215,8 @@ def empty_like(g, input, dtype, layout, device, pin_memory=False, memory_format= @parse_args('v', 'i', 'v', 'v', 'v') def zeros(g, sizes, dtype, layout, device, pin_memory=False): # NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it + if dtype is None: + dtype = 6 # float return g.op("ConstantOfShape", sizes, value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype])) @@ -1222,12 +1224,16 @@ def zeros(g, sizes, dtype, layout, device, pin_memory=False): @parse_args('v', 'i', 'v', 'v', 'v') def zeros_like(g, input, dtype, layout, device, pin_memory=False): shape = g.op("Shape", input) + if dtype is None: + dtype = 6 # float return g.op("ConstantOfShape", shape, value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype])) @parse_args('v', 'i', 'v', 'v', 'v') def ones(g, sizes, dtype, layout, device, pin_memory=False): + if dtype is None: + dtype = 6 # float return g.op("ConstantOfShape", sizes, value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype])) @@ -1235,11 +1241,15 @@ def ones(g, sizes, dtype, layout, device, pin_memory=False): @parse_args('v', 'i', 'v', 'v', 'v') def ones_like(g, input, dtype, layout, device, pin_memory=False): shape = g.op("Shape", input) + if dtype is None: + dtype = 6 # float return g.op("ConstantOfShape", shape, value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype])) def full(g, sizes, value, dtype, layout, device, pin_memory=False): + if dtype is None: + dtype = 6 # float const_value = sym_help._maybe_get_const(value, 't') if sym_help._is_value(const_value): tmp = zeros(g, sizes, dtype, layout, device) @@ -1253,6 +1263,8 @@ def full(g, sizes, value, dtype, layout, device, pin_memory=False): @parse_args('v', 'f', 'i', 'v', 'v', 'v') def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False): shape = g.op("Shape", input) + if dtype is None: + dtype = 6 # float return g.op("ConstantOfShape", shape, value_t=torch.tensor([fill_value], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))