pytorch/test/dynamo/cpython/3_13/test_bool.diff
William Wen 8678d831c4 [dynamo] rename set_fullgraph to error_on_graph_break (#161739)
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph.

I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet).

 cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739
Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
2025-09-04 01:15:06 +00:00

177 lines
5.3 KiB
Diff

diff --git a/test/dynamo/cpython/3_13/test_bool.py b/test/dynamo/cpython/3_13/test_bool.py
index 34ecb45f161..12b719c432b 100644
--- a/test/dynamo/cpython/3_13/test_bool.py
+++ b/test/dynamo/cpython/3_13/test_bool.py
@@ -1,3 +1,23 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_bool.py
+
+import sys
+import torch
+import torch._dynamo.test_case
+import unittest
+from torch._dynamo.test_case import CPythonTestCase
+from torch.testing._internal.common_utils import run_tests
+
+__TestCase = CPythonTestCase
+
+# ======= END DYNAMO PATCH =======
+
# Test properties of bool promised by PEP 285
import unittest
@@ -5,12 +25,13 @@ from test.support import os_helper
import os
-class BoolTest(unittest.TestCase):
+class BoolTest(__TestCase):
def test_subclass(self):
try:
- class C(bool):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class C(bool):
+ pass
except TypeError:
pass
else:
@@ -307,40 +328,46 @@ class BoolTest(unittest.TestCase):
# from __bool__(). This isn't really a bool test, but
# it's related.
check = lambda o: self.assertRaises(TypeError, bool, o)
- class Foo(object):
- def __bool__(self):
- return self
+ with torch._dynamo.error_on_graph_break(False):
+ class Foo(object):
+ def __bool__(self):
+ return self
check(Foo())
- class Bar(object):
- def __bool__(self):
- return "Yes"
+ with torch._dynamo.error_on_graph_break(False):
+ class Bar(object):
+ def __bool__(self):
+ return "Yes"
check(Bar())
- class Baz(int):
- def __bool__(self):
- return self
+ with torch._dynamo.error_on_graph_break(False):
+ class Baz(int):
+ def __bool__(self):
+ return self
check(Baz())
# __bool__() must return a bool not an int
- class Spam(int):
- def __bool__(self):
- return 1
+ with torch._dynamo.error_on_graph_break(False):
+ class Spam(int):
+ def __bool__(self):
+ return 1
check(Spam())
- class Eggs:
- def __len__(self):
- return -1
+ with torch._dynamo.error_on_graph_break(False):
+ class Eggs:
+ def __len__(self):
+ return -1
self.assertRaises(ValueError, bool, Eggs())
def test_interpreter_convert_to_bool_raises(self):
- class SymbolicBool:
- def __bool__(self):
- raise TypeError
+ with torch._dynamo.error_on_graph_break(False):
+ class SymbolicBool:
+ def __bool__(self):
+ raise TypeError
- class Symbol:
- def __gt__(self, other):
- return SymbolicBool()
+ class Symbol:
+ def __gt__(self, other):
+ return SymbolicBool()
x = Symbol()
@@ -361,9 +388,10 @@ class BoolTest(unittest.TestCase):
# this test just tests our assumptions about __len__
# this will start failing if __len__ changes assertions
for badval in ['illegal', -1, 1 << 32]:
- class A:
- def __len__(self):
- return badval
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ def __len__(self):
+ return badval
try:
bool(A())
except (Exception) as e_bool:
@@ -373,14 +401,16 @@ class BoolTest(unittest.TestCase):
self.assertEqual(str(e_bool), str(e_len))
def test_blocked(self):
- class A:
- __bool__ = None
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ __bool__ = None
self.assertRaises(TypeError, bool, A())
- class B:
- def __len__(self):
- return 10
- __bool__ = None
+ with torch._dynamo.error_on_graph_break(False):
+ class B:
+ def __len__(self):
+ return 10
+ __bool__ = None
self.assertRaises(TypeError, bool, B())
def test_real_and_imag(self):
@@ -394,12 +424,13 @@ class BoolTest(unittest.TestCase):
self.assertIs(type(False.imag), int)
def test_bool_called_at_least_once(self):
- class X:
- def __init__(self):
- self.count = 0
- def __bool__(self):
- self.count += 1
- return True
+ with torch._dynamo.error_on_graph_break(False):
+ class X:
+ def __init__(self):
+ self.count = 0
+ def __bool__(self):
+ self.count += 1
+ return True
def f(x):
if x or True:
@@ -418,4 +449,4 @@ class BoolTest(unittest.TestCase):
if __name__ == "__main__":
- unittest.main()
+ run_tests()