Fix torch.arange traced as constant (#25363)

Summary:
torch.arange is always traced as a constant which makes it impossible to trace correctly TestModel() from the example below.

class TestModel(torch.nn.Module):
  def forward(self, input):
    return torch.arange(input.shape[0])
input = torch.randn(5,3,2)
print(torch.jit.trace(TestModel(), input).graph)

Currently the trace of TestModel() looks like:

graph(%self : ClassType<TestModel>,
      %input : Float(5, 3, 2)):
  %11 : int = prim::Constant[value=5]()
  %12 : int = prim::Constant[value=4]()
  %13 : int = prim::Constant[value=0]()
  %14 : Device = prim::Constant[value="cpu"]()
  %15 : bool = prim::Constant[value=0]()
  %16 : Long(5) = aten::arange(%11, %12, %13, %14, %15)
  return (%16)

This PR will allow the trace to have a variable value for %11.
The trace of TestModel() with this PR's modifs looks like:

graph(%self : ClassType<TestModel>,
      %input : Float(5, 3, 2)):
  %2 : int = prim::Constant[value=0]()
  %3 : int = aten::size(%input, %2)
  %4 : Long() = prim::NumToTensor(%3)
  %11 : Scalar = prim::ImplicitTensorToNum(%4)
  %12 : int = prim::Constant[value=4]()
  %13 : int = prim::Constant[value=0]()
  %14 : Device = prim::Constant[value="cpu"]()
  %15 : bool = prim::Constant[value=0]()
  %16 : Long(5) = aten::arange(%11, %12, %13, %14, %15)
  return (%16)

More info : https://github.com/pytorch/pytorch/issues/20075
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25363

Reviewed By: zrphercule

Differential Revision: D17301934

Pulled By: houseroad

fbshipit-source-id: d9907763742cb51d8c761bf63fc2e4918f7b9941
This commit is contained in:
Lara Haidar 2019-09-11 13:37:20 -07:00 committed by Facebook Github Bot
parent 62767077c3
commit 8ca93ec351
5 changed files with 229 additions and 30 deletions

View File

@ -0,0 +1,129 @@
ir_version: 4
producer_name: "pytorch"
producer_version: "1.2"
graph {
node {
output: "1"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 7
raw_data: "\000\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "0"
output: "2"
op_type: "Shape"
}
node {
input: "2"
input: "1"
output: "3"
op_type: "Gather"
attribute {
name: "axis"
i: 0
type: INT
}
}
node {
input: "3"
output: "4"
op_type: "Unsqueeze"
attribute {
name: "axes"
ints: 0
type: INTS
}
}
node {
input: "4"
output: "5"
op_type: "ConstantOfShape"
attribute {
name: "value"
t {
dims: 1
data_type: 7
raw_data: "\001\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "5"
output: "6"
op_type: "NonZero"
}
node {
input: "6"
output: "7"
op_type: "Transpose"
attribute {
name: "perm"
ints: 1
ints: 0
type: INTS
}
}
node {
input: "7"
output: "8"
op_type: "Squeeze"
attribute {
name: "axes"
ints: 1
type: INTS
}
}
node {
input: "8"
output: "9"
op_type: "Cast"
attribute {
name: "to"
i: 7
type: INT
}
}
name: "torch-jit-export"
input {
name: "0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "9"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 5
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@ -713,6 +713,14 @@ class TestOperators(TestCase):
inputs = (scores, bbox_deltas, im_info, anchors)
self.assertONNX(model, inputs)
def test_dyn_arange(self):
class TestModel(torch.nn.Module):
def forward(self, input):
return torch.arange(input.shape[0])
input = torch.randn(5, 3, 2)
self.assertONNX(TestModel(), input)
def test_layer_norm_aten(self):
model = torch.nn.LayerNorm([10, 10])
x = torch.randn(20, 5, 10, 10)

View File

@ -19,9 +19,33 @@ from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_S
import model_defs.word_language_model as word_language_model
def ort_test_with_input(ort_sess, input, output, rtol, atol):
input, _ = torch.jit._flatten(input)
output, _ = torch.jit._flatten(output)
def to_numpy(tensor):
if tensor.requires_grad:
return tensor.detach().cpu().numpy()
else:
return tensor.cpu().numpy()
inputs = list(map(to_numpy, input))
outputs = list(map(to_numpy, output))
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs))
ort_outs = ort_sess.run(None, ort_inputs)
# compare onnxruntime and PyTorch results
assert len(outputs) == len(ort_outs), "number of outputs differ"
# compare onnxruntime and PyTorch results
[np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
def run_model_test(self, model, batch_size=2, state_dict=None,
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
example_outputs=None, do_constant_folding=True):
example_outputs=None, do_constant_folding=True,
dynamic_axes=None, test_with_inputs=None):
model.eval()
if input is None:
@ -40,31 +64,25 @@ def run_model_test(self, model, batch_size=2, state_dict=None,
opset_version=self.opset_version,
example_outputs=output,
do_constant_folding=do_constant_folding,
keep_initializers_as_inputs=self.keep_initializers_as_inputs)
input, _ = torch.jit._flatten(input)
output, _ = torch.jit._flatten(output)
def to_numpy(tensor):
if tensor.requires_grad:
return tensor.detach().cpu().numpy()
else:
return tensor.cpu().numpy()
inputs = list(map(to_numpy, input))
outputs = list(map(to_numpy, output))
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
dynamic_axes=dynamic_axes)
# compute onnxruntime output prediction
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs))
ort_outs = ort_sess.run(None, ort_inputs)
ort_test_with_input(ort_sess, input, output, rtol, atol)
# compare onnxruntime and PyTorch results
assert len(outputs) == len(ort_outs), "number of outputs differ"
# if addiional test inputs are provided run the onnx
# model with these inputs and check the outputs
if test_with_inputs is not None:
for test_input in test_with_inputs:
if isinstance(test_input, torch.Tensor):
test_input = (test_input,)
output = model(*test_input)
if isinstance(output, torch.Tensor):
output = (output,)
# compare onnxruntime and PyTorch results
[np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
ort_test_with_input(ort_sess, test_input, output, rtol, atol)
class TestONNXRuntime(unittest.TestCase):
@ -78,10 +96,12 @@ class TestONNXRuntime(unittest.TestCase):
torch.cuda.manual_seed_all(0)
np.random.seed(seed=0)
def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True, batch_size=2, use_gpu=True):
def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True,
batch_size=2, use_gpu=True, dynamic_axes=None, test_with_inputs=None):
run_model_test(self, model, batch_size=batch_size,
input=input, use_gpu=use_gpu, rtol=rtol, atol=atol,
do_constant_folding=do_constant_folding)
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes, test_with_inputs=test_with_inputs)
def run_word_language_model(self, model_name):
ntokens = 50
@ -280,6 +300,21 @@ class TestONNXRuntime(unittest.TestCase):
x = torch.rand(5, 5, 5)
self.run_test(DynamicSliceExportMod(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_arange(self):
class ArangeModel(torch.nn.Module):
def forward(self, input):
return torch.arange(x.shape[0]), \
torch.arange(12), \
torch.arange(start=x.shape[0], end=x.shape[0] + 5)
x = torch.randn(5, 3, 2)
y = torch.randn(8, 3, 2)
self.run_test(ArangeModel(), x, test_with_inputs=[y],
dynamic_axes={'input_1': [1],
'output_1': [0],
'output_2': [0]})
def _test_index_generic(self, fn):
class MyModel(torch.nn.Module):
def __init__(self):

View File

@ -1553,6 +1553,40 @@ graph(%Ra, %Rb):
def test_trace_size_with_grad(self):
self.do_trace_size(True)
def do_trace_arange(self, requires_grad):
def arange(x):
return torch.arange(x.shape[0])
def arange_scalar(x):
return torch.arange(12)
def arange_start_end(x):
return torch.arange(start=x.shape[0], end=x.shape[0] + 5)
x = torch.randn(5, 3, 2, requires_grad=requires_grad)
y = torch.randn(8, 2, 4, requires_grad=requires_grad)
# Check that it behaves as expected
traced_arange = torch.jit.trace(arange, x)
self.assertEqual(traced_arange(y), arange(y))
self.assertEqual(traced_arange(x), arange(x))
traced_arange_scalar = torch.jit.trace(arange_scalar, x)
self.assertEqual(traced_arange_scalar(y), arange_scalar(y))
self.assertEqual(traced_arange_scalar(x), arange_scalar(x))
traced_arange_start_end = torch.jit.trace(arange_start_end, x)
self.assertEqual(traced_arange_start_end(y), arange_start_end(y))
self.assertEqual(traced_arange_start_end(x), arange_start_end(x))
def test_trace_arange(self):
self.do_trace_arange(False)
# test the different graph_executor path that happens when
# gradients are required and sizes are involved
def test_trace_arange_with_grad(self):
self.do_trace_arange(True)
def test_trace_casts(self):
casts = [
lambda x: x.byte(),
@ -12018,13 +12052,6 @@ a")
FooMod(), (torch.rand(3, 4),), f,
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
def test_trace_checker_arange_as_constant(self):
with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
@_trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 5),)])
def foo(x):
y = torch.arange(0, x.shape[0]).double()
return x + y.unsqueeze(1)
@suppress_warnings
def test_trace_checker_dot_data(self):
with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Tensor-valued Constant nodes differed in value '

View File

@ -112,7 +112,7 @@ static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* k
static PythonArgParser parser({
"arange(Scalar end, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
"arange(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
});
}, /*traceable=*/true);
ParsedArgs<9> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);