mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix torch.arange traced as constant (#25363)
Summary:
torch.arange is always traced as a constant which makes it impossible to trace correctly TestModel() from the example below.
class TestModel(torch.nn.Module):
def forward(self, input):
return torch.arange(input.shape[0])
input = torch.randn(5,3,2)
print(torch.jit.trace(TestModel(), input).graph)
Currently the trace of TestModel() looks like:
graph(%self : ClassType<TestModel>,
%input : Float(5, 3, 2)):
%11 : int = prim::Constant[value=5]()
%12 : int = prim::Constant[value=4]()
%13 : int = prim::Constant[value=0]()
%14 : Device = prim::Constant[value="cpu"]()
%15 : bool = prim::Constant[value=0]()
%16 : Long(5) = aten::arange(%11, %12, %13, %14, %15)
return (%16)
This PR will allow the trace to have a variable value for %11.
The trace of TestModel() with this PR's modifs looks like:
graph(%self : ClassType<TestModel>,
%input : Float(5, 3, 2)):
%2 : int = prim::Constant[value=0]()
%3 : int = aten::size(%input, %2)
%4 : Long() = prim::NumToTensor(%3)
%11 : Scalar = prim::ImplicitTensorToNum(%4)
%12 : int = prim::Constant[value=4]()
%13 : int = prim::Constant[value=0]()
%14 : Device = prim::Constant[value="cpu"]()
%15 : bool = prim::Constant[value=0]()
%16 : Long(5) = aten::arange(%11, %12, %13, %14, %15)
return (%16)
More info : https://github.com/pytorch/pytorch/issues/20075
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25363
Reviewed By: zrphercule
Differential Revision: D17301934
Pulled By: houseroad
fbshipit-source-id: d9907763742cb51d8c761bf63fc2e4918f7b9941
This commit is contained in:
parent
62767077c3
commit
8ca93ec351
129
test/onnx/expect/TestOperators.test_dyn_arange.expect
Normal file
129
test/onnx/expect/TestOperators.test_dyn_arange.expect
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.2"
|
||||
graph {
|
||||
node {
|
||||
output: "1"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
data_type: 7
|
||||
raw_data: "\000\000\000\000\000\000\000\000"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "0"
|
||||
output: "2"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
input: "2"
|
||||
input: "1"
|
||||
output: "3"
|
||||
op_type: "Gather"
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "3"
|
||||
output: "4"
|
||||
op_type: "Unsqueeze"
|
||||
attribute {
|
||||
name: "axes"
|
||||
ints: 0
|
||||
type: INTS
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "4"
|
||||
output: "5"
|
||||
op_type: "ConstantOfShape"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
raw_data: "\001\000\000\000\000\000\000\000"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "5"
|
||||
output: "6"
|
||||
op_type: "NonZero"
|
||||
}
|
||||
node {
|
||||
input: "6"
|
||||
output: "7"
|
||||
op_type: "Transpose"
|
||||
attribute {
|
||||
name: "perm"
|
||||
ints: 1
|
||||
ints: 0
|
||||
type: INTS
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "7"
|
||||
output: "8"
|
||||
op_type: "Squeeze"
|
||||
attribute {
|
||||
name: "axes"
|
||||
ints: 1
|
||||
type: INTS
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "8"
|
||||
output: "9"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
i: 7
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "torch-jit-export"
|
||||
input {
|
||||
name: "0"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "9"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 9
|
||||
}
|
||||
|
|
@ -713,6 +713,14 @@ class TestOperators(TestCase):
|
|||
inputs = (scores, bbox_deltas, im_info, anchors)
|
||||
self.assertONNX(model, inputs)
|
||||
|
||||
def test_dyn_arange(self):
|
||||
class TestModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return torch.arange(input.shape[0])
|
||||
|
||||
input = torch.randn(5, 3, 2)
|
||||
self.assertONNX(TestModel(), input)
|
||||
|
||||
def test_layer_norm_aten(self):
|
||||
model = torch.nn.LayerNorm([10, 10])
|
||||
x = torch.randn(20, 5, 10, 10)
|
||||
|
|
|
|||
|
|
@ -19,9 +19,33 @@ from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_S
|
|||
import model_defs.word_language_model as word_language_model
|
||||
|
||||
|
||||
def ort_test_with_input(ort_sess, input, output, rtol, atol):
|
||||
input, _ = torch.jit._flatten(input)
|
||||
output, _ = torch.jit._flatten(output)
|
||||
|
||||
def to_numpy(tensor):
|
||||
if tensor.requires_grad:
|
||||
return tensor.detach().cpu().numpy()
|
||||
else:
|
||||
return tensor.cpu().numpy()
|
||||
|
||||
inputs = list(map(to_numpy, input))
|
||||
outputs = list(map(to_numpy, output))
|
||||
|
||||
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs))
|
||||
ort_outs = ort_sess.run(None, ort_inputs)
|
||||
|
||||
# compare onnxruntime and PyTorch results
|
||||
assert len(outputs) == len(ort_outs), "number of outputs differ"
|
||||
|
||||
# compare onnxruntime and PyTorch results
|
||||
[np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
|
||||
|
||||
|
||||
def run_model_test(self, model, batch_size=2, state_dict=None,
|
||||
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
|
||||
example_outputs=None, do_constant_folding=True):
|
||||
example_outputs=None, do_constant_folding=True,
|
||||
dynamic_axes=None, test_with_inputs=None):
|
||||
model.eval()
|
||||
|
||||
if input is None:
|
||||
|
|
@ -40,31 +64,25 @@ def run_model_test(self, model, batch_size=2, state_dict=None,
|
|||
opset_version=self.opset_version,
|
||||
example_outputs=output,
|
||||
do_constant_folding=do_constant_folding,
|
||||
keep_initializers_as_inputs=self.keep_initializers_as_inputs)
|
||||
|
||||
input, _ = torch.jit._flatten(input)
|
||||
output, _ = torch.jit._flatten(output)
|
||||
|
||||
def to_numpy(tensor):
|
||||
if tensor.requires_grad:
|
||||
return tensor.detach().cpu().numpy()
|
||||
else:
|
||||
return tensor.cpu().numpy()
|
||||
|
||||
inputs = list(map(to_numpy, input))
|
||||
outputs = list(map(to_numpy, output))
|
||||
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
|
||||
dynamic_axes=dynamic_axes)
|
||||
|
||||
|
||||
# compute onnxruntime output prediction
|
||||
ort_sess = onnxruntime.InferenceSession(f.getvalue())
|
||||
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs))
|
||||
ort_outs = ort_sess.run(None, ort_inputs)
|
||||
ort_test_with_input(ort_sess, input, output, rtol, atol)
|
||||
|
||||
# compare onnxruntime and PyTorch results
|
||||
assert len(outputs) == len(ort_outs), "number of outputs differ"
|
||||
# if addiional test inputs are provided run the onnx
|
||||
# model with these inputs and check the outputs
|
||||
if test_with_inputs is not None:
|
||||
for test_input in test_with_inputs:
|
||||
if isinstance(test_input, torch.Tensor):
|
||||
test_input = (test_input,)
|
||||
output = model(*test_input)
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = (output,)
|
||||
|
||||
# compare onnxruntime and PyTorch results
|
||||
[np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
|
||||
ort_test_with_input(ort_sess, test_input, output, rtol, atol)
|
||||
|
||||
|
||||
class TestONNXRuntime(unittest.TestCase):
|
||||
|
|
@ -78,10 +96,12 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
torch.cuda.manual_seed_all(0)
|
||||
np.random.seed(seed=0)
|
||||
|
||||
def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True, batch_size=2, use_gpu=True):
|
||||
def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True,
|
||||
batch_size=2, use_gpu=True, dynamic_axes=None, test_with_inputs=None):
|
||||
run_model_test(self, model, batch_size=batch_size,
|
||||
input=input, use_gpu=use_gpu, rtol=rtol, atol=atol,
|
||||
do_constant_folding=do_constant_folding)
|
||||
do_constant_folding=do_constant_folding,
|
||||
dynamic_axes=dynamic_axes, test_with_inputs=test_with_inputs)
|
||||
|
||||
def run_word_language_model(self, model_name):
|
||||
ntokens = 50
|
||||
|
|
@ -280,6 +300,21 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
x = torch.rand(5, 5, 5)
|
||||
self.run_test(DynamicSliceExportMod(), x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_arange(self):
|
||||
class ArangeModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return torch.arange(x.shape[0]), \
|
||||
torch.arange(12), \
|
||||
torch.arange(start=x.shape[0], end=x.shape[0] + 5)
|
||||
|
||||
x = torch.randn(5, 3, 2)
|
||||
y = torch.randn(8, 3, 2)
|
||||
self.run_test(ArangeModel(), x, test_with_inputs=[y],
|
||||
dynamic_axes={'input_1': [1],
|
||||
'output_1': [0],
|
||||
'output_2': [0]})
|
||||
|
||||
def _test_index_generic(self, fn):
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -1553,6 +1553,40 @@ graph(%Ra, %Rb):
|
|||
def test_trace_size_with_grad(self):
|
||||
self.do_trace_size(True)
|
||||
|
||||
def do_trace_arange(self, requires_grad):
|
||||
def arange(x):
|
||||
return torch.arange(x.shape[0])
|
||||
|
||||
def arange_scalar(x):
|
||||
return torch.arange(12)
|
||||
|
||||
def arange_start_end(x):
|
||||
return torch.arange(start=x.shape[0], end=x.shape[0] + 5)
|
||||
|
||||
x = torch.randn(5, 3, 2, requires_grad=requires_grad)
|
||||
y = torch.randn(8, 2, 4, requires_grad=requires_grad)
|
||||
|
||||
# Check that it behaves as expected
|
||||
traced_arange = torch.jit.trace(arange, x)
|
||||
self.assertEqual(traced_arange(y), arange(y))
|
||||
self.assertEqual(traced_arange(x), arange(x))
|
||||
|
||||
traced_arange_scalar = torch.jit.trace(arange_scalar, x)
|
||||
self.assertEqual(traced_arange_scalar(y), arange_scalar(y))
|
||||
self.assertEqual(traced_arange_scalar(x), arange_scalar(x))
|
||||
|
||||
traced_arange_start_end = torch.jit.trace(arange_start_end, x)
|
||||
self.assertEqual(traced_arange_start_end(y), arange_start_end(y))
|
||||
self.assertEqual(traced_arange_start_end(x), arange_start_end(x))
|
||||
|
||||
def test_trace_arange(self):
|
||||
self.do_trace_arange(False)
|
||||
|
||||
# test the different graph_executor path that happens when
|
||||
# gradients are required and sizes are involved
|
||||
def test_trace_arange_with_grad(self):
|
||||
self.do_trace_arange(True)
|
||||
|
||||
def test_trace_casts(self):
|
||||
casts = [
|
||||
lambda x: x.byte(),
|
||||
|
|
@ -12018,13 +12052,6 @@ a")
|
|||
FooMod(), (torch.rand(3, 4),), f,
|
||||
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
|
||||
|
||||
def test_trace_checker_arange_as_constant(self):
|
||||
with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
|
||||
@_trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 5),)])
|
||||
def foo(x):
|
||||
y = torch.arange(0, x.shape[0]).double()
|
||||
return x + y.unsqueeze(1)
|
||||
|
||||
@suppress_warnings
|
||||
def test_trace_checker_dot_data(self):
|
||||
with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Tensor-valued Constant nodes differed in value '
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* k
|
|||
static PythonArgParser parser({
|
||||
"arange(Scalar end, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
|
||||
"arange(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
|
||||
});
|
||||
}, /*traceable=*/true);
|
||||
|
||||
ParsedArgs<9> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user