mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
graph break on tolist if capture_scalar_outputs is false (#163807)
address https://github.com/pytorch/pytorch/issues/163798 its problematic to not graph break because: 1. break current contract. 2. well dynamo trace then we have .item call then if we ever re-trace later in autograd for example we hit a failure (We do not know where to graph break at that point)! see the added unit test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163807 Approved by: https://github.com/bobrenjc93
This commit is contained in:
parent
3059b08012
commit
b377c9e365
|
|
@ -218,7 +218,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,1
|
microbench_unbacked_tolist_sum,pass,2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -146,7 +146,7 @@ maml_omniglot,pass,7
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,8
|
microbench_unbacked_tolist_sum,pass,9
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -182,7 +182,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,0
|
microbench_unbacked_tolist_sum,pass,1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -182,7 +182,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,0
|
microbench_unbacked_tolist_sum,pass,1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -198,7 +198,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,1
|
microbench_unbacked_tolist_sum,pass,2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -198,7 +198,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,1
|
microbench_unbacked_tolist_sum,pass,2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -198,7 +198,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,1
|
microbench_unbacked_tolist_sum,pass,2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -218,7 +218,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,1
|
microbench_unbacked_tolist_sum,pass,2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -146,7 +146,7 @@ maml_omniglot,pass,7
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,8
|
microbench_unbacked_tolist_sum,pass,9
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -166,7 +166,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,0
|
microbench_unbacked_tolist_sum,pass,1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -166,7 +166,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,0
|
microbench_unbacked_tolist_sum,pass,1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -182,7 +182,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,1
|
microbench_unbacked_tolist_sum,pass,2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -198,7 +198,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,1
|
microbench_unbacked_tolist_sum,pass,2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -218,7 +218,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,1
|
microbench_unbacked_tolist_sum,pass,2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -146,7 +146,7 @@ maml_omniglot,pass,7
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,8
|
microbench_unbacked_tolist_sum,pass,9
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -218,7 +218,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,1
|
microbench_unbacked_tolist_sum,pass,2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -146,7 +146,7 @@ maml_omniglot,pass,7
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,8
|
microbench_unbacked_tolist_sum,pass,9
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -218,7 +218,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,1
|
microbench_unbacked_tolist_sum,pass,2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -146,7 +146,7 @@ maml_omniglot,pass,7
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,8
|
microbench_unbacked_tolist_sum,pass,9
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -221,7 +221,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,1
|
microbench_unbacked_tolist_sum,pass,2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -154,7 +154,7 @@ maml_omniglot,pass,7
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,8
|
microbench_unbacked_tolist_sum,pass,9
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -221,7 +221,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,1
|
microbench_unbacked_tolist_sum,pass,2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -154,7 +154,7 @@ maml_omniglot,pass,7
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,8
|
microbench_unbacked_tolist_sum,pass,9
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -142,7 +142,7 @@ maml_omniglot,pass,7
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,8
|
microbench_unbacked_tolist_sum,pass,9
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -221,7 +221,7 @@ maml_omniglot,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,1
|
microbench_unbacked_tolist_sum,pass,2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -154,7 +154,7 @@ maml_omniglot,pass,7
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,8
|
microbench_unbacked_tolist_sum,pass,9
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -142,7 +142,7 @@ maml_omniglot,pass,7
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
microbench_unbacked_tolist_sum,pass,8
|
microbench_unbacked_tolist_sum,pass,9
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -815,6 +815,7 @@ class GraphModule(torch.nn.Module):
|
||||||
|
|
||||||
@torch._dynamo.config.patch(
|
@torch._dynamo.config.patch(
|
||||||
capture_dynamic_output_shape_ops=True,
|
capture_dynamic_output_shape_ops=True,
|
||||||
|
capture_scalar_outputs=True,
|
||||||
)
|
)
|
||||||
def test_tensor_to_list_closure(self):
|
def test_tensor_to_list_closure(self):
|
||||||
def f(x):
|
def f(x):
|
||||||
|
|
|
||||||
|
|
@ -2591,6 +2591,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||||
y = fn(x)
|
y = fn(x)
|
||||||
self.assertTrue(y.flags.writeable) # XXX: differs from numpy
|
self.assertTrue(y.flags.writeable) # XXX: differs from numpy
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
def test_numpy_tolist(self):
|
def test_numpy_tolist(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
return x.tolist()
|
return x.tolist()
|
||||||
|
|
@ -7967,6 +7968,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||||
self.assertEqual(fn(torch.tensor([4])).size(0), 1)
|
self.assertEqual(fn(torch.tensor([4])).size(0), 1)
|
||||||
self.assertEqual(fn(torch.tensor([1])).size(0), 0)
|
self.assertEqual(fn(torch.tensor([1])).size(0), 0)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
def test_sym_and_terms(self):
|
def test_sym_and_terms(self):
|
||||||
from torch.fx.experimental.symbolic_shapes import sym_and
|
from torch.fx.experimental.symbolic_shapes import sym_and
|
||||||
|
|
||||||
|
|
@ -8137,6 +8139,19 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
torch.compile(my_dyn_fn, backend="eager")(y, y)
|
torch.compile(my_dyn_fn, backend="eager")(y, y)
|
||||||
|
|
||||||
|
def test_tolist(self):
|
||||||
|
# This should compile with no faluire.
|
||||||
|
cnt = CompileCounterWithBackend("inductor")
|
||||||
|
|
||||||
|
@torch.compile(fullgraph=False, backend=cnt)
|
||||||
|
def func(a):
|
||||||
|
a = a * 100
|
||||||
|
u0, u1, u2, u3, u4 = a.tolist()
|
||||||
|
return a * u0 * u1
|
||||||
|
|
||||||
|
func(torch.tensor([1, 2, 3, 4, 5]))
|
||||||
|
self.assertEqual(cnt.frame_count, 2)
|
||||||
|
|
||||||
# Sadly, this does not throw - we do not prop correctly across the graph break
|
# Sadly, this does not throw - we do not prop correctly across the graph break
|
||||||
@unittest.expectedFailure
|
@unittest.expectedFailure
|
||||||
def test_raise_guard_partial_constraint_across_break(self):
|
def test_raise_guard_partial_constraint_across_break(self):
|
||||||
|
|
@ -9688,6 +9703,7 @@ def ___make_guard_fn():
|
||||||
img2 = torch.randn(1, 3, 8, 15)
|
img2 = torch.randn(1, 3, 8, 15)
|
||||||
self.assertRaises(AssertionError, lambda: fn(img2))
|
self.assertRaises(AssertionError, lambda: fn(img2))
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
def test_tolist_scalar(self):
|
def test_tolist_scalar(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
new_list = []
|
new_list = []
|
||||||
|
|
@ -9702,6 +9718,7 @@ def ___make_guard_fn():
|
||||||
self.assertEqual(eager, compiled)
|
self.assertEqual(eager, compiled)
|
||||||
self.assertEqual(counter.frame_count, 1)
|
self.assertEqual(counter.frame_count, 1)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
def test_tolist_1d(self):
|
def test_tolist_1d(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
new_list = []
|
new_list = []
|
||||||
|
|
@ -9716,6 +9733,7 @@ def ___make_guard_fn():
|
||||||
self.assertEqual(eager, compiled)
|
self.assertEqual(eager, compiled)
|
||||||
self.assertEqual(counter.frame_count, 1)
|
self.assertEqual(counter.frame_count, 1)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
def test_tolist_kd(self):
|
def test_tolist_kd(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
new_list = []
|
new_list = []
|
||||||
|
|
@ -9730,6 +9748,7 @@ def ___make_guard_fn():
|
||||||
self.assertEqual(eager, compiled)
|
self.assertEqual(eager, compiled)
|
||||||
self.assertEqual(counter.frame_count, 1)
|
self.assertEqual(counter.frame_count, 1)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
@patch.object(torch._dynamo.config, "specialize_int", True)
|
@patch.object(torch._dynamo.config, "specialize_int", True)
|
||||||
def test_tolist_0d(self):
|
def test_tolist_0d(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
|
|
@ -9752,12 +9771,12 @@ def ___make_guard_fn():
|
||||||
new_list = []
|
new_list = []
|
||||||
i = x.tolist()
|
i = x.tolist()
|
||||||
new_list.append(i * 4)
|
new_list.append(i * 4)
|
||||||
return new_list
|
return new_list, x * 10
|
||||||
|
|
||||||
x = torch.randint(3, 5, [5, 5])
|
x = torch.randint(3, 5, [5, 5])
|
||||||
eager = fn(x)
|
eager = fn(x)
|
||||||
counter = CompileCounter()
|
counter = CompileCounter()
|
||||||
compiled_fn = torch.compile(fn, backend=counter, fullgraph=True)
|
compiled_fn = torch.compile(fn, backend=counter, fullgraph=False)
|
||||||
compiled = compiled_fn(x)
|
compiled = compiled_fn(x)
|
||||||
self.assertEqual(eager, compiled)
|
self.assertEqual(eager, compiled)
|
||||||
self.assertEqual(counter.frame_count, 1)
|
self.assertEqual(counter.frame_count, 1)
|
||||||
|
|
@ -11299,6 +11318,7 @@ fn
|
||||||
c2 = _debug_get_cache_entry_list(fn.__code__)
|
c2 = _debug_get_cache_entry_list(fn.__code__)
|
||||||
self.assertEqual(len(c2), 0)
|
self.assertEqual(len(c2), 0)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
def test_guard_size_oblivious_simplification(self):
|
def test_guard_size_oblivious_simplification(self):
|
||||||
@torch.compile(backend="eager", fullgraph=True)
|
@torch.compile(backend="eager", fullgraph=True)
|
||||||
def fn(x):
|
def fn(x):
|
||||||
|
|
@ -11328,6 +11348,7 @@ fn
|
||||||
with self.assertRaisesRegex(RuntimeError, "specialized"):
|
with self.assertRaisesRegex(RuntimeError, "specialized"):
|
||||||
fn(x, y)
|
fn(x, y)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
def test_sym_max_unbacked_sizelike_simplification(self):
|
def test_sym_max_unbacked_sizelike_simplification(self):
|
||||||
@torch.compile(fullgraph=True, backend="eager")
|
@torch.compile(fullgraph=True, backend="eager")
|
||||||
def cf(x):
|
def cf(x):
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@ import operator
|
||||||
import textwrap
|
import textwrap
|
||||||
import traceback
|
import traceback
|
||||||
import types
|
import types
|
||||||
import unittest
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
@ -945,15 +944,10 @@ class TensorVariable(VariableTracker):
|
||||||
|
|
||||||
def tolist(tensor, sub_proxy):
|
def tolist(tensor, sub_proxy):
|
||||||
def wrap(i, sub_proxy):
|
def wrap(i, sub_proxy):
|
||||||
# Sigh, we forgot to gate this, so this data dependent is on
|
return wrap_fx_proxy(
|
||||||
# by default and is load bearing in CI
|
tx,
|
||||||
with unittest.mock.patch.object(
|
sub_proxy.item(),
|
||||||
tx.fake_mode, "allow_scalar_outputs", True
|
)
|
||||||
):
|
|
||||||
return wrap_fx_proxy(
|
|
||||||
tx,
|
|
||||||
sub_proxy.item(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if tensor.dtype not in [
|
if tensor.dtype not in [
|
||||||
torch.int8,
|
torch.int8,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user