mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph. I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet). cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739 Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
177 lines
5.3 KiB
Diff
177 lines
5.3 KiB
Diff
diff --git a/test/dynamo/cpython/3_13/test_bool.py b/test/dynamo/cpython/3_13/test_bool.py
|
|
index 34ecb45f161..12b719c432b 100644
|
|
--- a/test/dynamo/cpython/3_13/test_bool.py
|
|
+++ b/test/dynamo/cpython/3_13/test_bool.py
|
|
@@ -1,3 +1,23 @@
|
|
+# ======= BEGIN Dynamo patch =======
|
|
+# Owner(s): ["module: dynamo"]
|
|
+
|
|
+# ruff: noqa
|
|
+# flake8: noqa
|
|
+
|
|
+# Test copied from
|
|
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_bool.py
|
|
+
|
|
+import sys
|
|
+import torch
|
|
+import torch._dynamo.test_case
|
|
+import unittest
|
|
+from torch._dynamo.test_case import CPythonTestCase
|
|
+from torch.testing._internal.common_utils import run_tests
|
|
+
|
|
+__TestCase = CPythonTestCase
|
|
+
|
|
+# ======= END DYNAMO PATCH =======
|
|
+
|
|
# Test properties of bool promised by PEP 285
|
|
|
|
import unittest
|
|
@@ -5,12 +25,13 @@ from test.support import os_helper
|
|
|
|
import os
|
|
|
|
-class BoolTest(unittest.TestCase):
|
|
+class BoolTest(__TestCase):
|
|
|
|
def test_subclass(self):
|
|
try:
|
|
- class C(bool):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class C(bool):
|
|
+ pass
|
|
except TypeError:
|
|
pass
|
|
else:
|
|
@@ -307,40 +328,46 @@ class BoolTest(unittest.TestCase):
|
|
# from __bool__(). This isn't really a bool test, but
|
|
# it's related.
|
|
check = lambda o: self.assertRaises(TypeError, bool, o)
|
|
- class Foo(object):
|
|
- def __bool__(self):
|
|
- return self
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Foo(object):
|
|
+ def __bool__(self):
|
|
+ return self
|
|
check(Foo())
|
|
|
|
- class Bar(object):
|
|
- def __bool__(self):
|
|
- return "Yes"
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Bar(object):
|
|
+ def __bool__(self):
|
|
+ return "Yes"
|
|
check(Bar())
|
|
|
|
- class Baz(int):
|
|
- def __bool__(self):
|
|
- return self
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Baz(int):
|
|
+ def __bool__(self):
|
|
+ return self
|
|
check(Baz())
|
|
|
|
# __bool__() must return a bool not an int
|
|
- class Spam(int):
|
|
- def __bool__(self):
|
|
- return 1
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Spam(int):
|
|
+ def __bool__(self):
|
|
+ return 1
|
|
check(Spam())
|
|
|
|
- class Eggs:
|
|
- def __len__(self):
|
|
- return -1
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Eggs:
|
|
+ def __len__(self):
|
|
+ return -1
|
|
self.assertRaises(ValueError, bool, Eggs())
|
|
|
|
def test_interpreter_convert_to_bool_raises(self):
|
|
- class SymbolicBool:
|
|
- def __bool__(self):
|
|
- raise TypeError
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class SymbolicBool:
|
|
+ def __bool__(self):
|
|
+ raise TypeError
|
|
|
|
- class Symbol:
|
|
- def __gt__(self, other):
|
|
- return SymbolicBool()
|
|
+ class Symbol:
|
|
+ def __gt__(self, other):
|
|
+ return SymbolicBool()
|
|
|
|
x = Symbol()
|
|
|
|
@@ -361,9 +388,10 @@ class BoolTest(unittest.TestCase):
|
|
# this test just tests our assumptions about __len__
|
|
# this will start failing if __len__ changes assertions
|
|
for badval in ['illegal', -1, 1 << 32]:
|
|
- class A:
|
|
- def __len__(self):
|
|
- return badval
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class A:
|
|
+ def __len__(self):
|
|
+ return badval
|
|
try:
|
|
bool(A())
|
|
except (Exception) as e_bool:
|
|
@@ -373,14 +401,16 @@ class BoolTest(unittest.TestCase):
|
|
self.assertEqual(str(e_bool), str(e_len))
|
|
|
|
def test_blocked(self):
|
|
- class A:
|
|
- __bool__ = None
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class A:
|
|
+ __bool__ = None
|
|
self.assertRaises(TypeError, bool, A())
|
|
|
|
- class B:
|
|
- def __len__(self):
|
|
- return 10
|
|
- __bool__ = None
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class B:
|
|
+ def __len__(self):
|
|
+ return 10
|
|
+ __bool__ = None
|
|
self.assertRaises(TypeError, bool, B())
|
|
|
|
def test_real_and_imag(self):
|
|
@@ -394,12 +424,13 @@ class BoolTest(unittest.TestCase):
|
|
self.assertIs(type(False.imag), int)
|
|
|
|
def test_bool_called_at_least_once(self):
|
|
- class X:
|
|
- def __init__(self):
|
|
- self.count = 0
|
|
- def __bool__(self):
|
|
- self.count += 1
|
|
- return True
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class X:
|
|
+ def __init__(self):
|
|
+ self.count = 0
|
|
+ def __bool__(self):
|
|
+ self.count += 1
|
|
+ return True
|
|
|
|
def f(x):
|
|
if x or True:
|
|
@@ -418,4 +449,4 @@ class BoolTest(unittest.TestCase):
|
|
|
|
|
|
if __name__ == "__main__":
|
|
- unittest.main()
|
|
+ run_tests()
|