mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
ONNX Export ConstantOfShape with default dtype (#27577)
Summary: Exporting a scripted module to ONNX, with ops like torch.zeros(), fails when the dtype is not specified. This PR adds support to exporting scripted torch.zeros() ops (and similar ops) without specifying the dtype (dtype will default to float). Pull Request resolved: https://github.com/pytorch/pytorch/pull/27577 Reviewed By: hl475 Differential Revision: D17822318 Pulled By: houseroad fbshipit-source-id: b2d4300b869e782a9b72534fea1263eb83744953
This commit is contained in:
parent
e049e0b027
commit
2093fac4ee
|
|
@ -175,6 +175,16 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
x = {"test_key_in": torch.randn(1, 2, 3)}
|
||||
self.run_test(MyModel(), (x,))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_cste_script(self):
|
||||
class MyModel(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
return torch.zeros(x.size(0)), torch.ones((x.size(1), x.size(0)), dtype=torch.int64)
|
||||
|
||||
x = torch.randn(3, 4)
|
||||
self.run_test(MyModel(), x)
|
||||
|
||||
def test_clamp(self):
|
||||
class ClampModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
|
|||
|
|
@ -54,6 +54,8 @@ def _parse_arg(value, desc):
|
|||
return value
|
||||
if desc == 'v' or not _is_value(value):
|
||||
return value
|
||||
if value.node().mustBeNone():
|
||||
return None
|
||||
if value.node().kind() == 'onnx::Constant':
|
||||
tval = value.node()['value']
|
||||
if desc == 'i':
|
||||
|
|
|
|||
|
|
@ -205,6 +205,8 @@ def flatten(g, input, start_dim, end_dim):
|
|||
|
||||
|
||||
def _constant_fill(g, sizes, dtype, const_value):
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
if not sym_help.scalar_type_to_pytorch_type[dtype].is_floating_point:
|
||||
result = g.op(
|
||||
"ConstantFill", sizes, dtype_i=sym_help.cast_pytorch_to_onnx["Float"], input_as_shape_i=1, value_f=const_value)
|
||||
|
|
|
|||
|
|
@ -1215,6 +1215,8 @@ def empty_like(g, input, dtype, layout, device, pin_memory=False, memory_format=
|
|||
@parse_args('v', 'i', 'v', 'v', 'v')
|
||||
def zeros(g, sizes, dtype, layout, device, pin_memory=False):
|
||||
# NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
return g.op("ConstantOfShape", sizes,
|
||||
value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
|
||||
|
||||
|
|
@ -1222,12 +1224,16 @@ def zeros(g, sizes, dtype, layout, device, pin_memory=False):
|
|||
@parse_args('v', 'i', 'v', 'v', 'v')
|
||||
def zeros_like(g, input, dtype, layout, device, pin_memory=False):
|
||||
shape = g.op("Shape", input)
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
return g.op("ConstantOfShape", shape,
|
||||
value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
|
||||
|
||||
|
||||
@parse_args('v', 'i', 'v', 'v', 'v')
|
||||
def ones(g, sizes, dtype, layout, device, pin_memory=False):
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
return g.op("ConstantOfShape", sizes,
|
||||
value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
|
||||
|
||||
|
|
@ -1235,11 +1241,15 @@ def ones(g, sizes, dtype, layout, device, pin_memory=False):
|
|||
@parse_args('v', 'i', 'v', 'v', 'v')
|
||||
def ones_like(g, input, dtype, layout, device, pin_memory=False):
|
||||
shape = g.op("Shape", input)
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
return g.op("ConstantOfShape", shape,
|
||||
value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
|
||||
|
||||
|
||||
def full(g, sizes, value, dtype, layout, device, pin_memory=False):
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
const_value = sym_help._maybe_get_const(value, 't')
|
||||
if sym_help._is_value(const_value):
|
||||
tmp = zeros(g, sizes, dtype, layout, device)
|
||||
|
|
@ -1253,6 +1263,8 @@ def full(g, sizes, value, dtype, layout, device, pin_memory=False):
|
|||
@parse_args('v', 'f', 'i', 'v', 'v', 'v')
|
||||
def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False):
|
||||
shape = g.op("Shape", input)
|
||||
if dtype is None:
|
||||
dtype = 6 # float
|
||||
return g.op("ConstantOfShape", shape,
|
||||
value_t=torch.tensor([fill_value], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user