diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 83c206a576a..1b7082a8933 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -2225,6 +2225,75 @@ class ReproTests(torch._dynamo.test_case.TestCase): res = opt_m(x) self.assertTrue(same(ref, res)) + def test_out_root_cell_shape_change(self): + @torch.compile(backend="eager") + def fn(): + out = torch.empty(0) + + def run(): + x = torch.zeros(3, 5) + torch.sigmoid(x, out=out) + return out.size() + + return run() + + res = fn() + self.assertEqual((3, 5), res) + + def test_out_nested_cell_shape_change(self): + @torch.compile(backend="eager") + def fn(): + def run(): + x = torch.zeros(3, 5) + out = torch.empty(0) + + def capture(): + return out # Force `out` to be a nested cell + + torch.sigmoid(x, out=out) + return out.size() + + return run() + + res = fn() + self.assertEqual((3, 5), res) + + def test_out_root_cell_tuple_shape_change(self): + @torch.compile(backend="eager") + def fn(): + out1 = torch.empty(0) + out2 = torch.empty(0, dtype=torch.long) + + def run(): + x = torch.zeros(3, 5) + torch.sort(x, out=(out1, out2)) + return out1.size(), out2.size() + + return run() + + res = fn() + self.assertEqual(((3, 5), (3, 5)), res) + + def test_out_nested_cell_tuple_shape_change(self): + @torch.compile(backend="eager") + def fn(): + def run(): + x = torch.zeros(3, 5) + out1 = torch.empty(0) + out2 = torch.empty(0, dtype=torch.long) + + def capture(): + # Force `out1` and `out2` to be nested cells + return out1, out2 + + torch.sort(x, out=(out1, out2)) + return out1.size(), out2.size() + + return run() + + res = fn() + self.assertEqual(((3, 5), (3, 5)), res) + def test_slice_into_list_mutable(self): class Mod(torch.nn.Module): def forward(self, listy): diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 8922d7cbbab..e0245321ae1 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -2642,12 +2642,6 @@ class InstructionTranslatorBase( def fake_mode(self): return self.output.tracing_context.fake_mode - def find_symbolic_locals_name(self, tensor_variable): - for key, value in self.symbolic_locals.items(): - if value is tensor_variable: - return key - return None - @contextlib.contextmanager def strict_translation_mode(self, check_fn: Callable[[VariableTracker], bool]): """ diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 6f500fff176..60e5d865ee3 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -973,24 +973,24 @@ Either create the tensor outside the compiled region, or do not set the tensor t isinstance(kwargs["out"], variables.ConstantVariable) and kwargs["out"].as_python_constant() is None ): - # out variants of torch operators like torch.sort and - # torch.sigmoid mutate the tensors in the out field. Track such - # tensors and rewrite the symbolic locals. + # out variants of torch operators like torch.sort and torch.sigmoid + # mutate the tensors in the out field. + # + # However, it's non-trivial to update all references of the old + # `TensorVariable` to the new one returned (`result_var`), so we + # take the conservative approach to graph break on size changes, and + # assume other cases can fall through soundly. + # + # Note that although these tensor variablels would hold different + # proxies, the in-place mutation semantics is preserved in the FX + # graph, so we won't have correctness issues. if isinstance(tensor_variable, TupleVariable): assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) - output_tensor_names = [ - tx.find_symbolic_locals_name(x) for x in kwargs["out"].items - ] - for idx, name in enumerate(output_tensor_names): - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable.items[idx] for out_tensor, result_tensor in zip( kwargs["out"].items, tensor_variable.items ): if ( - out_tensor.source - and out_tensor in tx.output.graphargs - and isinstance(out_tensor, variables.TensorVariable) + isinstance(out_tensor, variables.TensorVariable) and isinstance(result_tensor, variables.TensorVariable) and out_tensor._size != result_tensor._size # we actually want to compare None values here @@ -1003,11 +1003,7 @@ Either create the tensor outside the compiled region, or do not set the tensor t assert "example_value" in kwargs["out"].proxy.node.meta fake_tensor = tensor_variable.proxy.node.meta["example_value"] fake_out = kwargs["out"].proxy.node.meta["example_value"] - if ( - kwargs["out"].source - and kwargs["out"] in tx.output.graphargs - and fake_out_shape != fake_tensor.shape - ): + if fake_out_shape != fake_tensor.shape: # It's hard to get out variants with resizing on graph inputs work # properly across dynamo/aot/inductor, just fall back. unimplemented("out variants with resizing on graph inputs") @@ -1017,9 +1013,6 @@ Either create the tensor outside the compiled region, or do not set the tensor t unimplemented( "out= op was called where output tensor was non-contiguous" ) - name = tx.find_symbolic_locals_name(kwargs["out"]) - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable elif ( isinstance(tensor_variable, ConstantVariable) and tensor_variable.value is None