mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "[inductor] fix crash issue when input is a view tensor (#90150)"
This reverts commitb11ec270ba. Reverted https://github.com/pytorch/pytorch/pull/90150 on behalf of https://github.com/clee2000 due to failing test_inplace_unsqueeze3 (__main__.CPUReproTests) https://github.com/pytorch/pytorch/actions/runs/4074618739/jobs/7020199369b11ec270ba, marking as landrace cuz all jobs are green on pr
This commit is contained in:
parent
769eca6f97
commit
5d259425fc
|
|
@ -6044,78 +6044,6 @@ if HAS_CPU:
|
|||
if simdlen != 1:
|
||||
assert metrics.generated_cpp_vec_kernel_count == 1
|
||||
|
||||
def test_inplace_unsqueeze(self):
|
||||
@torch._dynamo.optimize("inductor")
|
||||
def fn(a):
|
||||
unsqueeze_ = torch.ops.aten.unsqueeze_.default(a, 0)
|
||||
return unsqueeze_
|
||||
|
||||
for dynamic_shapes in [True, False]:
|
||||
args = [
|
||||
(
|
||||
(1, 1, 1, 12, 11, 3),
|
||||
(396, 396, 396, 33, 3, 1),
|
||||
torch.int64,
|
||||
"cpu",
|
||||
)
|
||||
]
|
||||
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
|
||||
config.dynamic_shapes = dynamic_shapes
|
||||
torch._dynamo.config.dynamic_shapes = dynamic_shapes
|
||||
with torch.no_grad():
|
||||
out = fn(*args)
|
||||
assert args[0].shape == (1, 1, 1, 1, 12, 11, 3)
|
||||
assert args[0].stride() == (396, 396, 396, 396, 33, 3, 1)
|
||||
assert out.equal(args[0])
|
||||
|
||||
def test_inplace_unsqueeze2(self):
|
||||
@torch._dynamo.optimize("inductor")
|
||||
def fn(a):
|
||||
unsqueeze_ = torch.ops.aten.unsqueeze_.default(a, 0)
|
||||
res = unsqueeze_ + 1
|
||||
return res
|
||||
|
||||
for dynamic_shapes in [True, False]:
|
||||
args = [
|
||||
(
|
||||
(1, 1, 1, 12, 11, 3),
|
||||
(396, 396, 396, 33, 3, 1),
|
||||
torch.int64,
|
||||
"cpu",
|
||||
)
|
||||
]
|
||||
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
|
||||
config.dynamic_shapes = dynamic_shapes
|
||||
torch._dynamo.config.dynamic_shapes = dynamic_shapes
|
||||
with torch.no_grad():
|
||||
out = fn(*args)
|
||||
assert args[0].shape == (1, 1, 1, 1, 12, 11, 3)
|
||||
assert args[0].stride() == (396, 396, 396, 396, 33, 3, 1)
|
||||
assert out.equal(args[0] + 1)
|
||||
|
||||
def test_inplace_unsqueeze3(self):
|
||||
@torch._dynamo.optimize("inductor")
|
||||
def fn(a):
|
||||
torch.ops.aten.unsqueeze_.default(a, 0)
|
||||
return 0
|
||||
|
||||
for dynamic_shapes in [True, False]:
|
||||
args = [
|
||||
(
|
||||
(1, 1, 1, 12, 11, 3),
|
||||
(396, 396, 396, 33, 3, 1),
|
||||
torch.int64,
|
||||
"cpu",
|
||||
)
|
||||
]
|
||||
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
|
||||
config.dynamic_shapes = dynamic_shapes
|
||||
torch._dynamo.config.dynamic_shapes = dynamic_shapes
|
||||
with torch.no_grad():
|
||||
fn(*args)
|
||||
assert args[0].shape == (1, 1, 1, 1, 12, 11, 3)
|
||||
assert args[0].stride() == (396, 396, 396, 396, 33, 3, 1)
|
||||
|
||||
|
||||
if HAS_CUDA and not TEST_WITH_ASAN:
|
||||
import triton
|
||||
|
|
|
|||
|
|
@ -142,44 +142,6 @@ class GraphArg:
|
|||
assert isinstance(
|
||||
self.fake_tensor, torch._subclasses.fake_tensor.FakeTensor
|
||||
)
|
||||
# For inplace ops changing the input's shape (unsqueeze_)
|
||||
if not config.dynamic_shapes and (
|
||||
self.fake_tensor.shape != self.example.shape
|
||||
or self.fake_tensor.stride() != self.example.stride()
|
||||
):
|
||||
converter = torch._subclasses.fake_tensor.FakeTensorConverter()
|
||||
self.fake_tensor = converter.from_real_tensor(
|
||||
self.fake_tensor.fake_mode, self.example
|
||||
)
|
||||
elif config.dynamic_shapes:
|
||||
(
|
||||
size,
|
||||
stride,
|
||||
_,
|
||||
) = self.fake_tensor.fake_mode.shape_env.create_symbolic_sizes_strides_storage_offset(
|
||||
self.example, self.source
|
||||
)
|
||||
if (
|
||||
torch.Size(size) != self.fake_tensor.shape
|
||||
or tuple(stride) != self.fake_tensor.stride()
|
||||
):
|
||||
self.fake_tensor.fake_mode.converter = (
|
||||
torch._subclasses.fake_tensor.FakeTensorConverter()
|
||||
)
|
||||
self.fake_tensor.fake_mode.shape_env = (
|
||||
torch.fx.experimental.symbolic_shapes.ShapeEnv()
|
||||
)
|
||||
ignore_subclass = (
|
||||
True
|
||||
if type(self.example) in config.traceable_tensor_subclasses
|
||||
else False
|
||||
)
|
||||
self.fake_tensor = self.fake_tensor.fake_mode.from_tensor(
|
||||
self.example.clone(),
|
||||
static_shapes=False,
|
||||
ignore_subclass=ignore_subclass,
|
||||
source=self.source,
|
||||
)
|
||||
return [self.fake_tensor]
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
|||
|
|
@ -1049,10 +1049,7 @@ class AOTConfig:
|
|||
|
||||
|
||||
def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
|
||||
# flat_args is used by make_fx and aot_config.fw_compiler
|
||||
# clone flat_args to avoid flat_args shape changed by inplace ops (unsqueeze_)
|
||||
tmp_flat_args = [torch._prims_common.clone_preserve_strides(x) for x in flat_args]
|
||||
fw_module = make_fx(flat_fn, aot_config.decompositions)(*tmp_flat_args)
|
||||
fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args)
|
||||
if config.debug_graphs:
|
||||
log.debug(f"====== Forward (only) graph {aot_config.aot_id} ======")
|
||||
log.debug(fw_module.print_readable(print_output=False))
|
||||
|
|
|
|||
|
|
@ -509,10 +509,6 @@ class WrapperCodeGen(CodeGen):
|
|||
# these lines will be pointless
|
||||
self.lines.pop()
|
||||
|
||||
for name, value in V.graph.graph_inputs.items():
|
||||
if isinstance(value.data, ir.ReinterpretView):
|
||||
self.wrapper_call.writeline(value.data.codegen_reference_mutation())
|
||||
|
||||
# codegen allocations in two passes
|
||||
planning_state = MemoryPlanningState()
|
||||
for i in range(len(self.lines)):
|
||||
|
|
@ -579,8 +575,6 @@ class WrapperCodeGen(CodeGen):
|
|||
)
|
||||
|
||||
for name, value in V.graph.graph_inputs.items():
|
||||
if isinstance(value.data, ir.ReinterpretView):
|
||||
value = value.data.data
|
||||
shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
|
||||
stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
|
||||
add_fake_input(
|
||||
|
|
|
|||
|
|
@ -366,8 +366,6 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
value.realize()
|
||||
assert isinstance(value, TensorBox)
|
||||
value = value.data
|
||||
if isinstance(value, ir.ReinterpretView):
|
||||
continue
|
||||
assert isinstance(value, ir.StorageBox)
|
||||
value_storage_box = value
|
||||
value = value.data
|
||||
|
|
|
|||
|
|
@ -1470,14 +1470,6 @@ class ReinterpretView(BaseView):
|
|||
return f"{as_strided}({self.get_name()}, {size}, {stride}, {offset})"
|
||||
return f"{as_strided}({self.get_name()}, {size}, {stride})"
|
||||
|
||||
def codegen_reference_mutation(self):
|
||||
size = V.graph.sizevars.codegen_shape_tuple(self.layout.size)
|
||||
stride = V.graph.sizevars.codegen_shape_tuple(self.layout.stride)
|
||||
offset = V.graph.sizevars.codegen_sizevar(self.layout.offset)
|
||||
if offset != "0":
|
||||
return f"{self.get_name()}.as_strided_({size}, {stride}, {offset})"
|
||||
return f"{self.get_name()}.as_strided_({size}, {stride})"
|
||||
|
||||
|
||||
class SliceView(View):
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -1016,9 +1016,8 @@ class Scheduler:
|
|||
V.graph.wrapper_code.codegen_free(node.node)
|
||||
elif name in V.graph.graph_inputs:
|
||||
storage = V.graph.graph_inputs[name].data
|
||||
if not isinstance(storage, ir.ReinterpretView):
|
||||
assert storage.is_input_buffer()
|
||||
V.graph.wrapper_code.codegen_free(storage.data)
|
||||
assert storage.is_input_buffer()
|
||||
V.graph.wrapper_code.codegen_free(storage.data)
|
||||
|
||||
self.buffer_names_to_free.clear()
|
||||
|
||||
|
|
|
|||
|
|
@ -448,8 +448,6 @@ class SizeVarAllocator(object):
|
|||
needed = set(self.var_to_val.keys()) - set(self.replacements.keys())
|
||||
|
||||
for name, value in graph_inputs.items():
|
||||
if isinstance(value.data, ir.ReinterpretView):
|
||||
value = value.data.data
|
||||
shapes = value.get_size()
|
||||
for dim, shape in enumerate(shapes):
|
||||
shape = self.simplify(shape)
|
||||
|
|
@ -460,8 +458,6 @@ class SizeVarAllocator(object):
|
|||
)
|
||||
|
||||
for name, value in graph_inputs.items():
|
||||
if isinstance(value.data, ir.ReinterpretView):
|
||||
value = value.data.data
|
||||
shapes = value.get_stride()
|
||||
for dim, shape in enumerate(shapes):
|
||||
shape = self.simplify(shape)
|
||||
|
|
|
|||
|
|
@ -182,6 +182,4 @@ class ShapeProp(torch.fx.Interpreter):
|
|||
Returns:
|
||||
Any: The value returned from executing the Module
|
||||
"""
|
||||
# clone inputs to avoid side effects caused by inplace ops during run_node
|
||||
new_args = [torch._prims_common.clone_preserve_strides(x) for x in args]
|
||||
return super().run(*new_args)
|
||||
return super().run(*args)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user