[dynamo] Restrict support for out= variants of torch operators (#140202)

There has been a series of attempts to provide support for resizing in
torch operators like `torch.sigmoid(x, out=y)`, i.e., `y` would have a
different shape before and after this expression. Prior to this patch,
we have some checks to graph break if the shape changed.

This patch extends
1. extends the existing check and graph break for any shape change, not
   just for `TensorVariable` with source field.
2. removes an old code path which was introduced to address the shape
   change, but became obselete in that regard because we added extra
   checks to graph break upon shape change. Moreover, this old code path
   is unsound, it tries to replace references to the old
   `TensorVariable` the new one returned by `wrap_fx_proxy`, but it only
   does the replacement in `symbolic_locals`, which breaks when cells
   are involved. In general the old `TensorVariable` could be _anywhere_,
   think the `replace_all` we had for immutable VTs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140202
Approved by: https://github.com/jansel
ghstack dependencies: #140035, #140036, #140149, #140150, #140151, #140201
This commit is contained in:
Ryan Guo 2024-11-12 19:05:05 -05:00 committed by PyTorch MergeBot
parent 65615915ed
commit 39d1c91c33
3 changed files with 82 additions and 26 deletions

View File

@ -2225,6 +2225,75 @@ class ReproTests(torch._dynamo.test_case.TestCase):
res = opt_m(x) res = opt_m(x)
self.assertTrue(same(ref, res)) self.assertTrue(same(ref, res))
def test_out_root_cell_shape_change(self):
@torch.compile(backend="eager")
def fn():
out = torch.empty(0)
def run():
x = torch.zeros(3, 5)
torch.sigmoid(x, out=out)
return out.size()
return run()
res = fn()
self.assertEqual((3, 5), res)
def test_out_nested_cell_shape_change(self):
@torch.compile(backend="eager")
def fn():
def run():
x = torch.zeros(3, 5)
out = torch.empty(0)
def capture():
return out # Force `out` to be a nested cell
torch.sigmoid(x, out=out)
return out.size()
return run()
res = fn()
self.assertEqual((3, 5), res)
def test_out_root_cell_tuple_shape_change(self):
@torch.compile(backend="eager")
def fn():
out1 = torch.empty(0)
out2 = torch.empty(0, dtype=torch.long)
def run():
x = torch.zeros(3, 5)
torch.sort(x, out=(out1, out2))
return out1.size(), out2.size()
return run()
res = fn()
self.assertEqual(((3, 5), (3, 5)), res)
def test_out_nested_cell_tuple_shape_change(self):
@torch.compile(backend="eager")
def fn():
def run():
x = torch.zeros(3, 5)
out1 = torch.empty(0)
out2 = torch.empty(0, dtype=torch.long)
def capture():
# Force `out1` and `out2` to be nested cells
return out1, out2
torch.sort(x, out=(out1, out2))
return out1.size(), out2.size()
return run()
res = fn()
self.assertEqual(((3, 5), (3, 5)), res)
def test_slice_into_list_mutable(self): def test_slice_into_list_mutable(self):
class Mod(torch.nn.Module): class Mod(torch.nn.Module):
def forward(self, listy): def forward(self, listy):

View File

@ -2642,12 +2642,6 @@ class InstructionTranslatorBase(
def fake_mode(self): def fake_mode(self):
return self.output.tracing_context.fake_mode return self.output.tracing_context.fake_mode
def find_symbolic_locals_name(self, tensor_variable):
for key, value in self.symbolic_locals.items():
if value is tensor_variable:
return key
return None
@contextlib.contextmanager @contextlib.contextmanager
def strict_translation_mode(self, check_fn: Callable[[VariableTracker], bool]): def strict_translation_mode(self, check_fn: Callable[[VariableTracker], bool]):
""" """

View File

@ -973,24 +973,24 @@ Either create the tensor outside the compiled region, or do not set the tensor t
isinstance(kwargs["out"], variables.ConstantVariable) isinstance(kwargs["out"], variables.ConstantVariable)
and kwargs["out"].as_python_constant() is None and kwargs["out"].as_python_constant() is None
): ):
# out variants of torch operators like torch.sort and # out variants of torch operators like torch.sort and torch.sigmoid
# torch.sigmoid mutate the tensors in the out field. Track such # mutate the tensors in the out field.
# tensors and rewrite the symbolic locals. #
# However, it's non-trivial to update all references of the old
# `TensorVariable` to the new one returned (`result_var`), so we
# take the conservative approach to graph break on size changes, and
# assume other cases can fall through soundly.
#
# Note that although these tensor variablels would hold different
# proxies, the in-place mutation semantics is preserved in the FX
# graph, so we won't have correctness issues.
if isinstance(tensor_variable, TupleVariable): if isinstance(tensor_variable, TupleVariable):
assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
output_tensor_names = [
tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
]
for idx, name in enumerate(output_tensor_names):
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable.items[idx]
for out_tensor, result_tensor in zip( for out_tensor, result_tensor in zip(
kwargs["out"].items, tensor_variable.items kwargs["out"].items, tensor_variable.items
): ):
if ( if (
out_tensor.source isinstance(out_tensor, variables.TensorVariable)
and out_tensor in tx.output.graphargs
and isinstance(out_tensor, variables.TensorVariable)
and isinstance(result_tensor, variables.TensorVariable) and isinstance(result_tensor, variables.TensorVariable)
and out_tensor._size and out_tensor._size
!= result_tensor._size # we actually want to compare None values here != result_tensor._size # we actually want to compare None values here
@ -1003,11 +1003,7 @@ Either create the tensor outside the compiled region, or do not set the tensor t
assert "example_value" in kwargs["out"].proxy.node.meta assert "example_value" in kwargs["out"].proxy.node.meta
fake_tensor = tensor_variable.proxy.node.meta["example_value"] fake_tensor = tensor_variable.proxy.node.meta["example_value"]
fake_out = kwargs["out"].proxy.node.meta["example_value"] fake_out = kwargs["out"].proxy.node.meta["example_value"]
if ( if fake_out_shape != fake_tensor.shape:
kwargs["out"].source
and kwargs["out"] in tx.output.graphargs
and fake_out_shape != fake_tensor.shape
):
# It's hard to get out variants with resizing on graph inputs work # It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back. # properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs") unimplemented("out variants with resizing on graph inputs")
@ -1017,9 +1013,6 @@ Either create the tensor outside the compiled region, or do not set the tensor t
unimplemented( unimplemented(
"out= op was called where output tensor was non-contiguous" "out= op was called where output tensor was non-contiguous"
) )
name = tx.find_symbolic_locals_name(kwargs["out"])
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable
elif ( elif (
isinstance(tensor_variable, ConstantVariable) isinstance(tensor_variable, ConstantVariable)
and tensor_variable.value is None and tensor_variable.value is None