graph break on tolist if capture_scalar_outputs is false (#163807)

address https://github.com/pytorch/pytorch/issues/163798

its problematic to not graph break because:
1. break current contract.
2. well dynamo trace then we have .item call then if we ever re-trace later in autograd for example we hit a
 failure (We do not know where to graph break at that point)! see the added unit test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163807
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Laith Sakka 2025-09-26 16:27:28 -07:00 committed by PyTorch MergeBot
parent 3059b08012
commit b377c9e365
30 changed files with 55 additions and 39 deletions

View File

@ -218,7 +218,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,1
microbench_unbacked_tolist_sum,pass,2

1 name accuracy graph_breaks
218
219
220
221
222
223
224

View File

@ -146,7 +146,7 @@ maml_omniglot,pass,7
microbench_unbacked_tolist_sum,pass,8
microbench_unbacked_tolist_sum,pass,9

1 name accuracy graph_breaks
146
147
148
149
150
151
152

View File

@ -182,7 +182,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,0
microbench_unbacked_tolist_sum,pass,1

1 name accuracy graph_breaks
182
183
184
185
186
187
188

View File

@ -182,7 +182,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,0
microbench_unbacked_tolist_sum,pass,1

1 name accuracy graph_breaks
182
183
184
185
186
187
188

View File

@ -198,7 +198,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,1
microbench_unbacked_tolist_sum,pass,2

1 name accuracy graph_breaks
198
199
200
201
202
203
204

View File

@ -198,7 +198,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,1
microbench_unbacked_tolist_sum,pass,2

1 name accuracy graph_breaks
198
199
200
201
202
203
204

View File

@ -198,7 +198,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,1
microbench_unbacked_tolist_sum,pass,2

1 name accuracy graph_breaks
198
199
200
201
202
203
204

View File

@ -218,7 +218,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,1
microbench_unbacked_tolist_sum,pass,2

1 name accuracy graph_breaks
218
219
220
221
222
223
224

View File

@ -146,7 +146,7 @@ maml_omniglot,pass,7
microbench_unbacked_tolist_sum,pass,8
microbench_unbacked_tolist_sum,pass,9

1 name accuracy graph_breaks
146
147
148
149
150
151
152

View File

@ -166,7 +166,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,0
microbench_unbacked_tolist_sum,pass,1

1 name accuracy graph_breaks
166
167
168
169
170
171
172

View File

@ -166,7 +166,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,0
microbench_unbacked_tolist_sum,pass,1

1 name accuracy graph_breaks
166
167
168
169
170
171
172

View File

@ -182,7 +182,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,1
microbench_unbacked_tolist_sum,pass,2

1 name accuracy graph_breaks
182
183
184
185
186
187
188

View File

@ -198,7 +198,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,1
microbench_unbacked_tolist_sum,pass,2

1 name accuracy graph_breaks
198
199
200
201
202
203
204

View File

@ -218,7 +218,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,1
microbench_unbacked_tolist_sum,pass,2

1 name accuracy graph_breaks
218
219
220
221
222
223
224

View File

@ -146,7 +146,7 @@ maml_omniglot,pass,7
microbench_unbacked_tolist_sum,pass,8
microbench_unbacked_tolist_sum,pass,9

1 name accuracy graph_breaks
146
147
148
149
150
151
152

View File

@ -218,7 +218,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,1
microbench_unbacked_tolist_sum,pass,2

1 name accuracy graph_breaks
218
219
220
221
222
223
224

View File

@ -146,7 +146,7 @@ maml_omniglot,pass,7
microbench_unbacked_tolist_sum,pass,8
microbench_unbacked_tolist_sum,pass,9

1 name accuracy graph_breaks
146
147
148
149
150
151
152

View File

@ -218,7 +218,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,1
microbench_unbacked_tolist_sum,pass,2

1 name accuracy graph_breaks
218
219
220
221
222
223
224

View File

@ -146,7 +146,7 @@ maml_omniglot,pass,7
microbench_unbacked_tolist_sum,pass,8
microbench_unbacked_tolist_sum,pass,9

1 name accuracy graph_breaks
146
147
148
149
150
151
152

View File

@ -221,7 +221,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,1
microbench_unbacked_tolist_sum,pass,2

1 name accuracy graph_breaks
221
222
223
224
225
226
227

View File

@ -154,7 +154,7 @@ maml_omniglot,pass,7
microbench_unbacked_tolist_sum,pass,8
microbench_unbacked_tolist_sum,pass,9

1 name accuracy graph_breaks
154
155
156
157
158
159
160

View File

@ -221,7 +221,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,1
microbench_unbacked_tolist_sum,pass,2

1 name accuracy graph_breaks
221
222
223
224
225
226
227

View File

@ -154,7 +154,7 @@ maml_omniglot,pass,7
microbench_unbacked_tolist_sum,pass,8
microbench_unbacked_tolist_sum,pass,9

1 name accuracy graph_breaks
154
155
156
157
158
159
160

View File

@ -142,7 +142,7 @@ maml_omniglot,pass,7
microbench_unbacked_tolist_sum,pass,8
microbench_unbacked_tolist_sum,pass,9

1 name accuracy graph_breaks
142
143
144
145
146
147
148

View File

@ -221,7 +221,7 @@ maml_omniglot,pass,0
microbench_unbacked_tolist_sum,pass,1
microbench_unbacked_tolist_sum,pass,2

1 name accuracy graph_breaks
221
222
223
224
225
226
227

View File

@ -154,7 +154,7 @@ maml_omniglot,pass,7
microbench_unbacked_tolist_sum,pass,8
microbench_unbacked_tolist_sum,pass,9

1 name accuracy graph_breaks
154
155
156
157
158
159
160

View File

@ -142,7 +142,7 @@ maml_omniglot,pass,7
microbench_unbacked_tolist_sum,pass,8
microbench_unbacked_tolist_sum,pass,9

1 name accuracy graph_breaks
142
143
144
145
146
147
148

View File

@ -815,6 +815,7 @@ class GraphModule(torch.nn.Module):
@torch._dynamo.config.patch(
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
)
def test_tensor_to_list_closure(self):
def f(x):

View File

@ -2591,6 +2591,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
y = fn(x)
self.assertTrue(y.flags.writeable) # XXX: differs from numpy
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_numpy_tolist(self):
def fn(x):
return x.tolist()
@ -7967,6 +7968,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
self.assertEqual(fn(torch.tensor([4])).size(0), 1)
self.assertEqual(fn(torch.tensor([1])).size(0), 0)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_sym_and_terms(self):
from torch.fx.experimental.symbolic_shapes import sym_and
@ -8137,6 +8139,19 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
torch._dynamo.reset()
torch.compile(my_dyn_fn, backend="eager")(y, y)
def test_tolist(self):
# This should compile with no faluire.
cnt = CompileCounterWithBackend("inductor")
@torch.compile(fullgraph=False, backend=cnt)
def func(a):
a = a * 100
u0, u1, u2, u3, u4 = a.tolist()
return a * u0 * u1
func(torch.tensor([1, 2, 3, 4, 5]))
self.assertEqual(cnt.frame_count, 2)
# Sadly, this does not throw - we do not prop correctly across the graph break
@unittest.expectedFailure
def test_raise_guard_partial_constraint_across_break(self):
@ -9688,6 +9703,7 @@ def ___make_guard_fn():
img2 = torch.randn(1, 3, 8, 15)
self.assertRaises(AssertionError, lambda: fn(img2))
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_tolist_scalar(self):
def fn(x):
new_list = []
@ -9702,6 +9718,7 @@ def ___make_guard_fn():
self.assertEqual(eager, compiled)
self.assertEqual(counter.frame_count, 1)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_tolist_1d(self):
def fn(x):
new_list = []
@ -9716,6 +9733,7 @@ def ___make_guard_fn():
self.assertEqual(eager, compiled)
self.assertEqual(counter.frame_count, 1)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_tolist_kd(self):
def fn(x):
new_list = []
@ -9730,6 +9748,7 @@ def ___make_guard_fn():
self.assertEqual(eager, compiled)
self.assertEqual(counter.frame_count, 1)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
@patch.object(torch._dynamo.config, "specialize_int", True)
def test_tolist_0d(self):
def fn(x):
@ -9752,12 +9771,12 @@ def ___make_guard_fn():
new_list = []
i = x.tolist()
new_list.append(i * 4)
return new_list
return new_list, x * 10
x = torch.randint(3, 5, [5, 5])
eager = fn(x)
counter = CompileCounter()
compiled_fn = torch.compile(fn, backend=counter, fullgraph=True)
compiled_fn = torch.compile(fn, backend=counter, fullgraph=False)
compiled = compiled_fn(x)
self.assertEqual(eager, compiled)
self.assertEqual(counter.frame_count, 1)
@ -11299,6 +11318,7 @@ fn
c2 = _debug_get_cache_entry_list(fn.__code__)
self.assertEqual(len(c2), 0)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_guard_size_oblivious_simplification(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
@ -11328,6 +11348,7 @@ fn
with self.assertRaisesRegex(RuntimeError, "specialized"):
fn(x, y)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_sym_max_unbacked_sizelike_simplification(self):
@torch.compile(fullgraph=True, backend="eager")
def cf(x):

View File

@ -23,7 +23,6 @@ import operator
import textwrap
import traceback
import types
import unittest
from typing import TYPE_CHECKING
import sympy
@ -945,15 +944,10 @@ class TensorVariable(VariableTracker):
def tolist(tensor, sub_proxy):
def wrap(i, sub_proxy):
# Sigh, we forgot to gate this, so this data dependent is on
# by default and is load bearing in CI
with unittest.mock.patch.object(
tx.fake_mode, "allow_scalar_outputs", True
):
return wrap_fx_proxy(
tx,
sub_proxy.item(),
)
return wrap_fx_proxy(
tx,
sub_proxy.item(),
)
if tensor.dtype not in [
torch.int8,