Dont optimize slicing dispatch when we are tracing (#11156)

Summary:
Previously when we had a slicing expression like `x[0:5, 0]`, where the sliced tensor was of size `5` in dimension 0, we would skip dispatching the actual slice call as an optimization.

This caused incorrect behavior under tracing, as we would not record the slice op and thus if we encountered an input with a different shape while running the trace, we would get incorrect results.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11156

Differential Revision: D9622252

Pulled By: jamesr66a

fbshipit-source-id: 822f2e8f01504e131f53bd9ef51c171c7913a7cc
This commit is contained in:
James Reed 2018-09-01 17:06:03 -07:00 committed by Facebook Github Bot
parent b3d559cdd1
commit 43e73f85ad
2 changed files with 13 additions and 1 deletions

View File

@ -1528,6 +1528,14 @@ class TestJit(JitTestCase):
self.assertExpected(torch.onnx.export_to_pretty_string(
Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f))
def test_trace_slice_full_dim(self):
def foo(x):
return x[0:5, 0] + 1.0
traced = torch.jit.trace(foo, (torch.rand(5, 4),))
test_x = torch.rand(6, 3)
self.assertEqual(foo(test_x), traced(test_x))
class TestBatched(TestCase):
# generate random examples and create an batchtensor with them

View File

@ -11,6 +11,7 @@
#include "torch/csrc/utils/python_numbers.h"
#include "torch/csrc/utils/tensor_new.h"
#include "torch/csrc/utils/tensor_conversion_dispatch.h"
#include "torch/csrc/jit/tracer.h"
#include <ATen/DeviceGuard.h>
#include <ATen/ExpandUtils.h>
@ -79,7 +80,10 @@ static Variable applySlice(const Variable& self, int64_t dim, PyObject* slice, b
// TODO: implement negative step
throw ValueError("negative step not yet supported");
}
if (!ensure_view && start == 0 && stop == length && step == 1) {
// Skip this optimization if we are tracing, as the trace may be polymorphic
// over the shape of the `self` tensor, and we still want to record
// the slice.
if (!ensure_view && start == 0 && stop == length && step == 1 && !jit::tracer::isTracing()) {
return self;
}
return self.slice(dim, start, stop, step);