Revert "[inductor] fix crash issue when input is a view tensor (#90150)"

This reverts commit b11ec270ba.

Reverted https://github.com/pytorch/pytorch/pull/90150 on behalf of https://github.com/clee2000 due to failing test_inplace_unsqueeze3 (__main__.CPUReproTests) https://github.com/pytorch/pytorch/actions/runs/4074618739/jobs/7020199369 b11ec270ba, marking as landrace cuz all jobs are green on pr
This commit is contained in:
PyTorch MergeBot 2023-02-02 17:06:34 +00:00
parent 769eca6f97
commit 5d259425fc
9 changed files with 4 additions and 140 deletions

View File

@ -6044,78 +6044,6 @@ if HAS_CPU:
if simdlen != 1:
assert metrics.generated_cpp_vec_kernel_count == 1
def test_inplace_unsqueeze(self):
@torch._dynamo.optimize("inductor")
def fn(a):
unsqueeze_ = torch.ops.aten.unsqueeze_.default(a, 0)
return unsqueeze_
for dynamic_shapes in [True, False]:
args = [
(
(1, 1, 1, 12, 11, 3),
(396, 396, 396, 33, 3, 1),
torch.int64,
"cpu",
)
]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
config.dynamic_shapes = dynamic_shapes
torch._dynamo.config.dynamic_shapes = dynamic_shapes
with torch.no_grad():
out = fn(*args)
assert args[0].shape == (1, 1, 1, 1, 12, 11, 3)
assert args[0].stride() == (396, 396, 396, 396, 33, 3, 1)
assert out.equal(args[0])
def test_inplace_unsqueeze2(self):
@torch._dynamo.optimize("inductor")
def fn(a):
unsqueeze_ = torch.ops.aten.unsqueeze_.default(a, 0)
res = unsqueeze_ + 1
return res
for dynamic_shapes in [True, False]:
args = [
(
(1, 1, 1, 12, 11, 3),
(396, 396, 396, 33, 3, 1),
torch.int64,
"cpu",
)
]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
config.dynamic_shapes = dynamic_shapes
torch._dynamo.config.dynamic_shapes = dynamic_shapes
with torch.no_grad():
out = fn(*args)
assert args[0].shape == (1, 1, 1, 1, 12, 11, 3)
assert args[0].stride() == (396, 396, 396, 396, 33, 3, 1)
assert out.equal(args[0] + 1)
def test_inplace_unsqueeze3(self):
@torch._dynamo.optimize("inductor")
def fn(a):
torch.ops.aten.unsqueeze_.default(a, 0)
return 0
for dynamic_shapes in [True, False]:
args = [
(
(1, 1, 1, 12, 11, 3),
(396, 396, 396, 33, 3, 1),
torch.int64,
"cpu",
)
]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
config.dynamic_shapes = dynamic_shapes
torch._dynamo.config.dynamic_shapes = dynamic_shapes
with torch.no_grad():
fn(*args)
assert args[0].shape == (1, 1, 1, 1, 12, 11, 3)
assert args[0].stride() == (396, 396, 396, 396, 33, 3, 1)
if HAS_CUDA and not TEST_WITH_ASAN:
import triton

View File

@ -142,44 +142,6 @@ class GraphArg:
assert isinstance(
self.fake_tensor, torch._subclasses.fake_tensor.FakeTensor
)
# For inplace ops changing the input's shape (unsqueeze_)
if not config.dynamic_shapes and (
self.fake_tensor.shape != self.example.shape
or self.fake_tensor.stride() != self.example.stride()
):
converter = torch._subclasses.fake_tensor.FakeTensorConverter()
self.fake_tensor = converter.from_real_tensor(
self.fake_tensor.fake_mode, self.example
)
elif config.dynamic_shapes:
(
size,
stride,
_,
) = self.fake_tensor.fake_mode.shape_env.create_symbolic_sizes_strides_storage_offset(
self.example, self.source
)
if (
torch.Size(size) != self.fake_tensor.shape
or tuple(stride) != self.fake_tensor.stride()
):
self.fake_tensor.fake_mode.converter = (
torch._subclasses.fake_tensor.FakeTensorConverter()
)
self.fake_tensor.fake_mode.shape_env = (
torch.fx.experimental.symbolic_shapes.ShapeEnv()
)
ignore_subclass = (
True
if type(self.example) in config.traceable_tensor_subclasses
else False
)
self.fake_tensor = self.fake_tensor.fake_mode.from_tensor(
self.example.clone(),
static_shapes=False,
ignore_subclass=ignore_subclass,
source=self.source,
)
return [self.fake_tensor]
def __len__(self):

View File

@ -1049,10 +1049,7 @@ class AOTConfig:
def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
# flat_args is used by make_fx and aot_config.fw_compiler
# clone flat_args to avoid flat_args shape changed by inplace ops (unsqueeze_)
tmp_flat_args = [torch._prims_common.clone_preserve_strides(x) for x in flat_args]
fw_module = make_fx(flat_fn, aot_config.decompositions)(*tmp_flat_args)
fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args)
if config.debug_graphs:
log.debug(f"====== Forward (only) graph {aot_config.aot_id} ======")
log.debug(fw_module.print_readable(print_output=False))

View File

@ -509,10 +509,6 @@ class WrapperCodeGen(CodeGen):
# these lines will be pointless
self.lines.pop()
for name, value in V.graph.graph_inputs.items():
if isinstance(value.data, ir.ReinterpretView):
self.wrapper_call.writeline(value.data.codegen_reference_mutation())
# codegen allocations in two passes
planning_state = MemoryPlanningState()
for i in range(len(self.lines)):
@ -579,8 +575,6 @@ class WrapperCodeGen(CodeGen):
)
for name, value in V.graph.graph_inputs.items():
if isinstance(value.data, ir.ReinterpretView):
value = value.data.data
shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
add_fake_input(

View File

@ -366,8 +366,6 @@ class GraphLowering(torch.fx.Interpreter):
value.realize()
assert isinstance(value, TensorBox)
value = value.data
if isinstance(value, ir.ReinterpretView):
continue
assert isinstance(value, ir.StorageBox)
value_storage_box = value
value = value.data

View File

@ -1470,14 +1470,6 @@ class ReinterpretView(BaseView):
return f"{as_strided}({self.get_name()}, {size}, {stride}, {offset})"
return f"{as_strided}({self.get_name()}, {size}, {stride})"
def codegen_reference_mutation(self):
size = V.graph.sizevars.codegen_shape_tuple(self.layout.size)
stride = V.graph.sizevars.codegen_shape_tuple(self.layout.stride)
offset = V.graph.sizevars.codegen_sizevar(self.layout.offset)
if offset != "0":
return f"{self.get_name()}.as_strided_({size}, {stride}, {offset})"
return f"{self.get_name()}.as_strided_({size}, {stride})"
class SliceView(View):
@classmethod

View File

@ -1016,9 +1016,8 @@ class Scheduler:
V.graph.wrapper_code.codegen_free(node.node)
elif name in V.graph.graph_inputs:
storage = V.graph.graph_inputs[name].data
if not isinstance(storage, ir.ReinterpretView):
assert storage.is_input_buffer()
V.graph.wrapper_code.codegen_free(storage.data)
assert storage.is_input_buffer()
V.graph.wrapper_code.codegen_free(storage.data)
self.buffer_names_to_free.clear()

View File

@ -448,8 +448,6 @@ class SizeVarAllocator(object):
needed = set(self.var_to_val.keys()) - set(self.replacements.keys())
for name, value in graph_inputs.items():
if isinstance(value.data, ir.ReinterpretView):
value = value.data.data
shapes = value.get_size()
for dim, shape in enumerate(shapes):
shape = self.simplify(shape)
@ -460,8 +458,6 @@ class SizeVarAllocator(object):
)
for name, value in graph_inputs.items():
if isinstance(value.data, ir.ReinterpretView):
value = value.data.data
shapes = value.get_stride()
for dim, shape in enumerate(shapes):
shape = self.simplify(shape)

View File

@ -182,6 +182,4 @@ class ShapeProp(torch.fx.Interpreter):
Returns:
Any: The value returned from executing the Module
"""
# clone inputs to avoid side effects caused by inplace ops during run_node
new_args = [torch._prims_common.clone_preserve_strides(x) for x in args]
return super().run(*new_args)
return super().run(*args)