mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Dont optimize slicing dispatch when we are tracing (#11156)
Summary: Previously when we had a slicing expression like `x[0:5, 0]`, where the sliced tensor was of size `5` in dimension 0, we would skip dispatching the actual slice call as an optimization. This caused incorrect behavior under tracing, as we would not record the slice op and thus if we encountered an input with a different shape while running the trace, we would get incorrect results. Pull Request resolved: https://github.com/pytorch/pytorch/pull/11156 Differential Revision: D9622252 Pulled By: jamesr66a fbshipit-source-id: 822f2e8f01504e131f53bd9ef51c171c7913a7cc
This commit is contained in:
parent
b3d559cdd1
commit
43e73f85ad
|
|
@ -1528,6 +1528,14 @@ class TestJit(JitTestCase):
|
|||
self.assertExpected(torch.onnx.export_to_pretty_string(
|
||||
Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f))
|
||||
|
||||
def test_trace_slice_full_dim(self):
|
||||
def foo(x):
|
||||
return x[0:5, 0] + 1.0
|
||||
|
||||
traced = torch.jit.trace(foo, (torch.rand(5, 4),))
|
||||
test_x = torch.rand(6, 3)
|
||||
self.assertEqual(foo(test_x), traced(test_x))
|
||||
|
||||
|
||||
class TestBatched(TestCase):
|
||||
# generate random examples and create an batchtensor with them
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
#include "torch/csrc/utils/python_numbers.h"
|
||||
#include "torch/csrc/utils/tensor_new.h"
|
||||
#include "torch/csrc/utils/tensor_conversion_dispatch.h"
|
||||
#include "torch/csrc/jit/tracer.h"
|
||||
|
||||
#include <ATen/DeviceGuard.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
|
|
@ -79,7 +80,10 @@ static Variable applySlice(const Variable& self, int64_t dim, PyObject* slice, b
|
|||
// TODO: implement negative step
|
||||
throw ValueError("negative step not yet supported");
|
||||
}
|
||||
if (!ensure_view && start == 0 && stop == length && step == 1) {
|
||||
// Skip this optimization if we are tracing, as the trace may be polymorphic
|
||||
// over the shape of the `self` tensor, and we still want to record
|
||||
// the slice.
|
||||
if (!ensure_view && start == 0 && stop == length && step == 1 && !jit::tracer::isTracing()) {
|
||||
return self;
|
||||
}
|
||||
return self.slice(dim, start, stop, step);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user