From 5d259425fcff9c6eb4032f63aa33ab58d24aff85 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 2 Feb 2023 17:06:34 +0000 Subject: [PATCH] Revert "[inductor] fix crash issue when input is a view tensor (#90150)" This reverts commit b11ec270bad96bf6078564ec4b2dc5dc69ea5bfa. Reverted https://github.com/pytorch/pytorch/pull/90150 on behalf of https://github.com/clee2000 due to failing test_inplace_unsqueeze3 (__main__.CPUReproTests) https://github.com/pytorch/pytorch/actions/runs/4074618739/jobs/7020199369 https://hud.pytorch.org/pytorch/pytorch/commit/b11ec270bad96bf6078564ec4b2dc5dc69ea5bfa, marking as landrace cuz all jobs are green on pr --- test/inductor/test_torchinductor.py | 72 ----------------------------- torch/_dynamo/variables/builder.py | 38 --------------- torch/_functorch/aot_autograd.py | 5 +- torch/_inductor/codegen/wrapper.py | 6 --- torch/_inductor/graph.py | 2 - torch/_inductor/ir.py | 8 ---- torch/_inductor/scheduler.py | 5 +- torch/_inductor/sizevars.py | 4 -- torch/fx/passes/shape_prop.py | 4 +- 9 files changed, 4 insertions(+), 140 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index c596f883a38..4fa7dc360f0 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -6044,78 +6044,6 @@ if HAS_CPU: if simdlen != 1: assert metrics.generated_cpp_vec_kernel_count == 1 - def test_inplace_unsqueeze(self): - @torch._dynamo.optimize("inductor") - def fn(a): - unsqueeze_ = torch.ops.aten.unsqueeze_.default(a, 0) - return unsqueeze_ - - for dynamic_shapes in [True, False]: - args = [ - ( - (1, 1, 1, 12, 11, 3), - (396, 396, 396, 33, 3, 1), - torch.int64, - "cpu", - ) - ] - args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args] - config.dynamic_shapes = dynamic_shapes - torch._dynamo.config.dynamic_shapes = dynamic_shapes - with torch.no_grad(): - out = fn(*args) - assert args[0].shape == (1, 1, 1, 1, 12, 11, 3) - assert args[0].stride() == (396, 396, 396, 396, 33, 3, 1) - assert out.equal(args[0]) - - def test_inplace_unsqueeze2(self): - @torch._dynamo.optimize("inductor") - def fn(a): - unsqueeze_ = torch.ops.aten.unsqueeze_.default(a, 0) - res = unsqueeze_ + 1 - return res - - for dynamic_shapes in [True, False]: - args = [ - ( - (1, 1, 1, 12, 11, 3), - (396, 396, 396, 33, 3, 1), - torch.int64, - "cpu", - ) - ] - args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args] - config.dynamic_shapes = dynamic_shapes - torch._dynamo.config.dynamic_shapes = dynamic_shapes - with torch.no_grad(): - out = fn(*args) - assert args[0].shape == (1, 1, 1, 1, 12, 11, 3) - assert args[0].stride() == (396, 396, 396, 396, 33, 3, 1) - assert out.equal(args[0] + 1) - - def test_inplace_unsqueeze3(self): - @torch._dynamo.optimize("inductor") - def fn(a): - torch.ops.aten.unsqueeze_.default(a, 0) - return 0 - - for dynamic_shapes in [True, False]: - args = [ - ( - (1, 1, 1, 12, 11, 3), - (396, 396, 396, 33, 3, 1), - torch.int64, - "cpu", - ) - ] - args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args] - config.dynamic_shapes = dynamic_shapes - torch._dynamo.config.dynamic_shapes = dynamic_shapes - with torch.no_grad(): - fn(*args) - assert args[0].shape == (1, 1, 1, 1, 12, 11, 3) - assert args[0].stride() == (396, 396, 396, 396, 33, 3, 1) - if HAS_CUDA and not TEST_WITH_ASAN: import triton diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 149b0d7cba3..16c57e2d7c0 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -142,44 +142,6 @@ class GraphArg: assert isinstance( self.fake_tensor, torch._subclasses.fake_tensor.FakeTensor ) - # For inplace ops changing the input's shape (unsqueeze_) - if not config.dynamic_shapes and ( - self.fake_tensor.shape != self.example.shape - or self.fake_tensor.stride() != self.example.stride() - ): - converter = torch._subclasses.fake_tensor.FakeTensorConverter() - self.fake_tensor = converter.from_real_tensor( - self.fake_tensor.fake_mode, self.example - ) - elif config.dynamic_shapes: - ( - size, - stride, - _, - ) = self.fake_tensor.fake_mode.shape_env.create_symbolic_sizes_strides_storage_offset( - self.example, self.source - ) - if ( - torch.Size(size) != self.fake_tensor.shape - or tuple(stride) != self.fake_tensor.stride() - ): - self.fake_tensor.fake_mode.converter = ( - torch._subclasses.fake_tensor.FakeTensorConverter() - ) - self.fake_tensor.fake_mode.shape_env = ( - torch.fx.experimental.symbolic_shapes.ShapeEnv() - ) - ignore_subclass = ( - True - if type(self.example) in config.traceable_tensor_subclasses - else False - ) - self.fake_tensor = self.fake_tensor.fake_mode.from_tensor( - self.example.clone(), - static_shapes=False, - ignore_subclass=ignore_subclass, - source=self.source, - ) return [self.fake_tensor] def __len__(self): diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index eca646e2ac7..c8b16dc4450 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -1049,10 +1049,7 @@ class AOTConfig: def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig): - # flat_args is used by make_fx and aot_config.fw_compiler - # clone flat_args to avoid flat_args shape changed by inplace ops (unsqueeze_) - tmp_flat_args = [torch._prims_common.clone_preserve_strides(x) for x in flat_args] - fw_module = make_fx(flat_fn, aot_config.decompositions)(*tmp_flat_args) + fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args) if config.debug_graphs: log.debug(f"====== Forward (only) graph {aot_config.aot_id} ======") log.debug(fw_module.print_readable(print_output=False)) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index c43681144b3..965295a70af 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -509,10 +509,6 @@ class WrapperCodeGen(CodeGen): # these lines will be pointless self.lines.pop() - for name, value in V.graph.graph_inputs.items(): - if isinstance(value.data, ir.ReinterpretView): - self.wrapper_call.writeline(value.data.codegen_reference_mutation()) - # codegen allocations in two passes planning_state = MemoryPlanningState() for i in range(len(self.lines)): @@ -579,8 +575,6 @@ class WrapperCodeGen(CodeGen): ) for name, value in V.graph.graph_inputs.items(): - if isinstance(value.data, ir.ReinterpretView): - value = value.data.data shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()] stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()] add_fake_input( diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 76e17dd5676..8f3c75bb6fd 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -366,8 +366,6 @@ class GraphLowering(torch.fx.Interpreter): value.realize() assert isinstance(value, TensorBox) value = value.data - if isinstance(value, ir.ReinterpretView): - continue assert isinstance(value, ir.StorageBox) value_storage_box = value value = value.data diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index eb05f75e925..46e1c031916 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -1470,14 +1470,6 @@ class ReinterpretView(BaseView): return f"{as_strided}({self.get_name()}, {size}, {stride}, {offset})" return f"{as_strided}({self.get_name()}, {size}, {stride})" - def codegen_reference_mutation(self): - size = V.graph.sizevars.codegen_shape_tuple(self.layout.size) - stride = V.graph.sizevars.codegen_shape_tuple(self.layout.stride) - offset = V.graph.sizevars.codegen_sizevar(self.layout.offset) - if offset != "0": - return f"{self.get_name()}.as_strided_({size}, {stride}, {offset})" - return f"{self.get_name()}.as_strided_({size}, {stride})" - class SliceView(View): @classmethod diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 1e170887dc3..dbd060f922e 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -1016,9 +1016,8 @@ class Scheduler: V.graph.wrapper_code.codegen_free(node.node) elif name in V.graph.graph_inputs: storage = V.graph.graph_inputs[name].data - if not isinstance(storage, ir.ReinterpretView): - assert storage.is_input_buffer() - V.graph.wrapper_code.codegen_free(storage.data) + assert storage.is_input_buffer() + V.graph.wrapper_code.codegen_free(storage.data) self.buffer_names_to_free.clear() diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 18d6ed33907..146f7e48cad 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -448,8 +448,6 @@ class SizeVarAllocator(object): needed = set(self.var_to_val.keys()) - set(self.replacements.keys()) for name, value in graph_inputs.items(): - if isinstance(value.data, ir.ReinterpretView): - value = value.data.data shapes = value.get_size() for dim, shape in enumerate(shapes): shape = self.simplify(shape) @@ -460,8 +458,6 @@ class SizeVarAllocator(object): ) for name, value in graph_inputs.items(): - if isinstance(value.data, ir.ReinterpretView): - value = value.data.data shapes = value.get_stride() for dim, shape in enumerate(shapes): shape = self.simplify(shape) diff --git a/torch/fx/passes/shape_prop.py b/torch/fx/passes/shape_prop.py index a7e3aed9e9f..2cc11dbd4cd 100644 --- a/torch/fx/passes/shape_prop.py +++ b/torch/fx/passes/shape_prop.py @@ -182,6 +182,4 @@ class ShapeProp(torch.fx.Interpreter): Returns: Any: The value returned from executing the Module """ - # clone inputs to avoid side effects caused by inplace ops during run_node - new_args = [torch._prims_common.clone_preserve_strides(x) for x in args] - return super().run(*new_args) + return super().run(*args)