[dynamo] rename set_fullgraph to error_on_graph_break (#161739)

Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph.

I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet).

 cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739
Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
This commit is contained in:
William Wen 2025-09-03 14:25:40 -07:00 committed by PyTorch MergeBot
parent 1281470155
commit 8678d831c4
57 changed files with 1577 additions and 1571 deletions

View File

@ -62,16 +62,16 @@ index dbc5ef4f9f2..af717703053 100644
@@ -5,7 +58,7 @@ Tests common to list and UserList.UserList
import sys
from functools import cmp_to_key
-from test import seq_tests
+import seq_tests
from test.support import ALWAYS_EQ, NEVER_EQ, get_c_recursion_limit
@@ -119,10 +172,6 @@ class CommonTest(seq_tests.CommonTest):
a[-1] = 9
self.assertEqual(a, self.type2test([5,6,7,8,9]))
- msg = "list indices must be integers or slices"
- with self.assertRaisesRegex(TypeError, msg):
- a['a'] = "python"
@ -81,7 +81,7 @@ index dbc5ef4f9f2..af717703053 100644
del a[1]
@@ -270,13 +319,14 @@ class CommonTest(seq_tests.CommonTest):
self.assertRaises(TypeError, a.extend)
# overflow test. issue1621
- class CustomIter:
- def __iter__(self):
@ -90,7 +90,7 @@ index dbc5ef4f9f2..af717703053 100644
- raise StopIteration
- def __length_hint__(self):
- return sys.maxsize
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class CustomIter:
+ def __iter__(self):
+ return self
@ -104,13 +104,13 @@ index dbc5ef4f9f2..af717703053 100644
@@ -337,21 +387,23 @@ class CommonTest(seq_tests.CommonTest):
a = self.type2test([NEVER_EQ])
self.assertRaises(ValueError, a.remove, ALWAYS_EQ)
- class BadExc(Exception):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadExc(Exception):
+ pass
- class BadCmp:
- def __eq__(self, other):
- if other == 2:
@ -121,24 +121,24 @@ index dbc5ef4f9f2..af717703053 100644
+ if other == 2:
+ raise BadExc()
+ return False
a = self.type2test([0, 1, 2, 3])
self.assertRaises(BadExc, a.remove, BadCmp())
- class BadCmp2:
- def __eq__(self, other):
- raise BadExc()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadCmp2:
+ def __eq__(self, other):
+ raise BadExc()
d = self.type2test('abcdefghcij')
d.remove('c')
@@ -376,13 +428,14 @@ class CommonTest(seq_tests.CommonTest):
self.assertRaises(ValueError, a.index, 2, 0, 4)
self.assertEqual(a, self.type2test([-2, -1, 0, 1, 2]))
- # Test modifying the list during index's iteration
- class EvilCmp:
- def __init__(self, victim):
@ -146,7 +146,7 @@ index dbc5ef4f9f2..af717703053 100644
- def __eq__(self, other):
- del self.victim[:]
- return False
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Test modifying the list during index's iteration
+ class EvilCmp:
+ def __init__(self, victim):

View File

@ -319,7 +319,7 @@ class CommonTest(seq_tests.CommonTest):
self.assertRaises(TypeError, a.extend)
# overflow test. issue1621
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class CustomIter:
def __iter__(self):
return self
@ -387,7 +387,7 @@ class CommonTest(seq_tests.CommonTest):
a = self.type2test([NEVER_EQ])
self.assertRaises(ValueError, a.remove, ALWAYS_EQ)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadExc(Exception):
pass
@ -400,7 +400,7 @@ class CommonTest(seq_tests.CommonTest):
a = self.type2test([0, 1, 2, 3])
self.assertRaises(BadExc, a.remove, BadCmp())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadCmp2:
def __eq__(self, other):
raise BadExc()
@ -428,7 +428,7 @@ class CommonTest(seq_tests.CommonTest):
self.assertRaises(ValueError, a.index, 2, 0, 4)
self.assertEqual(a, self.type2test([-2, -1, 0, 1, 2]))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Test modifying the list during index's iteration
class EvilCmp:
def __init__(self, victim):

View File

@ -61,16 +61,16 @@ index ed89a81a6ea..b19cec7cb23 100644
import unittest
import collections
from test.support import get_c_recursion_limit
-class BasicTestMappingProtocol(unittest.TestCase):
+class BasicTestMappingProtocol(__TestCase):
# This base class can be used to check that an object conforms to the
# mapping protocol
@@ -196,70 +250,76 @@ class BasicTestMappingProtocol(unittest.TestCase):
self.assertRaises((TypeError, AttributeError), d.update, 42)
outerself = self
- class SimpleUserDict:
- def __init__(self):
@ -79,7 +79,7 @@ index ed89a81a6ea..b19cec7cb23 100644
- return self.d.keys()
- def __getitem__(self, i):
- return self.d[i]
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class SimpleUserDict:
+ def __init__(self):
+ self.d = outerself.reference
@ -92,23 +92,23 @@ index ed89a81a6ea..b19cec7cb23 100644
i1 = sorted(d.items())
i2 = sorted(self.reference.items())
self.assertEqual(i1, i2)
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
d = self._empty_mapping()
- class FailingUserDict:
- def keys(self):
- raise Exc
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class FailingUserDict:
+ def keys(self):
+ raise Exc
self.assertRaises(Exc, d.update, FailingUserDict())
d.clear()
- class FailingUserDict:
- def keys(self):
- class BogonIter:
@ -124,7 +124,7 @@ index ed89a81a6ea..b19cec7cb23 100644
- return BogonIter()
- def __getitem__(self, key):
- return key
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class FailingUserDict:
+ def keys(self):
+ class BogonIter:
@ -141,7 +141,7 @@ index ed89a81a6ea..b19cec7cb23 100644
+ def __getitem__(self, key):
+ return key
self.assertRaises(Exc, d.update, FailingUserDict())
- class FailingUserDict:
- def keys(self):
- class BogonIter:
@ -158,7 +158,7 @@ index ed89a81a6ea..b19cec7cb23 100644
- return BogonIter()
- def __getitem__(self, key):
- raise Exc
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class FailingUserDict:
+ def keys(self):
+ class BogonIter:
@ -176,26 +176,26 @@ index ed89a81a6ea..b19cec7cb23 100644
+ def __getitem__(self, key):
+ raise Exc
self.assertRaises(Exc, d.update, FailingUserDict())
d = self._empty_mapping()
- class badseq(object):
- def __iter__(self):
- return self
- def __next__(self):
- raise Exc()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class badseq(object):
+ def __iter__(self):
+ return self
+ def __next__(self):
+ raise Exc()
self.assertRaises(Exc, d.update, badseq())
@@ -409,13 +469,14 @@ class TestMappingProtocol(BasicTestMappingProtocol):
d.update(self._full_mapping({1:2, 3:4, 5:6}).items())
self.assertEqual(d, {1:2, 2:4, 3:4, 5:6})
- class SimpleUserDict:
- def __init__(self):
- self.d = {1:1, 2:2, 3:3}
@ -203,7 +203,7 @@ index ed89a81a6ea..b19cec7cb23 100644
- return self.d.keys()
- def __getitem__(self, i):
- return self.d[i]
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class SimpleUserDict:
+ def __init__(self):
+ self.d = {1:1, 2:2, 3:3}
@ -219,7 +219,7 @@ index ed89a81a6ea..b19cec7cb23 100644
self.assertEqual(d.fromkeys(g()), {1:None})
self.assertRaises(TypeError, {}.fromkeys, 3)
- class dictlike(self.type2test): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class dictlike(self.type2test): pass
self.assertEqual(dictlike.fromkeys('a'), {'a':None})
self.assertEqual(dictlike().fromkeys('a'), {'a':None})
@ -229,7 +229,7 @@ index ed89a81a6ea..b19cec7cb23 100644
- class mydict(self.type2test):
- def __new__(cls):
- return collections.UserDict()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class mydict(self.type2test):
+ def __new__(cls):
+ return collections.UserDict()
@ -237,52 +237,52 @@ index ed89a81a6ea..b19cec7cb23 100644
self.assertEqual(ud, {'a':None, 'b':None})
self.assertIsInstance(ud, collections.UserDict)
self.assertRaises(TypeError, dict.fromkeys)
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class baddict1(self.type2test):
- def __init__(self, *args, **kwargs):
- raise Exc()
+ class baddict1(self.type2test):
+ def __init__(self, *args, **kwargs):
+ raise Exc()
self.assertRaises(Exc, baddict1.fromkeys, [1])
- class BadSeq(object):
- def __iter__(self):
- return self
- def __next__(self):
- raise Exc()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadSeq(object):
+ def __iter__(self):
+ return self
+ def __next__(self):
+ raise Exc()
self.assertRaises(Exc, self.type2test.fromkeys, BadSeq())
- class baddict2(self.type2test):
- def __setitem__(self, key, value):
- raise Exc()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class baddict2(self.type2test):
+ def __setitem__(self, key, value):
+ raise Exc()
self.assertRaises(Exc, baddict2.fromkeys, [1])
@@ -537,25 +603,27 @@ class TestHashMappingProtocol(TestMappingProtocol):
def test_getitem(self):
TestMappingProtocol.test_getitem(self)
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class BadEq(object):
- def __eq__(self, other):
- raise Exc()
@ -293,11 +293,11 @@ index ed89a81a6ea..b19cec7cb23 100644
+ raise Exc()
+ def __hash__(self):
+ return 24
d = self._empty_mapping()
d[BadEq()] = 42
self.assertRaises(KeyError, d.__getitem__, 23)
- class BadHash(object):
- fail = False
- def __hash__(self):
@ -305,7 +305,7 @@ index ed89a81a6ea..b19cec7cb23 100644
- raise Exc()
- else:
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadHash(object):
+ fail = False
+ def __hash__(self):
@ -313,17 +313,17 @@ index ed89a81a6ea..b19cec7cb23 100644
+ raise Exc()
+ else:
+ return 42
d = self._empty_mapping()
x = BadHash()
@@ -565,9 +633,10 @@ class TestHashMappingProtocol(TestMappingProtocol):
def test_fromkeys(self):
TestMappingProtocol.test_fromkeys(self)
- class mydict(self.type2test):
- def __new__(cls):
- return collections.UserDict()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class mydict(self.type2test):
+ def __new__(cls):
+ return collections.UserDict()
@ -333,11 +333,11 @@ index ed89a81a6ea..b19cec7cb23 100644
@@ -575,15 +644,16 @@ class TestHashMappingProtocol(TestMappingProtocol):
def test_pop(self):
TestMappingProtocol.test_pop(self)
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class BadHash(object):
- fail = False
- def __hash__(self):
@ -352,34 +352,34 @@ index ed89a81a6ea..b19cec7cb23 100644
+ raise Exc()
+ else:
+ return 42
d = self._empty_mapping()
x = BadHash()
@@ -613,11 +683,12 @@ class TestHashMappingProtocol(TestMappingProtocol):
d[1] = d
self.assertEqual(repr(d), '{1: {...}}')
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class BadRepr(object):
- def __repr__(self):
- raise Exc()
+ class BadRepr(object):
+ def __repr__(self):
+ raise Exc()
d = self._full_mapping({1: BadRepr()})
self.assertRaises(Exc, repr, d)
@@ -635,13 +706,14 @@ class TestHashMappingProtocol(TestMappingProtocol):
self.assertEqual(self._full_mapping({1: 2}),
self._full_mapping({1: 2}))
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class BadCmp(object):
- def __eq__(self, other):
- raise Exc()
@ -390,17 +390,17 @@ index ed89a81a6ea..b19cec7cb23 100644
+ raise Exc()
+ def __hash__(self):
+ return 1
d1 = self._full_mapping({BadCmp(): 1})
d2 = self._full_mapping({1: 1})
@@ -651,15 +723,16 @@ class TestHashMappingProtocol(TestMappingProtocol):
def test_setdefault(self):
TestMappingProtocol.test_setdefault(self)
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class BadHash(object):
- fail = False
- def __hash__(self):
@ -415,6 +415,6 @@ index ed89a81a6ea..b19cec7cb23 100644
+ raise Exc()
+ else:
+ return 42
d = self._empty_mapping()
x = BadHash()

View File

@ -250,7 +250,7 @@ class BasicTestMappingProtocol(__TestCase):
self.assertRaises((TypeError, AttributeError), d.update, 42)
outerself = self
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class SimpleUserDict:
def __init__(self):
self.d = outerself.reference
@ -264,11 +264,11 @@ class BasicTestMappingProtocol(__TestCase):
i2 = sorted(self.reference.items())
self.assertEqual(i1, i2)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
d = self._empty_mapping()
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class FailingUserDict:
def keys(self):
raise Exc
@ -276,7 +276,7 @@ class BasicTestMappingProtocol(__TestCase):
d.clear()
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class FailingUserDict:
def keys(self):
class BogonIter:
@ -294,7 +294,7 @@ class BasicTestMappingProtocol(__TestCase):
return key
self.assertRaises(Exc, d.update, FailingUserDict())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class FailingUserDict:
def keys(self):
class BogonIter:
@ -314,7 +314,7 @@ class BasicTestMappingProtocol(__TestCase):
self.assertRaises(Exc, d.update, FailingUserDict())
d = self._empty_mapping()
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class badseq(object):
def __iter__(self):
return self
@ -469,7 +469,7 @@ class TestMappingProtocol(BasicTestMappingProtocol):
d.update(self._full_mapping({1:2, 3:4, 5:6}).items())
self.assertEqual(d, {1:2, 2:4, 3:4, 5:6})
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class SimpleUserDict:
def __init__(self):
self.d = {1:1, 2:2, 3:3}
@ -492,14 +492,14 @@ class TestMappingProtocol(BasicTestMappingProtocol):
yield 1
self.assertEqual(d.fromkeys(g()), {1:None})
self.assertRaises(TypeError, {}.fromkeys, 3)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class dictlike(self.type2test): pass
self.assertEqual(dictlike.fromkeys('a'), {'a':None})
self.assertEqual(dictlike().fromkeys('a'), {'a':None})
self.assertTrue(dictlike.fromkeys('a').__class__ is dictlike)
self.assertTrue(dictlike().fromkeys('a').__class__ is dictlike)
self.assertTrue(type(dictlike.fromkeys('a')) is dictlike)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class mydict(self.type2test):
def __new__(cls):
return collections.UserDict()
@ -508,7 +508,7 @@ class TestMappingProtocol(BasicTestMappingProtocol):
self.assertIsInstance(ud, collections.UserDict)
self.assertRaises(TypeError, dict.fromkeys)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class baddict1(self.type2test):
@ -517,7 +517,7 @@ class TestMappingProtocol(BasicTestMappingProtocol):
self.assertRaises(Exc, baddict1.fromkeys, [1])
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadSeq(object):
def __iter__(self):
return self
@ -526,7 +526,7 @@ class TestMappingProtocol(BasicTestMappingProtocol):
self.assertRaises(Exc, self.type2test.fromkeys, BadSeq())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class baddict2(self.type2test):
def __setitem__(self, key, value):
raise Exc()
@ -603,7 +603,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
def test_getitem(self):
TestMappingProtocol.test_getitem(self)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadEq(object):
@ -616,7 +616,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
d[BadEq()] = 42
self.assertRaises(KeyError, d.__getitem__, 23)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadHash(object):
fail = False
def __hash__(self):
@ -633,7 +633,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
def test_fromkeys(self):
TestMappingProtocol.test_fromkeys(self)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class mydict(self.type2test):
def __new__(cls):
return collections.UserDict()
@ -644,7 +644,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
def test_pop(self):
TestMappingProtocol.test_pop(self)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadHash(object):
@ -683,7 +683,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
d[1] = d
self.assertEqual(repr(d), '{1: {...}}')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadRepr(object):
@ -706,7 +706,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
self.assertEqual(self._full_mapping({1: 2}),
self._full_mapping({1: 2}))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadCmp(object):
@ -723,7 +723,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
def test_setdefault(self):
TestMappingProtocol.test_setdefault(self)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadHash(object):

View File

@ -63,15 +63,15 @@ index 719c9434a16..290e57c04a0 100644
@@ -95,7 +149,7 @@ class LyingList(list):
def __iter__(self):
yield 1
-class CommonTest(unittest.TestCase):
+class CommonTest(__TestCase):
# The type to be tested
type2test = None
@@ -115,13 +169,14 @@ class CommonTest(unittest.TestCase):
uu2 = self.type2test(u2)
v = self.type2test(tuple(u))
- class OtherSeq:
- def __init__(self, initseq):
@ -80,7 +80,7 @@ index 719c9434a16..290e57c04a0 100644
- return len(self.__data)
- def __getitem__(self, i):
- return self.__data[i]
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class OtherSeq:
+ def __init__(self, initseq):
+ self.__data = initseq
@ -100,51 +100,51 @@ index 719c9434a16..290e57c04a0 100644
- class StopCompares:
- def __eq__(self, other):
- raise DoNotTestEq
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class DoNotTestEq(Exception):
+ pass
+ class StopCompares:
+ def __eq__(self, other):
+ raise DoNotTestEq
checkfirst = self.type2test([1, StopCompares()])
self.assertIn(1, checkfirst)
@@ -283,8 +339,9 @@ class CommonTest(unittest.TestCase):
self.assertEqual(u2+u2+u2, u2*3)
self.assertEqual(u2+u2+u2, 3*u2)
- class subclass(self.type2test):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass(self.type2test):
+ pass
u3 = subclass([0, 1])
self.assertEqual(u3, u3*1)
self.assertIsNot(u3, u3*1)
@@ -311,9 +368,10 @@ class CommonTest(unittest.TestCase):
def test_getitemoverwriteiter(self):
# Verify that __getitem__ overrides are not recognized by __iter__
- class T(self.type2test):
- def __getitem__(self, key):
- return str(key) + '!!!'
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class T(self.type2test):
+ def __getitem__(self, key):
+ return str(key) + '!!!'
self.assertEqual(next(iter(T((1,2)))), 1)
def test_repeat(self):
@@ -361,14 +419,15 @@ class CommonTest(unittest.TestCase):
self.assertRaises(TypeError, a.count)
- class BadExc(Exception):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadExc(Exception):
+ pass
- class BadCmp:
- def __eq__(self, other):
- if other == 2:
@ -155,19 +155,19 @@ index 719c9434a16..290e57c04a0 100644
+ if other == 2:
+ raise BadExc()
+ return False
self.assertRaises(BadExc, a.count, BadCmp())
@@ -394,14 +453,15 @@ class CommonTest(unittest.TestCase):
self.assertRaises(TypeError, u.index)
- class BadExc(Exception):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadExc(Exception):
+ pass
- class BadCmp:
- def __eq__(self, other):
- if other == 2:
@ -178,6 +178,6 @@ index 719c9434a16..290e57c04a0 100644
+ if other == 2:
+ raise BadExc()
+ return False
a = self.type2test([0, 1, 2, 3])
self.assertRaises(BadExc, a.index, BadCmp())

View File

@ -169,7 +169,7 @@ class CommonTest(__TestCase):
uu2 = self.type2test(u2)
v = self.type2test(tuple(u))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class OtherSeq:
def __init__(self, initseq):
self.__data = initseq
@ -294,7 +294,7 @@ class CommonTest(__TestCase):
# Sequences must test in-order. If a rich comparison has side
# effects, these will be visible to tests against later members.
# In this test, the "side effect" is a short-circuiting raise.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class DoNotTestEq(Exception):
pass
class StopCompares:
@ -339,7 +339,7 @@ class CommonTest(__TestCase):
self.assertEqual(u2+u2+u2, u2*3)
self.assertEqual(u2+u2+u2, 3*u2)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass(self.type2test):
pass
u3 = subclass([0, 1])
@ -368,7 +368,7 @@ class CommonTest(__TestCase):
def test_getitemoverwriteiter(self):
# Verify that __getitem__ overrides are not recognized by __iter__
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class T(self.type2test):
def __getitem__(self, key):
return str(key) + '!!!'
@ -419,7 +419,7 @@ class CommonTest(__TestCase):
self.assertRaises(TypeError, a.count)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadExc(Exception):
pass
@ -453,7 +453,7 @@ class CommonTest(__TestCase):
self.assertRaises(TypeError, u.index)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadExc(Exception):
pass

View File

@ -24,20 +24,20 @@ index 34ecb45f161..12b719c432b 100644
+# ======= END DYNAMO PATCH =======
+
# Test properties of bool promised by PEP 285
import unittest
@@ -5,12 +25,13 @@ from test.support import os_helper
import os
-class BoolTest(unittest.TestCase):
+class BoolTest(__TestCase):
def test_subclass(self):
try:
- class C(bool):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(bool):
+ pass
except TypeError:
@ -50,67 +50,67 @@ index 34ecb45f161..12b719c432b 100644
- class Foo(object):
- def __bool__(self):
- return self
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Foo(object):
+ def __bool__(self):
+ return self
check(Foo())
- class Bar(object):
- def __bool__(self):
- return "Yes"
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Bar(object):
+ def __bool__(self):
+ return "Yes"
check(Bar())
- class Baz(int):
- def __bool__(self):
- return self
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Baz(int):
+ def __bool__(self):
+ return self
check(Baz())
# __bool__() must return a bool not an int
- class Spam(int):
- def __bool__(self):
- return 1
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Spam(int):
+ def __bool__(self):
+ return 1
check(Spam())
- class Eggs:
- def __len__(self):
- return -1
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Eggs:
+ def __len__(self):
+ return -1
self.assertRaises(ValueError, bool, Eggs())
def test_interpreter_convert_to_bool_raises(self):
- class SymbolicBool:
- def __bool__(self):
- raise TypeError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class SymbolicBool:
+ def __bool__(self):
+ raise TypeError
- class Symbol:
- def __gt__(self, other):
- return SymbolicBool()
+ class Symbol:
+ def __gt__(self, other):
+ return SymbolicBool()
x = Symbol()
@@ -361,9 +388,10 @@ class BoolTest(unittest.TestCase):
# this test just tests our assumptions about __len__
# this will start failing if __len__ changes assertions
@ -118,7 +118,7 @@ index 34ecb45f161..12b719c432b 100644
- class A:
- def __len__(self):
- return badval
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ def __len__(self):
+ return badval
@ -127,30 +127,30 @@ index 34ecb45f161..12b719c432b 100644
except (Exception) as e_bool:
@@ -373,14 +401,16 @@ class BoolTest(unittest.TestCase):
self.assertEqual(str(e_bool), str(e_len))
def test_blocked(self):
- class A:
- __bool__ = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ __bool__ = None
self.assertRaises(TypeError, bool, A())
- class B:
- def __len__(self):
- return 10
- __bool__ = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class B:
+ def __len__(self):
+ return 10
+ __bool__ = None
self.assertRaises(TypeError, bool, B())
def test_real_and_imag(self):
@@ -394,12 +424,13 @@ class BoolTest(unittest.TestCase):
self.assertIs(type(False.imag), int)
def test_bool_called_at_least_once(self):
- class X:
- def __init__(self):
@ -158,19 +158,19 @@ index 34ecb45f161..12b719c432b 100644
- def __bool__(self):
- self.count += 1
- return True
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X:
+ def __init__(self):
+ self.count = 0
+ def __bool__(self):
+ self.count += 1
+ return True
def f(x):
if x or True:
@@ -418,4 +449,4 @@ class BoolTest(unittest.TestCase):
if __name__ == "__main__":
- unittest.main()
+ run_tests()

View File

@ -29,7 +29,7 @@ class BoolTest(__TestCase):
def test_subclass(self):
try:
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(bool):
pass
except TypeError:
@ -328,39 +328,39 @@ class BoolTest(__TestCase):
# from __bool__(). This isn't really a bool test, but
# it's related.
check = lambda o: self.assertRaises(TypeError, bool, o)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Foo(object):
def __bool__(self):
return self
check(Foo())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Bar(object):
def __bool__(self):
return "Yes"
check(Bar())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Baz(int):
def __bool__(self):
return self
check(Baz())
# __bool__() must return a bool not an int
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Spam(int):
def __bool__(self):
return 1
check(Spam())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Eggs:
def __len__(self):
return -1
self.assertRaises(ValueError, bool, Eggs())
def test_interpreter_convert_to_bool_raises(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class SymbolicBool:
def __bool__(self):
raise TypeError
@ -388,7 +388,7 @@ class BoolTest(__TestCase):
# this test just tests our assumptions about __len__
# this will start failing if __len__ changes assertions
for badval in ['illegal', -1, 1 << 32]:
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
def __len__(self):
return badval
@ -401,12 +401,12 @@ class BoolTest(__TestCase):
self.assertEqual(str(e_bool), str(e_len))
def test_blocked(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
__bool__ = None
self.assertRaises(TypeError, bool, A())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class B:
def __len__(self):
return 10
@ -424,7 +424,7 @@ class BoolTest(__TestCase):
self.assertIs(type(False.imag), int)
def test_bool_called_at_least_once(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X:
def __init__(self):
self.count = 0

View File

@ -65,7 +65,7 @@ index a96a5780b31..d00dfca8a17 100644
@@ -50,7 +103,7 @@ complex_nans = [complex(x, y) for x, y in [
(INF, NAN)
]]
-class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase):
+class CMathTests(__TestCase):
# list of all functions in cmath
@ -74,7 +74,7 @@ index a96a5780b31..d00dfca8a17 100644
@@ -66,6 +119,39 @@ class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase):
def tearDown(self):
self.test_values.close()
+ def assertFloatIdentical(self, x, y):
+ """Fail unless floats x and y are identical, in the sense that:
+ (1) both x and y are nans, or
@ -113,7 +113,7 @@ index a96a5780b31..d00dfca8a17 100644
"""Fail if the two floating-point numbers are not almost equal.
@@ -165,38 +251,39 @@ class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase):
# end up being passed to the cmath functions
# usual case: new-style class implementing __complex__
- class MyComplex:
- def __init__(self, value):
@ -127,7 +127,7 @@ index a96a5780b31..d00dfca8a17 100644
- class MyComplexException:
- def __complex__(self):
- raise SomeException
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyComplex:
+ def __init__(self, value):
+ self.value = value
@ -140,7 +140,7 @@ index a96a5780b31..d00dfca8a17 100644
+ class MyComplexException:
+ def __complex__(self):
+ raise SomeException
- # some classes not providing __float__ or __complex__
- class NeitherComplexNorFloat(object):
- pass
@ -179,12 +179,12 @@ index a96a5780b31..d00dfca8a17 100644
+ class JustFloat:
+ def __float__(self):
+ return flt_arg
for f in self.test_functions:
# usual usage
@@ -590,4 +677,4 @@ class IsCloseTests(test_math.IsCloseTests):
if __name__ == "__main__":
- unittest.main()
+ run_tests()

View File

@ -251,7 +251,7 @@ class CMathTests(__TestCase):
# end up being passed to the cmath functions
# usual case: new-style class implementing __complex__
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyComplex:
def __init__(self, value):
self.value = value

View File

@ -24,12 +24,12 @@ index cafc44007d1..4571e5a14fd 100644
+# ======= END DYNAMO PATCH =======
+
"""Unit tests for collections.py."""
import array
@@ -29,7 +49,7 @@ from collections.abc import Sequence, MutableSequence
from collections.abc import ByteString, Buffer
-class TestUserObjects(unittest.TestCase):
+class TestUserObjects(__TestCase):
def _superset_test(self, a, b):
@ -37,12 +37,12 @@ index cafc44007d1..4571e5a14fd 100644
set(dir(a)),
@@ -73,9 +93,10 @@ class TestUserObjects(unittest.TestCase):
self._copy_test(obj)
def test_dict_missing(self):
- class A(UserDict):
- def __missing__(self, key):
- return 456
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A(UserDict):
+ def __missing__(self, key):
+ return 456
@ -52,20 +52,20 @@ index cafc44007d1..4571e5a14fd 100644
@@ -85,7 +106,7 @@ class TestUserObjects(unittest.TestCase):
### ChainMap (helper class for configparser and the string module)
################################################################################
-class TestChainMap(unittest.TestCase):
+class TestChainMap(__TestCase):
def test_basics(self):
c = ChainMap()
@@ -172,9 +193,10 @@ class TestChainMap(unittest.TestCase):
self.assertTrue(ChainMap({}, {1:2}))
def test_missing(self):
- class DefaultChainMap(ChainMap):
- def __missing__(self, key):
- return 999
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class DefaultChainMap(ChainMap):
+ def __missing__(self, key):
+ return 999
@ -74,7 +74,7 @@ index cafc44007d1..4571e5a14fd 100644
self.assertEqual(d[k], v) # check __getitem__ w/missing
@@ -206,13 +228,14 @@ class TestChainMap(unittest.TestCase):
('i', 9999), ('j', 0)])
def test_iter_not_calling_getitem_on_maps(self):
- class DictWithGetItem(UserDict):
- def __init__(self, *args, **kwds):
@ -83,7 +83,7 @@ index cafc44007d1..4571e5a14fd 100644
- def __getitem__(self, item):
- self.called = True
- UserDict.__getitem__(self, item)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class DictWithGetItem(UserDict):
+ def __init__(self, *args, **kwds):
+ self.called = False
@ -91,12 +91,12 @@ index cafc44007d1..4571e5a14fd 100644
+ def __getitem__(self, item):
+ self.called = True
+ UserDict.__getitem__(self, item)
d = DictWithGetItem(a=1)
c = ChainMap(d)
@@ -237,15 +260,16 @@ class TestChainMap(unittest.TestCase):
self.assertIs(m, d.maps[0])
# Use a different map than a dict
- class lowerdict(dict):
- def __getitem__(self, key):
@ -107,7 +107,7 @@ index cafc44007d1..4571e5a14fd 100644
- if isinstance(key, str):
- key = key.lower()
- return dict.__contains__(self, key)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class lowerdict(dict):
+ def __getitem__(self, key):
+ if isinstance(key, str):
@ -117,46 +117,46 @@ index cafc44007d1..4571e5a14fd 100644
+ if isinstance(key, str):
+ key = key.lower()
+ return dict.__contains__(self, key)
c = ChainMap()
c['a'] = 1
@@ -315,7 +339,7 @@ class TestChainMap(unittest.TestCase):
TestNT = namedtuple('TestNT', 'x y z') # type used for pickle tests
-class TestNamedTuple(unittest.TestCase):
+class TestNamedTuple(__TestCase):
def test_factory(self):
Point = namedtuple('Point', 'x y')
@@ -666,8 +690,9 @@ class TestNamedTuple(unittest.TestCase):
NT = namedtuple('NT', ['abc', 'def'], False, True)
def test_namedtuple_subclass_issue_24931(self):
- class Point(namedtuple('_Point', ['x', 'y'])):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Point(namedtuple('_Point', ['x', 'y'])):
+ pass
a = Point(3, 4)
self.assertEqual(a._asdict(), OrderedDict([('x', 3), ('y', 4)]))
@@ -722,21 +747,26 @@ class TestNamedTuple(unittest.TestCase):
### Abstract Base Classes
################################################################################
-class ABCTestCase(unittest.TestCase):
+class ABCTestCase(__TestCase):
def validate_abstract_methods(self, abc, *names):
methodstubs = dict.fromkeys(names, lambda s, *args: 0)
# everything should work will all required methods are present
- C = type('C', (abc,), methodstubs)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ C = type('C', (abc,), methodstubs)
C()
+ # Dynamo raises a hard error here that we can't easily capture
+ # Commenting this part as this would also fail in eager if a user
+ # attempt to run the same code
@ -172,7 +172,7 @@ index cafc44007d1..4571e5a14fd 100644
+ # del stubs[name]
+ # C = type('C', (abc,), stubs)
+ # self.assertRaises(TypeError, C, name)
def validate_isinstance(self, abc, name):
stub = lambda s, *args: 0
@@ -981,19 +1011,21 @@ class TestOneTrickPonyABCs(ABCTestCase):
@ -183,7 +183,7 @@ index cafc44007d1..4571e5a14fd 100644
- class I(Iterable):
- def __iter__(self):
- return super().__iter__()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Check direct subclassing
+ class I(Iterable):
+ def __iter__(self):
@ -197,7 +197,7 @@ index cafc44007d1..4571e5a14fd 100644
- def __iter__(self): return iter([])
- class ItBlocked(It):
- __iter__ = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Check None blocking
+ class It:
+ def __iter__(self): return iter([])
@ -216,7 +216,7 @@ index cafc44007d1..4571e5a14fd 100644
- return iter(list())
- def __reversed__(self):
- return iter(list())
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Check direct subclassing
+ class R(Reversible):
+ def __iter__(self):
@ -231,7 +231,7 @@ index cafc44007d1..4571e5a14fd 100644
- def __reversed__(self): return reversed([])
- class RevPlusIter(RevNoIter):
- def __iter__(self): return iter([])
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Check reversible non-iterable (which is not Reversible)
+ class RevNoIter:
+ def __reversed__(self): return reversed([])
@ -249,7 +249,7 @@ index cafc44007d1..4571e5a14fd 100644
- __iter__ = None
- class RevRevBlocked(Rev):
- __reversed__ = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Check None blocking
+ class Rev:
+ def __iter__(self): return iter([])
@ -274,7 +274,7 @@ index cafc44007d1..4571e5a14fd 100644
- def __contains__(self, item):
- return False
- class DerCol(Col): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Check direct subclassing
+ class Col(Collection):
+ def __iter__(self):
@ -300,7 +300,7 @@ index cafc44007d1..4571e5a14fd 100644
- class ColNoCont:
- def __iter__(self): return iter([])
- def __len__(self): return 0
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class ColNoIter:
+ def __len__(self): return 0
+ def __contains__(self, item): return False
@ -326,7 +326,7 @@ index cafc44007d1..4571e5a14fd 100644
- def __contains__(self): return True
- __iter__ = None
+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Check None blocking
+ class SizeBlock:
+ def __iter__(self): return iter([])
@ -350,7 +350,7 @@ index cafc44007d1..4571e5a14fd 100644
- return False
- class NonCol(ColImpl):
- __contains__ = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Check None blocking in subclass
+ class ColImpl:
+ def __iter__(self):
@ -363,24 +363,24 @@ index cafc44007d1..4571e5a14fd 100644
+ __contains__ = None
self.assertFalse(issubclass(NonCol, Collection))
self.assertFalse(isinstance(NonCol(), Collection))
@@ -1162,30 +1202,32 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertTrue(issubclass(type(x), Iterator), repr(type(x)))
self.validate_abstract_methods(Iterator, '__next__', '__iter__')
- # Issue 10565
- class NextOnly:
- def __next__(self):
- yield 1
- return
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Issue 10565
+ class NextOnly:
+ def __next__(self):
+ yield 1
+ return
self.assertNotIsInstance(NextOnly(), Iterator)
def test_Generator(self):
- class NonGen1:
- def __iter__(self): return self
@ -398,7 +398,7 @@ index cafc44007d1..4571e5a14fd 100644
- def close(self): pass
- def send(self, value): return value
- def throw(self, typ, val=None, tb=None): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class NonGen1:
+ def __iter__(self): return self
+ def __next__(self): return None
@ -415,27 +415,27 @@ index cafc44007d1..4571e5a14fd 100644
+ def close(self): pass
+ def send(self, value): return value
+ def throw(self, typ, val=None, tb=None): pass
non_samples = [
None, 42, 3.14, 1j, b"", "", (), [], {}, set(),
@@ -1194,18 +1236,19 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertNotIsInstance(x, Generator)
self.assertFalse(issubclass(type(x), Generator), repr(type(x)))
- class Gen:
- def __iter__(self): return self
- def __next__(self): return None
- def close(self): pass
- def send(self, value): return value
- def throw(self, typ, val=None, tb=None): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Gen:
+ def __iter__(self): return self
+ def __next__(self): return None
+ def close(self): pass
+ def send(self, value): return value
+ def throw(self, typ, val=None, tb=None): pass
- class MinimalGen(Generator):
- def send(self, value):
- return value
@ -446,50 +446,50 @@ index cafc44007d1..4571e5a14fd 100644
+ return value
+ def throw(self, typ, val=None, tb=None):
+ super().throw(typ, val, tb)
def gen():
yield 1
@@ -1228,15 +1271,17 @@ class TestOneTrickPonyABCs(ABCTestCase):
mgen.throw, ValueError, ValueError("huhu"))
self.assertRaises(StopIteration, mgen.throw, StopIteration())
- class FailOnClose(Generator):
- def send(self, value): return value
- def throw(self, *args): raise ValueError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class FailOnClose(Generator):
+ def send(self, value): return value
+ def throw(self, *args): raise ValueError
self.assertRaises(ValueError, FailOnClose().close)
- class IgnoreGeneratorExit(Generator):
- def send(self, value): return value
- def throw(self, *args): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class IgnoreGeneratorExit(Generator):
+ def send(self, value): return value
+ def throw(self, *args): pass
self.assertRaises(RuntimeError, IgnoreGeneratorExit().close)
@@ -1379,15 +1424,17 @@ class TestOneTrickPonyABCs(ABCTestCase):
def test_direct_subclassing(self):
for B in Hashable, Iterable, Iterator, Reversible, Sized, Container, Callable:
- class C(B):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(B):
+ pass
self.assertTrue(issubclass(C, B))
self.assertFalse(issubclass(int, C))
def test_registration(self):
for B in Hashable, Iterable, Iterator, Reversible, Sized, Container, Callable:
- class C:
- __hash__ = None # Make sure it isn't hashable by default
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ __hash__ = None # Make sure it isn't hashable by default
self.assertFalse(issubclass(C, B), B.__name__)
@ -506,7 +506,7 @@ index cafc44007d1..4571e5a14fd 100644
- return 0
- def __iter__(self):
- return iter([])
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MySet(Set):
+ def __contains__(self, x):
+ return False
@ -515,11 +515,11 @@ index cafc44007d1..4571e5a14fd 100644
+ def __iter__(self):
+ return iter([])
self.validate_comparison(MySet())
def test_hash_Set(self):
@@ -1448,15 +1496,16 @@ class TestCollectionABCs(ABCTestCase):
self.assertTrue(hash(a) == hash(b))
def test_isdisjoint_Set(self):
- class MySet(Set):
- def __init__(self, itr):
@ -530,7 +530,7 @@ index cafc44007d1..4571e5a14fd 100644
- return iter(self.contents)
- def __len__(self):
- return len([x for x in self.contents])
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MySet(Set):
+ def __init__(self, itr):
+ self.contents = itr
@ -545,7 +545,7 @@ index cafc44007d1..4571e5a14fd 100644
s3 = MySet((1, 5, 6))
@@ -1464,15 +1513,16 @@ class TestCollectionABCs(ABCTestCase):
self.assertFalse(s1.isdisjoint(s3))
def test_equality_Set(self):
- class MySet(Set):
- def __init__(self, itr):
@ -556,7 +556,7 @@ index cafc44007d1..4571e5a14fd 100644
- return iter(self.contents)
- def __len__(self):
- return len([x for x in self.contents])
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MySet(Set):
+ def __init__(self, itr):
+ self.contents = itr
@ -571,7 +571,7 @@ index cafc44007d1..4571e5a14fd 100644
s3 = MySet((3, 4))
@@ -1486,15 +1536,16 @@ class TestCollectionABCs(ABCTestCase):
self.assertNotEqual(s2, s3)
def test_arithmetic_Set(self):
- class MySet(Set):
- def __init__(self, itr):
@ -582,7 +582,7 @@ index cafc44007d1..4571e5a14fd 100644
- return iter(self.contents)
- def __len__(self):
- return len([x for x in self.contents])
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MySet(Set):
+ def __init__(self, itr):
+ self.contents = itr
@ -596,7 +596,7 @@ index cafc44007d1..4571e5a14fd 100644
s2 = MySet((3, 4, 5))
s3 = s1 & s2
@@ -1516,28 +1567,29 @@ class TestCollectionABCs(ABCTestCase):
def test_issue_4920(self):
# MutableSet.pop() method did not work
- class MySet(MutableSet):
@ -621,7 +621,7 @@ index cafc44007d1..4571e5a14fd 100644
- return result
- def __repr__(self):
- return "MySet(%s)" % repr(list(self))
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MySet(MutableSet):
+ __slots__=['__s']
+ def __init__(self,items=None):
@ -669,7 +669,7 @@ index cafc44007d1..4571e5a14fd 100644
- return NotImplemented
- def __lt__(self, x):
- return NotImplemented
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyComparableSet(Set):
+ def __contains__(self, x):
+ return False
@ -688,11 +688,11 @@ index cafc44007d1..4571e5a14fd 100644
+ return NotImplemented
+ def __lt__(self, x):
+ return NotImplemented
cs = MyComparableSet()
ncs = MyNonComparableSet()
@@ -1591,13 +1644,14 @@ class TestCollectionABCs(ABCTestCase):
def test_issue26915(self):
# Container membership test should check identity first
- class CustomSequence(Sequence):
@ -702,7 +702,7 @@ index cafc44007d1..4571e5a14fd 100644
- return self._seq[index]
- def __len__(self):
- return len(self._seq)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class CustomSequence(Sequence):
+ def __init__(self, seq):
+ self._seq = seq
@ -710,11 +710,11 @@ index cafc44007d1..4571e5a14fd 100644
+ return self._seq[index]
+ def __len__(self):
+ return len(self._seq)
nan = float('nan')
obj = support.NEVER_EQ
@@ -1622,30 +1676,31 @@ class TestCollectionABCs(ABCTestCase):
def test_Set_from_iterable(self):
"""Verify _from_iterable overridden to an instance method works."""
- class SetUsingInstanceFromIterable(MutableSet):
@ -723,48 +723,48 @@ index cafc44007d1..4571e5a14fd 100644
- raise ValueError('created_by must be specified')
- self.created_by = created_by
- self._values = set(values)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class SetUsingInstanceFromIterable(MutableSet):
+ def __init__(self, values, created_by):
+ if not created_by:
+ raise ValueError('created_by must be specified')
+ self.created_by = created_by
+ self._values = set(values)
- def _from_iterable(self, values):
- return type(self)(values, 'from_iterable')
+ def _from_iterable(self, values):
+ return type(self)(values, 'from_iterable')
- def __contains__(self, value):
- return value in self._values
+ def __contains__(self, value):
+ return value in self._values
- def __iter__(self):
- yield from self._values
+ def __iter__(self):
+ yield from self._values
- def __len__(self):
- return len(self._values)
+ def __len__(self):
+ return len(self._values)
- def add(self, value):
- self._values.add(value)
+ def add(self, value):
+ self._values.add(value)
- def discard(self, value):
- self._values.discard(value)
+ def discard(self, value):
+ self._values.discard(value)
impl = SetUsingInstanceFromIterable([1, 2, 3], 'test')
@@ -1678,20 +1733,21 @@ class TestCollectionABCs(ABCTestCase):
def test_Set_interoperability_with_real_sets(self):
# Issue: 8743
- class ListSet(Set):
@ -781,7 +781,7 @@ index cafc44007d1..4571e5a14fd 100644
- return len(self.data)
- def __repr__(self):
- return 'Set({!r})'.format(self.data)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class ListSet(Set):
+ def __init__(self, elements=()):
+ self.data = []
@ -796,7 +796,7 @@ index cafc44007d1..4571e5a14fd 100644
+ return len(self.data)
+ def __repr__(self):
+ return 'Set({!r})'.format(self.data)
r1 = set('abc')
r2 = set('bcd')
@@ -1846,13 +1902,14 @@ class TestCollectionABCs(ABCTestCase):
@ -810,7 +810,7 @@ index cafc44007d1..4571e5a14fd 100644
- raise IndexError
- def __iter__(self):
- return iter(())
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyMapping(Mapping):
+ def __len__(self):
+ return 0
@ -820,7 +820,7 @@ index cafc44007d1..4571e5a14fd 100644
+ return iter(())
self.validate_comparison(MyMapping())
self.assertRaises(TypeError, reversed, MyMapping())
@@ -1860,7 +1917,7 @@ class TestCollectionABCs(ABCTestCase):
for sample in [dict]:
self.assertIsInstance(sample(), MutableMapping)
@ -828,30 +828,30 @@ index cafc44007d1..4571e5a14fd 100644
- self.validate_abstract_methods(MutableMapping, '__contains__', '__iter__', '__len__',
+ self.validate_abstract_methods(MutableMapping, '__iter__', '__len__',
'__getitem__', '__setitem__', '__delitem__')
def test_MutableMapping_subclass(self):
@@ -1903,15 +1960,16 @@ class TestCollectionABCs(ABCTestCase):
'__getitem__')
def test_Sequence_mixins(self):
- class SequenceSubclass(Sequence):
- def __init__(self, seq=()):
- self.seq = seq
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class SequenceSubclass(Sequence):
+ def __init__(self, seq=()):
+ self.seq = seq
- def __getitem__(self, index):
- return self.seq[index]
+ def __getitem__(self, index):
+ return self.seq[index]
- def __len__(self):
- return len(self.seq)
+ def __len__(self):
+ return len(self.seq)
# Compare Sequence.index() behavior to (list|str).index() behavior
def assert_index_same(seq1, seq2, index_args):
@@ -1983,24 +2041,25 @@ class TestCollectionABCs(ABCTestCase):
@ -861,54 +861,54 @@ index cafc44007d1..4571e5a14fd 100644
- class MutableSequenceSubclass(MutableSequence):
- def __init__(self):
- self.lst = []
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MutableSequenceSubclass(MutableSequence):
+ def __init__(self):
+ self.lst = []
- def __setitem__(self, index, value):
- self.lst[index] = value
+ def __setitem__(self, index, value):
+ self.lst[index] = value
- def __getitem__(self, index):
- return self.lst[index]
+ def __getitem__(self, index):
+ return self.lst[index]
- def __len__(self):
- return len(self.lst)
+ def __len__(self):
+ return len(self.lst)
- def __delitem__(self, index):
- del self.lst[index]
+ def __delitem__(self, index):
+ del self.lst[index]
- def insert(self, index, value):
- self.lst.insert(index, value)
+ def insert(self, index, value):
+ self.lst.insert(index, value)
mss = MutableSequenceSubclass()
mss.append(0)
@@ -2059,7 +2118,7 @@ class CounterSubclassWithGet(Counter):
self.called = True
return Counter.get(self, key, default)
-class TestCounter(unittest.TestCase):
+class TestCounter(__TestCase):
def test_basics(self):
c = Counter('abcaba')
@@ -2225,8 +2284,9 @@ class TestCounter(unittest.TestCase):
check(Counter(words))
def test_copy_subclass(self):
- class MyCounter(Counter):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyCounter(Counter):
+ pass
c = MyCounter('slartibartfast')
@ -916,8 +916,8 @@ index cafc44007d1..4571e5a14fd 100644
self.assertEqual(d, c)
@@ -2402,10 +2462,5 @@ class TestCounter(unittest.TestCase):
self.assertFalse(Counter(a=2, b=1, c=0) > Counter('aab'))
-def load_tests(loader, tests, pattern):
- tests.addTest(doctest.DocTestSuite(collections))
- return tests

View File

@ -93,7 +93,7 @@ class TestUserObjects(__TestCase):
self._copy_test(obj)
def test_dict_missing(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A(UserDict):
def __missing__(self, key):
return 456
@ -193,7 +193,7 @@ class TestChainMap(__TestCase):
self.assertTrue(ChainMap({}, {1:2}))
def test_missing(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class DefaultChainMap(ChainMap):
def __missing__(self, key):
return 999
@ -228,7 +228,7 @@ class TestChainMap(__TestCase):
('i', 9999), ('j', 0)])
def test_iter_not_calling_getitem_on_maps(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class DictWithGetItem(UserDict):
def __init__(self, *args, **kwds):
self.called = False
@ -260,7 +260,7 @@ class TestChainMap(__TestCase):
self.assertIs(m, d.maps[0])
# Use a different map than a dict
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class lowerdict(dict):
def __getitem__(self, key):
if isinstance(key, str):
@ -690,7 +690,7 @@ class TestNamedTuple(__TestCase):
NT = namedtuple('NT', ['abc', 'def'], False, True)
def test_namedtuple_subclass_issue_24931(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Point(namedtuple('_Point', ['x', 'y'])):
pass
@ -753,7 +753,7 @@ class ABCTestCase(__TestCase):
methodstubs = dict.fromkeys(names, lambda s, *args: 0)
# everything should work will all required methods are present
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
C = type('C', (abc,), methodstubs)
C()
@ -1011,7 +1011,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
for x in samples:
self.assertIsInstance(x, Iterable)
self.assertTrue(issubclass(type(x), Iterable), repr(type(x)))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Check direct subclassing
class I(Iterable):
def __iter__(self):
@ -1020,7 +1020,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertFalse(issubclass(str, I))
self.validate_abstract_methods(Iterable, '__iter__')
self.validate_isinstance(Iterable, '__iter__')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Check None blocking
class It:
def __iter__(self): return iter([])
@ -1055,7 +1055,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertTrue(issubclass(Sequence, Reversible), repr(Sequence))
self.assertFalse(issubclass(Mapping, Reversible), repr(Mapping))
self.assertFalse(issubclass(MutableMapping, Reversible), repr(MutableMapping))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Check direct subclassing
class R(Reversible):
def __iter__(self):
@ -1065,7 +1065,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertEqual(list(reversed(R())), [])
self.assertFalse(issubclass(float, R))
self.validate_abstract_methods(Reversible, '__reversed__', '__iter__')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Check reversible non-iterable (which is not Reversible)
class RevNoIter:
def __reversed__(self): return reversed([])
@ -1075,7 +1075,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertFalse(isinstance(RevNoIter(), Reversible))
self.assertTrue(issubclass(RevPlusIter, Reversible))
self.assertTrue(isinstance(RevPlusIter(), Reversible))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Check None blocking
class Rev:
def __iter__(self): return iter([])
@ -1117,7 +1117,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertTrue(issubclass(Set, Collection), repr(Set))
self.assertTrue(issubclass(MutableSet, Collection), repr(MutableSet))
self.assertTrue(issubclass(Sequence, Collection), repr(MutableSet))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Check direct subclassing
class Col(Collection):
def __iter__(self):
@ -1138,7 +1138,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.validate_abstract_methods(Collection, '__len__', '__iter__',
'__contains__')
# Check sized container non-iterable (which is not Collection) etc.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class ColNoIter:
def __len__(self): return 0
def __contains__(self, item): return False
@ -1155,7 +1155,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertFalse(issubclass(ColNoCont, Collection))
self.assertFalse(isinstance(ColNoCont(), Collection))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Check None blocking
class SizeBlock:
def __iter__(self): return iter([])
@ -1169,7 +1169,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertFalse(isinstance(SizeBlock(), Collection))
self.assertFalse(issubclass(IterBlock, Collection))
self.assertFalse(isinstance(IterBlock(), Collection))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Check None blocking in subclass
class ColImpl:
def __iter__(self):
@ -1202,7 +1202,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertTrue(issubclass(type(x), Iterator), repr(type(x)))
self.validate_abstract_methods(Iterator, '__next__', '__iter__')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Issue 10565
class NextOnly:
def __next__(self):
@ -1211,7 +1211,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertNotIsInstance(NextOnly(), Iterator)
def test_Generator(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class NonGen1:
def __iter__(self): return self
def __next__(self): return None
@ -1236,7 +1236,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertNotIsInstance(x, Generator)
self.assertFalse(issubclass(type(x), Generator), repr(type(x)))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Gen:
def __iter__(self): return self
def __next__(self): return None
@ -1271,14 +1271,14 @@ class TestOneTrickPonyABCs(ABCTestCase):
mgen.throw, ValueError, ValueError("huhu"))
self.assertRaises(StopIteration, mgen.throw, StopIteration())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class FailOnClose(Generator):
def send(self, value): return value
def throw(self, *args): raise ValueError
self.assertRaises(ValueError, FailOnClose().close)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class IgnoreGeneratorExit(Generator):
def send(self, value): return value
def throw(self, *args): pass
@ -1424,7 +1424,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
def test_direct_subclassing(self):
for B in Hashable, Iterable, Iterator, Reversible, Sized, Container, Callable:
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(B):
pass
self.assertTrue(issubclass(C, B))
@ -1432,7 +1432,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
def test_registration(self):
for B in Hashable, Iterable, Iterator, Reversible, Sized, Container, Callable:
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
__hash__ = None # Make sure it isn't hashable by default
self.assertFalse(issubclass(C, B), B.__name__)
@ -1470,7 +1470,7 @@ class TestCollectionABCs(ABCTestCase):
self.assertIsInstance(sample(), Set)
self.assertTrue(issubclass(sample, Set))
self.validate_abstract_methods(Set, '__contains__', '__iter__', '__len__')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MySet(Set):
def __contains__(self, x):
return False
@ -1496,7 +1496,7 @@ class TestCollectionABCs(ABCTestCase):
self.assertTrue(hash(a) == hash(b))
def test_isdisjoint_Set(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MySet(Set):
def __init__(self, itr):
self.contents = itr
@ -1513,7 +1513,7 @@ class TestCollectionABCs(ABCTestCase):
self.assertFalse(s1.isdisjoint(s3))
def test_equality_Set(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MySet(Set):
def __init__(self, itr):
self.contents = itr
@ -1536,7 +1536,7 @@ class TestCollectionABCs(ABCTestCase):
self.assertNotEqual(s2, s3)
def test_arithmetic_Set(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MySet(Set):
def __init__(self, itr):
self.contents = itr
@ -1567,7 +1567,7 @@ class TestCollectionABCs(ABCTestCase):
def test_issue_4920(self):
# MutableSet.pop() method did not work
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MySet(MutableSet):
__slots__=['__s']
def __init__(self,items=None):
@ -1615,7 +1615,7 @@ class TestCollectionABCs(ABCTestCase):
def test_issue16373(self):
# Recursion error comparing comparable and noncomparable
# Set instances
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyComparableSet(Set):
def __contains__(self, x):
return False
@ -1644,7 +1644,7 @@ class TestCollectionABCs(ABCTestCase):
def test_issue26915(self):
# Container membership test should check identity first
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class CustomSequence(Sequence):
def __init__(self, seq):
self._seq = seq
@ -1676,7 +1676,7 @@ class TestCollectionABCs(ABCTestCase):
def test_Set_from_iterable(self):
"""Verify _from_iterable overridden to an instance method works."""
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class SetUsingInstanceFromIterable(MutableSet):
def __init__(self, values, created_by):
if not created_by:
@ -1733,7 +1733,7 @@ class TestCollectionABCs(ABCTestCase):
def test_Set_interoperability_with_real_sets(self):
# Issue: 8743
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class ListSet(Set):
def __init__(self, elements=()):
self.data = []
@ -1902,7 +1902,7 @@ class TestCollectionABCs(ABCTestCase):
self.assertTrue(issubclass(sample, Mapping))
self.validate_abstract_methods(Mapping, '__contains__', '__iter__', '__len__',
'__getitem__')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyMapping(Mapping):
def __len__(self):
return 0
@ -1960,7 +1960,7 @@ class TestCollectionABCs(ABCTestCase):
'__getitem__')
def test_Sequence_mixins(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class SequenceSubclass(Sequence):
def __init__(self, seq=()):
self.seq = seq
@ -2041,7 +2041,7 @@ class TestCollectionABCs(ABCTestCase):
def test_MutableSequence_mixins(self):
# Test the mixins of MutableSequence by creating a minimal concrete
# class inherited from it.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MutableSequenceSubclass(MutableSequence):
def __init__(self):
self.lst = []
@ -2284,7 +2284,7 @@ class TestCounter(__TestCase):
check(Counter(words))
def test_copy_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyCounter(Counter):
pass
c = MyCounter('slartibartfast')

View File

@ -43,7 +43,7 @@ index 6ff1a8ab29d..1572433c5ae 100644
+ "test.test_iter",
+ "test.typinganndata.ann_module",
)
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
+ def find_spec(self, fullname, path, target=None):
+ # Check if the import is the problematic one
@ -74,7 +74,7 @@ index 6ff1a8ab29d..1572433c5ae 100644
from math import isnan, copysign
+import math
import operator
+VALID_UNDERSCORE_LITERALS = [
+ '0_0_0',
+ '4_2',
@ -158,7 +158,7 @@ index 6ff1a8ab29d..1572433c5ae 100644
@@ -45,7 +176,40 @@ class WithComplex:
def __complex__(self):
return self.value
-class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
+class ComplexTest(__TestCase):
+
@ -194,13 +194,13 @@ index 6ff1a8ab29d..1572433c5ae 100644
+ """
+ self.assertFloatIdentical(x.real, y.real)
+ self.assertFloatIdentical(x.imag, y.imag)
def assertAlmostEqual(self, a, b):
if isinstance(a, complex):
@@ -74,6 +238,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
# check that relative difference < eps
self.assertTrue(abs((x-y)/y) < eps)
+ def assertFloatsAreIdentical(self, x, y):
+ """assert that floats x and y are identical, in the sense that:
+ (1) both x and y are nans, or
@ -230,58 +230,58 @@ index 6ff1a8ab29d..1572433c5ae 100644
@@ -93,6 +280,7 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
q = z.__truediv__(y)
self.assertClose(q, x)
+ @slowTest
def test_truediv(self):
simple_real = [float(i) for i in range(-5, 6)]
simple_complex = [complex(x, y) for x in simple_real for y in simple_real]
@@ -338,7 +526,10 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
def test_boolcontext(self):
for i in range(100):
- self.assertTrue(complex(random() + 1e-6, random() + 1e-6))
+ with torch._dynamo.set_fullgraph(False):
+ with torch._dynamo.error_on_graph_break(False):
+ r1 = random()
+ r2 = random()
+ self.assertTrue(complex(r1 + 1e-6, r2 + 1e-6))
self.assertTrue(not complex(0.0, 0.0))
self.assertTrue(1j)
@@ -431,12 +622,13 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
self.assertRaises(TypeError, complex, WithComplex(1), object())
self.assertRaises(TypeError, complex, WithComplex(None), object())
- class EvilExc(Exception):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class EvilExc(Exception):
+ pass
- class evilcomplex:
- def __complex__(self):
- raise EvilExc
+ class evilcomplex:
+ def __complex__(self):
+ raise EvilExc
self.assertRaises(EvilExc, complex, evilcomplex())
@@ -460,31 +652,33 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
self.assertRaises(TypeError, complex, WithIndex(None), 1.5)
self.assertRaises(TypeError, complex, 1.5, WithIndex(None))
- class MyInt:
- def __int__(self):
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyInt:
+ def __int__(self):
+ return 42
self.assertRaises(TypeError, complex, MyInt())
self.assertRaises(TypeError, complex, MyInt(), 1.5)
self.assertRaises(TypeError, complex, 1.5, MyInt())
- class complex0(complex):
- """Test usage of __complex__() when inheriting from 'complex'"""
- def __complex__(self):
@ -299,7 +299,7 @@ index 6ff1a8ab29d..1572433c5ae 100644
- complex is returned"""
- def __complex__(self):
- return None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class complex0(complex):
+ """Test usage of __complex__() when inheriting from 'complex'"""
+ def __complex__(self):
@ -317,12 +317,12 @@ index 6ff1a8ab29d..1572433c5ae 100644
+ complex is returned"""
+ def __complex__(self):
+ return None
check(complex(complex0(1j)), 0.0, 42.0)
with self.assertWarns(DeprecationWarning):
@@ -855,4 +1049,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
if __name__ == "__main__":
- unittest.main()
+ run_tests()

View File

@ -526,7 +526,7 @@ class ComplexTest(__TestCase):
def test_boolcontext(self):
for i in range(100):
with torch._dynamo.set_fullgraph(False):
with torch._dynamo.error_on_graph_break(False):
r1 = random()
r2 = random()
self.assertTrue(complex(r1 + 1e-6, r2 + 1e-6))
@ -622,7 +622,7 @@ class ComplexTest(__TestCase):
self.assertRaises(TypeError, complex, WithComplex(1), object())
self.assertRaises(TypeError, complex, WithComplex(None), object())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class EvilExc(Exception):
pass
@ -652,7 +652,7 @@ class ComplexTest(__TestCase):
self.assertRaises(TypeError, complex, WithIndex(None), 1.5)
self.assertRaises(TypeError, complex, 1.5, WithIndex(None))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyInt:
def __int__(self):
return 42
@ -661,7 +661,7 @@ class ComplexTest(__TestCase):
self.assertRaises(TypeError, complex, MyInt(), 1.5)
self.assertRaises(TypeError, complex, 1.5, MyInt())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class complex0(complex):
"""Test usage of __complex__() when inheriting from 'complex'"""
def __complex__(self):

View File

@ -58,121 +58,121 @@ index cf651959803..256a824932d 100644
+# ======= END DYNAMO PATCH =======
+
"""Unit tests for contextlib.py, and other context managers."""
import io
@@ -14,60 +68,67 @@ from test.support.testcase import ExceptionIsLikeMixin
import weakref
-class TestAbstractContextManager(unittest.TestCase):
+class TestAbstractContextManager(__TestCase):
def test_enter(self):
- class DefaultEnter(AbstractContextManager):
- def __exit__(self, *args):
- super().__exit__(*args)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class DefaultEnter(AbstractContextManager):
+ def __exit__(self, *args):
+ super().__exit__(*args)
manager = DefaultEnter()
self.assertIs(manager.__enter__(), manager)
def test_slots(self):
- class DefaultContextManager(AbstractContextManager):
- __slots__ = ()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class DefaultContextManager(AbstractContextManager):
+ __slots__ = ()
- def __exit__(self, *args):
- super().__exit__(*args)
+ def __exit__(self, *args):
+ super().__exit__(*args)
with self.assertRaises(AttributeError):
DefaultContextManager().var = 42
def test_exit_is_abstract(self):
- class MissingExit(AbstractContextManager):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MissingExit(AbstractContextManager):
+ pass
with self.assertRaises(TypeError):
MissingExit()
def test_structural_subclassing(self):
- class ManagerFromScratch:
- def __enter__(self):
- return self
- def __exit__(self, exc_type, exc_value, traceback):
- return None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class ManagerFromScratch:
+ def __enter__(self):
+ return self
+ def __exit__(self, exc_type, exc_value, traceback):
+ return None
self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
- class DefaultEnter(AbstractContextManager):
- def __exit__(self, *args):
- super().__exit__(*args)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class DefaultEnter(AbstractContextManager):
+ def __exit__(self, *args):
+ super().__exit__(*args)
self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
- class NoEnter(ManagerFromScratch):
- __enter__ = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class NoEnter(ManagerFromScratch):
+ __enter__ = None
self.assertFalse(issubclass(NoEnter, AbstractContextManager))
- class NoExit(ManagerFromScratch):
- __exit__ = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class NoExit(ManagerFromScratch):
+ __exit__ = None
self.assertFalse(issubclass(NoExit, AbstractContextManager))
-class ContextManagerTestCase(unittest.TestCase):
+class ContextManagerTestCase(__TestCase):
def test_contextmanager_plain(self):
state = []
@@ -115,8 +176,9 @@ class ContextManagerTestCase(unittest.TestCase):
self.assertEqual(frames[0].line, '1/0')
# Repeat with RuntimeError (which goes through a different code path)
- class RuntimeErrorSubclass(RuntimeError):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class RuntimeErrorSubclass(RuntimeError):
+ pass
try:
with f():
@@ -128,8 +190,9 @@ class ContextManagerTestCase(unittest.TestCase):
self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
- class StopIterationSubclass(StopIteration):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class StopIterationSubclass(StopIteration):
+ pass
for stop_exc in (
StopIteration('spam'),
@@ -169,9 +232,9 @@ class ContextManagerTestCase(unittest.TestCase):
@ -185,7 +185,7 @@ index cf651959803..256a824932d 100644
+ # if support.check_impl_detail(cpython=True):
+ # # The "gen" attribute is an implementation detail.
+ # self.assertFalse(ctx.gen.gi_suspended)
def test_contextmanager_trap_no_yield(self):
@contextmanager
@@ -191,9 +254,9 @@ class ContextManagerTestCase(unittest.TestCase):
@ -198,50 +198,50 @@ index cf651959803..256a824932d 100644
+ # if support.check_impl_detail(cpython=True):
+ # # The "gen" attribute is an implementation detail.
+ # self.assertFalse(ctx.gen.gi_suspended)
def test_contextmanager_non_normalised(self):
@contextmanager
@@ -230,8 +293,9 @@ class ContextManagerTestCase(unittest.TestCase):
def woohoo():
yield
- class StopIterationSubclass(StopIteration):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class StopIterationSubclass(StopIteration):
+ pass
for stop_exc in (StopIteration('spam'), StopIterationSubclass('spam')):
with self.subTest(type=type(stop_exc)):
@@ -344,8 +408,9 @@ def woohoo():
self.assertEqual(target, (11, 22, 33, 44))
def test_nokeepref(self):
- class A:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ pass
@contextmanager
def woohoo(a, b):
@@ -396,7 +461,7 @@ def woohoo():
self.assertEqual(depth, 0)
-class ClosingTestCase(unittest.TestCase):
+class ClosingTestCase(__TestCase):
@support.requires_docstrings
def test_instance_docs(self):
@@ -407,9 +472,10 @@ class ClosingTestCase(unittest.TestCase):
def test_closing(self):
state = []
- class C:
- def close(self):
- state.append(1)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ def close(self):
+ state.append(1)
@ -249,13 +249,13 @@ index cf651959803..256a824932d 100644
self.assertEqual(state, [])
with closing(x) as y:
@@ -418,9 +484,10 @@ class ClosingTestCase(unittest.TestCase):
def test_closing_error(self):
state = []
- class C:
- def close(self):
- state.append(1)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ def close(self):
+ state.append(1)
@ -264,52 +264,52 @@ index cf651959803..256a824932d 100644
with self.assertRaises(ZeroDivisionError):
@@ -430,16 +497,17 @@ class ClosingTestCase(unittest.TestCase):
self.assertEqual(state, [1])
-class NullcontextTestCase(unittest.TestCase):
+class NullcontextTestCase(__TestCase):
def test_nullcontext(self):
- class C:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ pass
c = C()
with nullcontext(c) as c_in:
self.assertIs(c_in, c)
-class FileContextTestCase(unittest.TestCase):
+class FileContextTestCase(__TestCase):
def testWithOpen(self):
tfn = tempfile.mktemp()
@@ -457,7 +525,7 @@ class FileContextTestCase(unittest.TestCase):
finally:
os_helper.unlink(tfn)
-class LockContextTestCase(unittest.TestCase):
+class LockContextTestCase(__TestCase):
def boilerPlate(self, lock, locked):
self.assertFalse(locked())
@@ -520,7 +588,7 @@ class mycontext(ContextDecorator):
return self.catch
-class TestContextDecorator(unittest.TestCase):
+class TestContextDecorator(__TestCase):
@support.requires_docstrings
def test_instance_docs(self):
@@ -584,13 +652,14 @@ class TestContextDecorator(unittest.TestCase):
def test_decorating_method(self):
context = mycontext()
- class Test(object):
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Test(object):
- @context
- def method(self, a, b, c=None):
- self.a = a
@ -320,84 +320,84 @@ index cf651959803..256a824932d 100644
+ self.a = a
+ self.b = b
+ self.c = c
# these tests are for argument passing when used as a decorator
test = Test()
@@ -612,11 +681,12 @@ class TestContextDecorator(unittest.TestCase):
def test_typo_enter(self):
- class mycontext(ContextDecorator):
- def __unter__(self):
- pass
- def __exit__(self, *exc):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class mycontext(ContextDecorator):
+ def __unter__(self):
+ pass
+ def __exit__(self, *exc):
+ pass
with self.assertRaisesRegex(TypeError, 'the context manager'):
with mycontext():
@@ -624,11 +694,12 @@ class TestContextDecorator(unittest.TestCase):
def test_typo_exit(self):
- class mycontext(ContextDecorator):
- def __enter__(self):
- pass
- def __uxit__(self, *exc):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class mycontext(ContextDecorator):
+ def __enter__(self):
+ pass
+ def __uxit__(self, *exc):
+ pass
with self.assertRaisesRegex(TypeError, 'the context manager.*__exit__'):
with mycontext():
@@ -636,19 +707,20 @@ class TestContextDecorator(unittest.TestCase):
def test_contextdecorator_as_mixin(self):
- class somecontext(object):
- started = False
- exc = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class somecontext(object):
+ started = False
+ exc = None
- def __enter__(self):
- self.started = True
- return self
+ def __enter__(self):
+ self.started = True
+ return self
- def __exit__(self, *exc):
- self.exc = exc
+ def __exit__(self, *exc):
+ self.exc = exc
- class mycontext(somecontext, ContextDecorator):
- pass
+ class mycontext(somecontext, ContextDecorator):
+ pass
context = mycontext()
@context
@@ -680,7 +752,7 @@ class TestContextDecorator(unittest.TestCase):
self.assertEqual(state, [1, 'something else', 999])
-class TestBaseExitStack:
+class _TestBaseExitStack:
exit_stack = None
@support.requires_docstrings
@@ -745,13 +817,14 @@ class TestBaseExitStack:
self.assertIsNone(exc_type)
@ -410,7 +410,7 @@ index cf651959803..256a824932d 100644
- self.fail("Should not be called!")
- def __exit__(self, *exc_details):
- self.check_exc(*exc_details)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class ExitCM(object):
+ def __init__(self, check_exc):
+ self.check_exc = check_exc
@ -423,25 +423,25 @@ index cf651959803..256a824932d 100644
self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
@@ -770,11 +843,12 @@ class TestBaseExitStack:
1/0
def test_enter_context(self):
- class TestCM(object):
- def __enter__(self):
- result.append(1)
- def __exit__(self, *exc_details):
- result.append(3)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class TestCM(object):
+ def __enter__(self):
+ result.append(1)
+ def __exit__(self, *exc_details):
+ result.append(3)
result = []
cm = TestCM()
@@ -789,14 +863,15 @@ class TestBaseExitStack:
self.assertEqual(result, [1, 2, 3, 4])
def test_enter_context_errors(self):
- class LacksEnterAndExit:
- pass
@ -450,7 +450,7 @@ index cf651959803..256a824932d 100644
- pass
- class LacksExit:
- def __enter__(self):
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class LacksEnterAndExit:
pass
+ class LacksEnter:
@ -459,7 +459,7 @@ index cf651959803..256a824932d 100644
+ class LacksExit:
+ def __enter__(self):
+ pass
with self.exit_stack() as stack:
with self.assertRaisesRegex(TypeError, 'the context manager'):
@@ -877,32 +952,33 @@ class TestBaseExitStack:
@ -492,7 +492,7 @@ index cf651959803..256a824932d 100644
- def __exit__(self, *exc_details):
- type(self).saved_details = exc_details
- return True
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class RaiseExc:
+ def __init__(self, exc):
+ self.exc = exc
@ -519,47 +519,47 @@ index cf651959803..256a824932d 100644
+ def __exit__(self, *exc_details):
+ type(self).saved_details = exc_details
+ return True
try:
with RaiseExc(IndexError):
@@ -957,8 +1033,9 @@ class TestBaseExitStack:
# Ensure ExitStack chaining matches actual nested `with` statements
# regarding explicit __context__ = None.
- class MyException(Exception):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyException(Exception):
+ pass
@contextmanager
def my_cm():
@@ -1096,7 +1173,8 @@ class TestBaseExitStack:
stack.callback(int)
def test_instance_bypass(self):
- class Example(object): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Example(object): pass
cm = Example()
cm.__enter__ = object()
cm.__exit__ = object()
@@ -1108,8 +1186,9 @@ class TestBaseExitStack:
def test_dont_reraise_RuntimeError(self):
# https://bugs.python.org/issue27122
- class UniqueException(Exception): pass
- class UniqueRuntimeError(RuntimeError): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class UniqueException(Exception): pass
+ class UniqueRuntimeError(RuntimeError): pass
@contextmanager
def second():
@@ -1141,7 +1220,7 @@ class TestBaseExitStack:
self.assertIs(exc.__cause__, exc.__context__)
-class TestExitStack(TestBaseExitStack, unittest.TestCase):
+class TestExitStack(_TestBaseExitStack, __TestCase):
exit_stack = ExitStack
@ -567,40 +567,40 @@ index cf651959803..256a824932d 100644
('__exit__', 'raise exc'),
@@ -1149,7 +1228,7 @@ class TestExitStack(TestBaseExitStack, unittest.TestCase):
]
-class TestRedirectStream:
+class _TestRedirectStream:
redirect_stream = None
orig_stream = None
@@ -1206,19 +1285,19 @@ class TestRedirectStream:
self.assertEqual(s, "Hello World!\n")
-class TestRedirectStdout(TestRedirectStream, unittest.TestCase):
+class TestRedirectStdout(_TestRedirectStream, __TestCase):
redirect_stream = redirect_stdout
orig_stream = "stdout"
-class TestRedirectStderr(TestRedirectStream, unittest.TestCase):
+class TestRedirectStderr(_TestRedirectStream, __TestCase):
redirect_stream = redirect_stderr
orig_stream = "stderr"
-class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase):
+class TestSuppress(ExceptionIsLikeMixin, __TestCase):
@support.requires_docstrings
def test_instance_docs(self):
@@ -1315,7 +1394,7 @@ class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase):
)
-class TestChdir(unittest.TestCase):
+class TestChdir(__TestCase):
def make_relative_path(self, *parts):
@ -609,14 +609,14 @@ index cf651959803..256a824932d 100644
@@ -1331,6 +1410,7 @@ class TestChdir(unittest.TestCase):
self.assertEqual(os.getcwd(), target)
self.assertEqual(os.getcwd(), old_cwd)
+ @unittest.skip("Missing archivetestdata")
def test_reentrant(self):
old_cwd = os.getcwd()
target1 = self.make_relative_path('data')
@@ -1363,4 +1443,4 @@ class TestChdir(unittest.TestCase):
if __name__ == "__main__":
- unittest.main()
+ run_tests()

View File

@ -71,7 +71,7 @@ import weakref
class TestAbstractContextManager(__TestCase):
def test_enter(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class DefaultEnter(AbstractContextManager):
def __exit__(self, *args):
super().__exit__(*args)
@ -80,7 +80,7 @@ class TestAbstractContextManager(__TestCase):
self.assertIs(manager.__enter__(), manager)
def test_slots(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class DefaultContextManager(AbstractContextManager):
__slots__ = ()
@ -91,7 +91,7 @@ class TestAbstractContextManager(__TestCase):
DefaultContextManager().var = 42
def test_exit_is_abstract(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MissingExit(AbstractContextManager):
pass
@ -99,7 +99,7 @@ class TestAbstractContextManager(__TestCase):
MissingExit()
def test_structural_subclassing(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class ManagerFromScratch:
def __enter__(self):
return self
@ -108,20 +108,20 @@ class TestAbstractContextManager(__TestCase):
self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class DefaultEnter(AbstractContextManager):
def __exit__(self, *args):
super().__exit__(*args)
self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class NoEnter(ManagerFromScratch):
__enter__ = None
self.assertFalse(issubclass(NoEnter, AbstractContextManager))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class NoExit(ManagerFromScratch):
__exit__ = None
@ -176,7 +176,7 @@ class ContextManagerTestCase(__TestCase):
self.assertEqual(frames[0].line, '1/0')
# Repeat with RuntimeError (which goes through a different code path)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class RuntimeErrorSubclass(RuntimeError):
pass
@ -190,7 +190,7 @@ class ContextManagerTestCase(__TestCase):
self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class StopIterationSubclass(StopIteration):
pass
@ -293,7 +293,7 @@ class ContextManagerTestCase(__TestCase):
def woohoo():
yield
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class StopIterationSubclass(StopIteration):
pass
@ -408,7 +408,7 @@ def woohoo():
self.assertEqual(target, (11, 22, 33, 44))
def test_nokeepref(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
pass
@ -472,7 +472,7 @@ class ClosingTestCase(__TestCase):
def test_closing(self):
state = []
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
def close(self):
state.append(1)
@ -484,7 +484,7 @@ class ClosingTestCase(__TestCase):
def test_closing_error(self):
state = []
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
def close(self):
state.append(1)
@ -499,7 +499,7 @@ class ClosingTestCase(__TestCase):
class NullcontextTestCase(__TestCase):
def test_nullcontext(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
pass
c = C()
@ -652,7 +652,7 @@ class TestContextDecorator(__TestCase):
def test_decorating_method(self):
context = mycontext()
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Test(object):
@context
@ -681,7 +681,7 @@ class TestContextDecorator(__TestCase):
def test_typo_enter(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class mycontext(ContextDecorator):
def __unter__(self):
pass
@ -694,7 +694,7 @@ class TestContextDecorator(__TestCase):
def test_typo_exit(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class mycontext(ContextDecorator):
def __enter__(self):
pass
@ -707,7 +707,7 @@ class TestContextDecorator(__TestCase):
def test_contextdecorator_as_mixin(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class somecontext(object):
started = False
exc = None
@ -817,7 +817,7 @@ class _TestBaseExitStack:
self.assertIsNone(exc_type)
self.assertIsNone(exc)
self.assertIsNone(exc_tb)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class ExitCM(object):
def __init__(self, check_exc):
self.check_exc = check_exc
@ -843,7 +843,7 @@ class _TestBaseExitStack:
1/0
def test_enter_context(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class TestCM(object):
def __enter__(self):
result.append(1)
@ -863,7 +863,7 @@ class _TestBaseExitStack:
self.assertEqual(result, [1, 2, 3, 4])
def test_enter_context_errors(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class LacksEnterAndExit:
pass
class LacksEnter:
@ -952,7 +952,7 @@ class _TestBaseExitStack:
def test_exit_exception_chaining_reference(self):
# Sanity check to make sure that ExitStack chaining matches
# actual nested with statements
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class RaiseExc:
def __init__(self, exc):
self.exc = exc
@ -1033,7 +1033,7 @@ class _TestBaseExitStack:
# Ensure ExitStack chaining matches actual nested `with` statements
# regarding explicit __context__ = None.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyException(Exception):
pass
@ -1173,7 +1173,7 @@ class _TestBaseExitStack:
stack.callback(int)
def test_instance_bypass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Example(object): pass
cm = Example()
cm.__enter__ = object()
@ -1186,7 +1186,7 @@ class _TestBaseExitStack:
def test_dont_reraise_RuntimeError(self):
# https://bugs.python.org/issue27122
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class UniqueException(Exception): pass
class UniqueRuntimeError(RuntimeError): pass

View File

@ -61,19 +61,19 @@ index bdbe9b81e8f..d55f1dc54c6 100644
+
+
"""Unit tests for collections.defaultdict."""
import copy
@@ -9,7 +66,7 @@ from collections import defaultdict
def foobar():
return list
-class TestDefaultDict(unittest.TestCase):
+class TestDefaultDict(__TestCase):
def test_basic(self):
d1 = defaultdict()
@@ -127,11 +184,12 @@ class TestDefaultDict(unittest.TestCase):
def test_recursive_repr(self):
# Issue2045: stack overflow when default_factory is a bound method
- class sub(defaultdict):
@ -81,7 +81,7 @@ index bdbe9b81e8f..d55f1dc54c6 100644
- self.default_factory = self._factory
- def _factory(self):
- return []
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class sub(defaultdict):
+ def __init__(self):
+ self.default_factory = self._factory
@ -92,7 +92,7 @@ index bdbe9b81e8f..d55f1dc54c6 100644
r"sub\(<bound method .*sub\._factory "
@@ -187,4 +245,4 @@ class TestDefaultDict(unittest.TestCase):
i |= None
if __name__ == "__main__":
- unittest.main()
+ run_tests()

View File

@ -184,7 +184,7 @@ class TestDefaultDict(__TestCase):
def test_recursive_repr(self):
# Issue2045: stack overflow when default_factory is a bound method
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class sub(defaultdict):
def __init__(self):
self.default_factory = self._factory

File diff suppressed because it is too large Load Diff

View File

@ -71,7 +71,7 @@ from test.support import import_helper, get_c_recursion_limit
class DictTest(__TestCase):
def test_invalid_keyword_arguments(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Custom(dict):
pass
for invalid in {1 : 2}, Custom({1 : 2}):
@ -166,7 +166,7 @@ class DictTest(__TestCase):
def test_views_mapping(self):
mappingproxy = type(type.__dict__)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Dict(dict):
pass
for cls in [dict, Dict]:
@ -216,7 +216,7 @@ class DictTest(__TestCase):
self.assertRaises(TypeError, d.__getitem__)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadEq(object):
def __eq__(self, other):
raise Exc()
@ -227,7 +227,7 @@ class DictTest(__TestCase):
d[BadEq()] = 42
self.assertRaises(KeyError, d.__getitem__, 23)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadHash(object):
@ -262,7 +262,7 @@ class DictTest(__TestCase):
self.assertRaises((TypeError, AttributeError), d.update, None)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class SimpleUserDict:
def __init__(self):
self.d = {1:1, 2:2, 3:3}
@ -274,18 +274,18 @@ class DictTest(__TestCase):
d.update(SimpleUserDict())
self.assertEqual(d, {1:1, 2:2, 3:3})
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
d.clear()
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class FailingUserDict:
def keys(self):
raise Exc
self.assertRaises(Exc, d.update, FailingUserDict())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class FailingUserDict:
def keys(self):
class BogonIter:
@ -303,7 +303,7 @@ class DictTest(__TestCase):
return key
self.assertRaises(Exc, d.update, FailingUserDict())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class FailingUserDict:
def keys(self):
class BogonIter:
@ -323,7 +323,7 @@ class DictTest(__TestCase):
self.assertRaises(Exc, d.update, FailingUserDict())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class badseq(object):
def __iter__(self):
return self
@ -346,13 +346,13 @@ class DictTest(__TestCase):
yield 1
self.assertEqual(d.fromkeys(g()), {1:None})
self.assertRaises(TypeError, {}.fromkeys, 3)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class dictlike(dict): pass
self.assertEqual(dictlike.fromkeys('a'), {'a':None})
self.assertEqual(dictlike().fromkeys('a'), {'a':None})
self.assertIsInstance(dictlike.fromkeys('a'), dictlike)
self.assertIsInstance(dictlike().fromkeys('a'), dictlike)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class mydict(dict):
def __new__(cls):
return collections.UserDict()
@ -361,7 +361,7 @@ class DictTest(__TestCase):
self.assertIsInstance(ud, collections.UserDict)
self.assertRaises(TypeError, dict.fromkeys)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class baddict1(dict):
@ -370,7 +370,7 @@ class DictTest(__TestCase):
self.assertRaises(Exc, baddict1.fromkeys, [1])
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadSeq(object):
def __iter__(self):
return self
@ -379,7 +379,7 @@ class DictTest(__TestCase):
self.assertRaises(Exc, dict.fromkeys, BadSeq())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class baddict2(dict):
def __setitem__(self, key, value):
raise Exc()
@ -398,7 +398,7 @@ class DictTest(__TestCase):
self.assertEqual(dict.fromkeys(d, 0), res)
# test fast path when object's constructor returns large non-empty dict
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class baddict3(dict):
def __new__(cls):
return d
@ -408,7 +408,7 @@ class DictTest(__TestCase):
self.assertEqual(baddict3.fromkeys({"a", "b", "c"}), res)
# test slow path when object is a proper subclass of dict
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class baddict4(dict):
def __init__(self):
dict.__init__(self, d)
@ -447,7 +447,7 @@ class DictTest(__TestCase):
self.assertEqual(len(d2), len(d) + 1)
def test_copy_maintains_tracking(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
pass
@ -495,7 +495,7 @@ class DictTest(__TestCase):
self.assertRaises(TypeError, d.setdefault)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadHash(object):
@ -513,7 +513,7 @@ class DictTest(__TestCase):
def test_setdefault_atomic(self):
# Issue #13521: setdefault() calls __hash__ and __eq__ only once.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Hashed(object):
def __init__(self):
self.hash_count = 0
@ -533,7 +533,7 @@ class DictTest(__TestCase):
self.assertEqual(hashed1.eq_count + hashed2.eq_count, 1)
def test_setitem_atomic_at_resize(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Hashed(object):
def __init__(self):
self.hash_count = 0
@ -599,7 +599,7 @@ class DictTest(__TestCase):
self.assertRaises(TypeError, d.pop)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadHash(object):
@ -652,7 +652,7 @@ class DictTest(__TestCase):
def test_mutating_lookup(self):
# changing dict during a lookup (issue #14417)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class NastyKey:
mutate_dict = None
@ -686,7 +686,7 @@ class DictTest(__TestCase):
d[1] = d
self.assertEqual(repr(d), '{1: {...}}')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadRepr(object):
@ -706,7 +706,7 @@ class DictTest(__TestCase):
self.assertEqual({}, {})
self.assertEqual({1: 2}, {1: 2})
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadCmp(object):
@ -770,7 +770,7 @@ class DictTest(__TestCase):
self.assertFalse(larger == larger3)
def test_errors_in_view_containment_check(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
def __eq__(self, other):
raise RuntimeError
@ -853,7 +853,7 @@ class DictTest(__TestCase):
# (E) subclass defines __missing__ method raising RuntimeError
# (F) subclass sets __missing__ instance variable (no effect)
# (G) subclass doesn't define __missing__ at all
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class D(dict):
def __missing__(self, key):
return 42
@ -864,7 +864,7 @@ class DictTest(__TestCase):
self.assertNotIn(2, d.keys())
self.assertEqual(d[2], 42)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class E(dict):
def __missing__(self, key):
raise RuntimeError(key)
@ -873,7 +873,7 @@ class DictTest(__TestCase):
e[42]
self.assertEqual(c.exception.args, (42,))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class F(dict):
def __init__(self):
# An instance variable __missing__ should have no effect
@ -883,7 +883,7 @@ class DictTest(__TestCase):
f[42]
self.assertEqual(c.exception.args, (42,))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class G(dict):
pass
g = G()
@ -900,7 +900,7 @@ class DictTest(__TestCase):
def test_bad_key(self):
# Dictionary lookups should fail if __eq__() raises an exception.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class CustomException(Exception):
pass
@ -947,7 +947,7 @@ class DictTest(__TestCase):
# Another dict resizing bug (SF bug #1456209).
# This caused Segmentation faults or Illegal instructions.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X(object):
def __hash__(self):
return 5
@ -977,7 +977,7 @@ class DictTest(__TestCase):
def test_container_iterator(self):
# Bug #3680: tp_traverse was not implemented for dictiter and
# dictview objects.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
pass
views = (dict.items, dict.values, dict.keys)
@ -1033,7 +1033,7 @@ class DictTest(__TestCase):
def test_track_dynamic(self):
# Test GC-optimization of dynamically-created dicts
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyObject(object):
pass
x, y, z, w, o = 1.5, "a", (1, object()), [], MyObject()
@ -1103,7 +1103,7 @@ class DictTest(__TestCase):
self._tracked(MyDict())
def make_shared_key_dict(self, n):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
pass
@ -1194,7 +1194,7 @@ class DictTest(__TestCase):
@support.cpython_only
def test_splittable_update(self):
"""dict.update(other) must preserve order in other."""
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
def __init__(self, order):
if order:
@ -1212,7 +1212,7 @@ class DictTest(__TestCase):
@support.cpython_only
def test_splittable_to_generic_combinedtable(self):
"""split table must be correctly resized and converted to generic combined table"""
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
pass
@ -1336,19 +1336,19 @@ class DictTest(__TestCase):
self.assertEqual(sorted(values), sorted(data.values()))
def test_instance_dict_getattr_str_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Foo:
def __init__(self, msg):
self.msg = msg
f = Foo('123')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class _str(str):
pass
self.assertEqual(f.msg, getattr(f, _str('msg')))
self.assertEqual(f.msg, f.__dict__[_str('msg')])
def test_object_set_item_single_instance_non_str_key(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Foo: pass
f = Foo()
f.__dict__[1] = 1
@ -1359,7 +1359,7 @@ class DictTest(__TestCase):
# This object will trigger mutation of the dict when replaced
# by another value. Note this relies on refcounting: the test
# won't achieve its purpose on fully-GCed Python implementations.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Mutating:
def __del__(self):
mutate(d)
@ -1385,7 +1385,7 @@ class DictTest(__TestCase):
self.check_reentrant_insertion(mutate)
def test_merge_and_mutate(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X:
def __hash__(self):
return 0
@ -1408,7 +1408,7 @@ class DictTest(__TestCase):
def test_equal_operator_modifying_operand(self):
# test fix for seg fault reported in bpo-27945 part 3.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X():
def __del__(self):
dict_b.clear()
@ -1425,7 +1425,7 @@ class DictTest(__TestCase):
self.assertTrue(dict_a == dict_b)
# test fix for seg fault reported in bpo-38588 part 1.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Y:
def __eq__(self, other):
dict_d.clear()
@ -1437,7 +1437,7 @@ class DictTest(__TestCase):
def test_fromkeys_operator_modifying_dict_operand(self):
# test fix for seg fault reported in issue 27945 part 4a.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X(int):
def __hash__(self):
return 13
@ -1456,7 +1456,7 @@ class DictTest(__TestCase):
def test_fromkeys_operator_modifying_set_operand(self):
# test fix for seg fault reported in issue 27945 part 4b.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X(int):
def __hash__(self):
return 13
@ -1474,7 +1474,7 @@ class DictTest(__TestCase):
pass
def test_dictitems_contains_use_after_free(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X:
def __eq__(self, other):
d.clear()
@ -1485,7 +1485,7 @@ class DictTest(__TestCase):
def test_dict_contain_use_after_free(self):
# bpo-40489
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class S(str):
def __eq__(self, other):
d.clear()
@ -1498,7 +1498,7 @@ class DictTest(__TestCase):
self.assertFalse('test' in d)
def test_init_use_after_free(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X:
def __hash__(self):
pair[:] = []
@ -1508,7 +1508,7 @@ class DictTest(__TestCase):
dict([pair])
def test_oob_indexing_dictiter_iternextitem(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X(int):
def __del__(self):
d.clear()
@ -1545,7 +1545,7 @@ class DictTest(__TestCase):
self.assertEqual(list(reversed(dict().keys())), [])
def test_reverse_iterator_for_shared_shared_dicts(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
def __init__(self, x, y):
if x: self.x = x
@ -1565,7 +1565,7 @@ class DictTest(__TestCase):
self.assertEqual(list(copy.items()), expected)
# dict subclass doesn't override __iter__
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class CustomDict(dict):
pass
@ -1574,7 +1574,7 @@ class DictTest(__TestCase):
d = CustomDict(pairs)
self.assertEqual(pairs, list(dict(d).items()))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class CustomReversedDict(dict):
def keys(self):
return reversed(list(dict.keys(self)))
@ -1607,7 +1607,7 @@ class DictTest(__TestCase):
self.assertTrue(gc.is_tracked(next(it)))
def test_store_evilattr(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class EvilAttr:
def __init__(self, d):
self.d = d
@ -1630,13 +1630,13 @@ class DictTest(__TestCase):
# `str` keys. Make sure the unoptimized path is used when a non-`str`
# key appears.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class StrSub(str):
pass
eq_count = 0
# This class compares equal to the string 'key3'
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key3:
def __hash__(self):
return hash('key3')
@ -1746,7 +1746,7 @@ class CAPITest(__TestCase):
# key does not exist
self.assertRaises(KeyError, dict_getitem_knownhash, {}, 1, hash(1))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadEq:
def __eq__(self, other):

View File

@ -62,7 +62,7 @@ index 97f951f1299..da82bd190c3 100644
import os
@@ -8,11 +62,84 @@ import time
import unittest
from test import support
-from test.support.testcase import FloatsAreIdenticalMixin
-from test.support.numbers import (
@ -149,14 +149,14 @@ index 97f951f1299..da82bd190c3 100644
+
from math import isinf, isnan, copysign, ldexp
import math
@@ -35,7 +162,7 @@ class FloatSubclass(float):
class OtherFloatSubclass(float):
pass
-class GeneralFloatCases(unittest.TestCase):
+class GeneralFloatCases(__TestCase):
def test_float(self):
self.assertEqual(float(3.14), 3.14)
@@ -95,9 +222,10 @@ class GeneralFloatCases(unittest.TestCase):
@ -166,51 +166,51 @@ index 97f951f1299..da82bd190c3 100644
- class CustomStr(str): pass
- class CustomBytes(bytes): pass
- class CustomByteArray(bytearray): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class CustomStr(str): pass
+ class CustomBytes(bytes): pass
+ class CustomByteArray(bytearray): pass
factories = [
bytes,
@@ -184,30 +312,31 @@ class GeneralFloatCases(unittest.TestCase):
def test_floatconversion(self):
# Make sure that calls to __float__() work properly
- class Foo1(object):
- def __float__(self):
- return 42.
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Foo1(object):
+ def __float__(self):
+ return 42.
- class Foo2(float):
- def __float__(self):
- return 42.
+ class Foo2(float):
+ def __float__(self):
+ return 42.
- class Foo3(float):
- def __new__(cls, value=0.):
- return float.__new__(cls, 2*value)
+ class Foo3(float):
+ def __new__(cls, value=0.):
+ return float.__new__(cls, 2*value)
- def __float__(self):
- return self
+ def __float__(self):
+ return self
- class Foo4(float):
- def __float__(self):
- return 42
+ class Foo4(float):
+ def __float__(self):
+ return 42
- # Issue 5759: __float__ not called on str subclasses (though it is on
- # unicode subclasses).
- class FooStr(str):
@ -221,27 +221,27 @@ index 97f951f1299..da82bd190c3 100644
+ class FooStr(str):
+ def __float__(self):
+ return float(str(self)) + 1
self.assertEqual(float(Foo1()), 42.)
self.assertEqual(float(Foo2()), 42.)
@@ -216,15 +345,17 @@ class GeneralFloatCases(unittest.TestCase):
self.assertRaises(TypeError, float, Foo4(42))
self.assertEqual(float(FooStr('8')), 9.)
- class Foo5:
- def __float__(self):
- return ""
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Foo5:
+ def __float__(self):
+ return ""
self.assertRaises(TypeError, time.sleep, Foo5())
- # Issue #24731
- class F:
- def __float__(self):
- return OtherFloatSubclass(42.)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Issue #24731
+ class F:
+ def __float__(self):
@ -252,39 +252,39 @@ index 97f951f1299..da82bd190c3 100644
@@ -234,18 +365,20 @@ class GeneralFloatCases(unittest.TestCase):
with self.assertWarns(DeprecationWarning):
self.assertIs(type(FloatSubclass(F())), FloatSubclass)
- class MyIndex:
- def __init__(self, value):
- self.value = value
- def __index__(self):
- return self.value
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyIndex:
+ def __init__(self, value):
+ self.value = value
+ def __index__(self):
+ return self.value
self.assertEqual(float(MyIndex(42)), 42.0)
self.assertRaises(OverflowError, float, MyIndex(2**2000))
- class MyInt:
- def __int__(self):
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyInt:
+ def __int__(self):
+ return 42
self.assertRaises(TypeError, float, MyInt())
@@ -254,27 +387,30 @@ class GeneralFloatCases(unittest.TestCase):
float(x='3.14')
def test_keywords_in_subclass(self):
- class subclass(float):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass(float):
+ pass
u = subclass(2.5)
@ -292,11 +292,11 @@ index 97f951f1299..da82bd190c3 100644
self.assertEqual(float(u), 2.5)
with self.assertRaises(TypeError):
subclass(x=0)
- class subclass_with_init(float):
- def __init__(self, arg, newarg=None):
- self.newarg = newarg
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_init(float):
+ def __init__(self, arg, newarg=None):
+ self.newarg = newarg
@ -304,13 +304,13 @@ index 97f951f1299..da82bd190c3 100644
self.assertIs(type(u), subclass_with_init)
self.assertEqual(float(u), 2.5)
self.assertEqual(u.newarg, 3)
- class subclass_with_new(float):
- def __new__(cls, arg, newarg=None):
- self = super().__new__(cls, arg)
- self.newarg = newarg
- return self
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_new(float):
+ def __new__(cls, arg, newarg=None):
+ self = super().__new__(cls, arg)
@ -328,7 +328,7 @@ index 97f951f1299..da82bd190c3 100644
- return 42
- class F(float, H):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class H:
+ def __hash__(self):
+ return 42
@ -336,8 +336,8 @@ index 97f951f1299..da82bd190c3 100644
+ pass
value = F('nan')
self.assertEqual(hash(value), object.__hash__(value))
@unittest.skipUnless(hasattr(float, "__getformat__"), "requires __getformat__")
-class FormatFunctionsTestCase(unittest.TestCase):
+class FormatFunctionsTestCase(__TestCase):
@ -347,25 +347,25 @@ index 97f951f1299..da82bd190c3 100644
@@ -645,7 +782,7 @@ LE_FLOAT_NAN = bytes(reversed(BE_FLOAT_NAN))
# is accident (today).
# let's also try to guarantee that -0.0 and 0.0 don't get confused.
-class IEEEFormatTestCase(unittest.TestCase):
+class IEEEFormatTestCase(__TestCase):
@support.requires_IEEE_754
def test_double_specials_do_unpack(self):
@@ -670,7 +807,7 @@ class IEEEFormatTestCase(unittest.TestCase):
self.assertEqual(struct.pack("<f", 3.40282356e38), struct.pack("<f", FLT_MAX))
self.assertEqual(struct.pack("<f", -3.40282356e38), struct.pack("<f", -FLT_MAX))
-class FormatTestCase(unittest.TestCase):
+class FormatTestCase(__TestCase):
def test_format(self):
# these should be rewritten to use both format(x, spec) and
@@ -767,7 +904,7 @@ class FormatTestCase(unittest.TestCase):
self.assertEqual(format(-123.34, '00.10e'), '-1.2334000000e+02')
self.assertEqual(format(-123.34, '00.10g'), '-123.34')
-class ReprTestCase(unittest.TestCase):
+class ReprTestCase(__TestCase):
def test_repr(self):
@ -373,7 +373,7 @@ index 97f951f1299..da82bd190c3 100644
'mathdata',
@@ -832,7 +969,29 @@ class ReprTestCase(unittest.TestCase):
self.assertEqual(repr(float(negs)), str(float(negs)))
@support.requires_IEEE_754
-class RoundTestCase(unittest.TestCase, FloatsAreIdenticalMixin):
+class RoundTestCase(__TestCase):
@ -399,11 +399,11 @@ index 97f951f1299..da82bd190c3 100644
+ else:
+ msg += ': zeros have different signs'
+ self.fail(msg.format(x, y))
def test_inf_nan(self):
self.assertRaises(OverflowError, round, INF)
@@ -955,7 +1114,7 @@ class RoundTestCase(unittest.TestCase, FloatsAreIdenticalMixin):
# Beginning with Python 2.6 float has cross platform compatible
# ways to create and represent inf and nan
-class InfNanTest(unittest.TestCase):
@ -412,7 +412,7 @@ index 97f951f1299..da82bd190c3 100644
self.assertTrue(isinf(float("inf")))
self.assertTrue(isinf(float("+inf")))
@@ -1056,12 +1215,35 @@ class InfNanTest(unittest.TestCase):
fromHex = float.fromhex
toHex = float.hex
-class HexFloatTestCase(FloatsAreIdenticalMixin, unittest.TestCase):
@ -421,7 +421,7 @@ index 97f951f1299..da82bd190c3 100644
MIN = fromHex('0x1p-1022') # min normal
TINY = fromHex('0x0.0000000000001p-1022') # min subnormal
EPS = fromHex('0x0.0000000000001p0') # diff between 1.0 and next float up
+ def assertFloatsAreIdentical(self, x, y):
+ """assert that floats x and y are identical, in the sense that:
+ (1) both x and y are nans, or
@ -447,37 +447,37 @@ index 97f951f1299..da82bd190c3 100644
+
def identical(self, x, y):
self.assertFloatsAreIdentical(x, y)
@@ -1482,17 +1664,19 @@ class HexFloatTestCase(FloatsAreIdenticalMixin, unittest.TestCase):
self.identical(x, fromHex(toHex(x)))
def test_subclass(self):
- class F(float):
- def __new__(cls, value):
- return float.__new__(cls, value + 1)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class F(float):
+ def __new__(cls, value):
+ return float.__new__(cls, value + 1)
f = F.fromhex((1.5).hex())
self.assertIs(type(f), F)
self.assertEqual(f, 2.5)
- class F2(float):
- def __init__(self, value):
- self.foo = 'bar'
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class F2(float):
+ def __init__(self, value):
+ self.foo = 'bar'
f = F2.fromhex((1.5).hex())
self.assertIs(type(f), F2)
@@ -1500,5 +1684,5 @@ class HexFloatTestCase(FloatsAreIdenticalMixin, unittest.TestCase):
self.assertEqual(getattr(f, 'foo', 'none'), 'bar')
-if __name__ == '__main__':
- unittest.main()
+if __name__ == "__main__":

View File

@ -222,7 +222,7 @@ class GeneralFloatCases(__TestCase):
def test_non_numeric_input_types(self):
# Test possible non-numeric types for the argument x, including
# subclasses of the explicitly documented accepted types.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class CustomStr(str): pass
class CustomBytes(bytes): pass
class CustomByteArray(bytearray): pass
@ -312,7 +312,7 @@ class GeneralFloatCases(__TestCase):
def test_floatconversion(self):
# Make sure that calls to __float__() work properly
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Foo1(object):
def __float__(self):
return 42.
@ -345,13 +345,13 @@ class GeneralFloatCases(__TestCase):
self.assertRaises(TypeError, float, Foo4(42))
self.assertEqual(float(FooStr('8')), 9.)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Foo5:
def __float__(self):
return ""
self.assertRaises(TypeError, time.sleep, Foo5())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Issue #24731
class F:
def __float__(self):
@ -365,7 +365,7 @@ class GeneralFloatCases(__TestCase):
with self.assertWarns(DeprecationWarning):
self.assertIs(type(FloatSubclass(F())), FloatSubclass)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyIndex:
def __init__(self, value):
self.value = value
@ -375,7 +375,7 @@ class GeneralFloatCases(__TestCase):
self.assertEqual(float(MyIndex(42)), 42.0)
self.assertRaises(OverflowError, float, MyIndex(2**2000))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyInt:
def __int__(self):
return 42
@ -387,7 +387,7 @@ class GeneralFloatCases(__TestCase):
float(x='3.14')
def test_keywords_in_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass(float):
pass
u = subclass(2.5)
@ -396,7 +396,7 @@ class GeneralFloatCases(__TestCase):
with self.assertRaises(TypeError):
subclass(x=0)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_init(float):
def __init__(self, arg, newarg=None):
self.newarg = newarg
@ -405,7 +405,7 @@ class GeneralFloatCases(__TestCase):
self.assertEqual(float(u), 2.5)
self.assertEqual(u.newarg, 3)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_new(float):
def __new__(cls, arg, newarg=None):
self = super().__new__(cls, arg)
@ -746,7 +746,7 @@ class GeneralFloatCases(__TestCase):
def test_hash_nan(self):
value = float('nan')
self.assertEqual(hash(value), object.__hash__(value))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class H:
def __hash__(self):
return 42
@ -1664,7 +1664,7 @@ class HexFloatTestCase(__TestCase):
self.identical(x, fromHex(toHex(x)))
def test_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class F(float):
def __new__(cls, value):
return float.__new__(cls, value + 1)
@ -1673,7 +1673,7 @@ class HexFloatTestCase(__TestCase):
self.assertIs(type(f), F)
self.assertEqual(f, 2.5)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class F2(float):
def __init__(self, value):
self.foo = 'bar'

View File

@ -59,7 +59,7 @@ index 48825f46911..731680d82a0 100644
+
import sys
import time
import unittest
from unittest import mock
from test import support
@ -144,35 +144,35 @@ index 48825f46911..731680d82a0 100644
+ '(1+1.5_j_)',
+ '(1+1.5_j)',
+]
try:
import _pylong
@@ -38,7 +165,7 @@ L = [
class IntSubclass(int):
pass
-class IntTestCases(unittest.TestCase):
+class IntTestCases(__TestCase):
def test_basic(self):
self.assertEqual(int(314), 314)
@@ -309,11 +436,13 @@ class IntTestCases(unittest.TestCase):
int('0', 5.0)
def test_int_base_indexable(self):
- class MyIndexable(object):
- def __init__(self, value):
- self.value = value
- def __index__(self):
- return self.value
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyIndexable(object):
+ def __init__(self, value):
+ self.value = value
+ def __index__(self):
+ return self.value
# Check out of range bases.
for base in 2**100, -2**100, 1, 37:
@@ -328,9 +457,11 @@ class IntTestCases(unittest.TestCase):
@ -183,44 +183,44 @@ index 48825f46911..731680d82a0 100644
- class CustomBytes(bytes): pass
- class CustomByteArray(bytearray): pass
+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class CustomStr(str): pass
+ class CustomBytes(bytes): pass
+ class CustomByteArray(bytearray): pass
factories = [
bytes,
@@ -372,72 +503,82 @@ class IntTestCases(unittest.TestCase):
def test_intconversion(self):
# Test __int__()
- class ClassicMissingMethods:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class ClassicMissingMethods:
+ pass
self.assertRaises(TypeError, int, ClassicMissingMethods())
- class MissingMethods(object):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MissingMethods(object):
+ pass
self.assertRaises(TypeError, int, MissingMethods())
- class Foo0:
- def __int__(self):
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Foo0:
+ def __int__(self):
+ return 42
self.assertEqual(int(Foo0()), 42)
- class Classic:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Classic:
+ pass
for base in (object, Classic):
@ -229,35 +229,35 @@ index 48825f46911..731680d82a0 100644
- return 42
- def __trunc__(self):
- return -12
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class IntOverridesTrunc(base):
+ def __int__(self):
+ return 42
+ def __trunc__(self):
+ return -12
self.assertEqual(int(IntOverridesTrunc()), 42)
- class JustTrunc(base):
- def __trunc__(self):
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class JustTrunc(base):
+ def __trunc__(self):
+ return 42
with self.assertWarns(DeprecationWarning):
self.assertEqual(int(JustTrunc()), 42)
- class ExceptionalTrunc(base):
- def __trunc__(self):
- 1 / 0
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class ExceptionalTrunc(base):
+ def __trunc__(self):
+ 1 / 0
with self.assertRaises(ZeroDivisionError), \
self.assertWarns(DeprecationWarning):
int(ExceptionalTrunc())
for trunc_result_base in (object, Classic):
- class Index(trunc_result_base):
- def __index__(self):
@ -266,7 +266,7 @@ index 48825f46911..731680d82a0 100644
- class TruncReturnsNonInt(base):
- def __trunc__(self):
- return Index()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Index(trunc_result_base):
+ def __index__(self):
+ return 42
@ -276,15 +276,15 @@ index 48825f46911..731680d82a0 100644
+ return Index()
with self.assertWarns(DeprecationWarning):
self.assertEqual(int(TruncReturnsNonInt()), 42)
- class Intable(trunc_result_base):
- def __int__(self):
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Intable(trunc_result_base):
+ def __int__(self):
+ return 42
- class TruncReturnsNonIndex(base):
- def __trunc__(self):
- return Intable()
@ -293,17 +293,17 @@ index 48825f46911..731680d82a0 100644
+ return Intable()
with self.assertWarns(DeprecationWarning):
self.assertEqual(int(TruncReturnsNonInt()), 42)
- class NonIntegral(trunc_result_base):
- def __trunc__(self):
- # Check that we avoid infinite recursion.
- return NonIntegral()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class NonIntegral(trunc_result_base):
+ def __trunc__(self):
+ # Check that we avoid infinite recursion.
+ return NonIntegral()
- class TruncReturnsNonIntegral(base):
- def __trunc__(self):
- return NonIntegral()
@ -316,152 +316,152 @@ index 48825f46911..731680d82a0 100644
@@ -449,27 +590,29 @@ class IntTestCases(unittest.TestCase):
self.fail("Failed to raise TypeError with %s" %
((base, trunc_result_base),))
- # Regression test for bugs.python.org/issue16060.
- class BadInt(trunc_result_base):
- def __int__(self):
- return 42.0
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Regression test for bugs.python.org/issue16060.
+ class BadInt(trunc_result_base):
+ def __int__(self):
+ return 42.0
- class TruncReturnsBadInt(base):
- def __trunc__(self):
- return BadInt()
+ class TruncReturnsBadInt(base):
+ def __trunc__(self):
+ return BadInt()
with self.assertRaises(TypeError), \
self.assertWarns(DeprecationWarning):
int(TruncReturnsBadInt())
def test_int_subclass_with_index(self):
- class MyIndex(int):
- def __index__(self):
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyIndex(int):
+ def __index__(self):
+ return 42
- class BadIndex(int):
- def __index__(self):
- return 42.0
+ class BadIndex(int):
+ def __index__(self):
+ return 42.0
my_int = MyIndex(7)
self.assertEqual(my_int, 7)
@@ -478,13 +621,14 @@ class IntTestCases(unittest.TestCase):
self.assertEqual(int(BadIndex()), 0)
def test_int_subclass_with_int(self):
- class MyInt(int):
- def __int__(self):
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyInt(int):
+ def __int__(self):
+ return 42
- class BadInt(int):
- def __int__(self):
- return 42.0
+ class BadInt(int):
+ def __int__(self):
+ return 42.0
my_int = MyInt(7)
self.assertEqual(my_int, 7)
@@ -495,33 +639,34 @@ class IntTestCases(unittest.TestCase):
self.assertRaises(TypeError, int, my_int)
def test_int_returns_int_subclass(self):
- class BadIndex:
- def __index__(self):
- return True
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadIndex:
+ def __index__(self):
+ return True
- class BadIndex2(int):
- def __index__(self):
- return True
+ class BadIndex2(int):
+ def __index__(self):
+ return True
- class BadInt:
- def __int__(self):
- return True
+ class BadInt:
+ def __int__(self):
+ return True
- class BadInt2(int):
- def __int__(self):
- return True
+ class BadInt2(int):
+ def __int__(self):
+ return True
- class TruncReturnsBadIndex:
- def __trunc__(self):
- return BadIndex()
+ class TruncReturnsBadIndex:
+ def __trunc__(self):
+ return BadIndex()
- class TruncReturnsBadInt:
- def __trunc__(self):
- return BadInt()
+ class TruncReturnsBadInt:
+ def __trunc__(self):
+ return BadInt()
- class TruncReturnsIntSubclass:
- def __trunc__(self):
- return True
+ class TruncReturnsIntSubclass:
+ def __trunc__(self):
+ return True
bad_int = BadIndex()
with self.assertWarns(DeprecationWarning):
@@ -566,6 +711,7 @@ class IntTestCases(unittest.TestCase):
self.assertEqual(n, 1)
self.assertIs(type(n), IntSubclass)
+ @skipIfTorchDynamo("flaky under dynamo")
def test_error_message(self):
def check(s, base=None):
with self.assertRaises(ValueError,
@@ -607,7 +753,7 @@ class IntTestCases(unittest.TestCase):
self.assertEqual(int('1_2_3_4_5_6_7', 32), 1144132807)
-class IntStrDigitLimitsTests(unittest.TestCase):
+class IntStrDigitLimitsTests(__TestCase):
int_class = int # Override this in subclasses to reuse the suite.
@@ -818,7 +964,7 @@ class IntSubclassStrDigitLimitsTests(IntStrDigitLimitsTests):
int_class = IntSubclass
-class PyLongModuleTests(unittest.TestCase):
+class PyLongModuleTests(__TestCase):
# Tests of the functions in _pylong.py. Those get used when the
# number of digits in the input values are large enough.
@@ -922,4 +1068,4 @@ class PyLongModuleTests(unittest.TestCase):
bits <<= 1
if __name__ == "__main__":
- unittest.main()
+ run_tests()

View File

@ -436,8 +436,8 @@ class IntTestCases(__TestCase):
int('0', 5.0)
def test_int_base_indexable(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
with torch._dynamo.error_on_graph_break(False):
class MyIndexable(object):
def __init__(self, value):
self.value = value
@ -458,7 +458,7 @@ class IntTestCases(__TestCase):
# Test possible non-numeric types for the argument x, including
# subclasses of the explicitly documented accepted types.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class CustomStr(str): pass
class CustomBytes(bytes): pass
class CustomByteArray(bytearray): pass
@ -503,28 +503,28 @@ class IntTestCases(__TestCase):
def test_intconversion(self):
# Test __int__()
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class ClassicMissingMethods:
pass
self.assertRaises(TypeError, int, ClassicMissingMethods())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MissingMethods(object):
pass
self.assertRaises(TypeError, int, MissingMethods())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Foo0:
def __int__(self):
return 42
self.assertEqual(int(Foo0()), 42)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Classic:
pass
for base in (object, Classic):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class IntOverridesTrunc(base):
def __int__(self):
return 42
@ -532,14 +532,14 @@ class IntTestCases(__TestCase):
return -12
self.assertEqual(int(IntOverridesTrunc()), 42)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class JustTrunc(base):
def __trunc__(self):
return 42
with self.assertWarns(DeprecationWarning):
self.assertEqual(int(JustTrunc()), 42)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class ExceptionalTrunc(base):
def __trunc__(self):
1 / 0
@ -548,7 +548,7 @@ class IntTestCases(__TestCase):
int(ExceptionalTrunc())
for trunc_result_base in (object, Classic):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Index(trunc_result_base):
def __index__(self):
return 42
@ -559,7 +559,7 @@ class IntTestCases(__TestCase):
with self.assertWarns(DeprecationWarning):
self.assertEqual(int(TruncReturnsNonInt()), 42)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Intable(trunc_result_base):
def __int__(self):
return 42
@ -570,7 +570,7 @@ class IntTestCases(__TestCase):
with self.assertWarns(DeprecationWarning):
self.assertEqual(int(TruncReturnsNonInt()), 42)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class NonIntegral(trunc_result_base):
def __trunc__(self):
# Check that we avoid infinite recursion.
@ -590,7 +590,7 @@ class IntTestCases(__TestCase):
self.fail("Failed to raise TypeError with %s" %
((base, trunc_result_base),))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Regression test for bugs.python.org/issue16060.
class BadInt(trunc_result_base):
def __int__(self):
@ -605,7 +605,7 @@ class IntTestCases(__TestCase):
int(TruncReturnsBadInt())
def test_int_subclass_with_index(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyIndex(int):
def __index__(self):
return 42
@ -621,7 +621,7 @@ class IntTestCases(__TestCase):
self.assertEqual(int(BadIndex()), 0)
def test_int_subclass_with_int(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyInt(int):
def __int__(self):
return 42
@ -639,7 +639,7 @@ class IntTestCases(__TestCase):
self.assertRaises(TypeError, int, my_int)
def test_int_returns_int_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadIndex:
def __index__(self):
return True

View File

@ -61,15 +61,15 @@ index 1b9f3cf7624..6560c7423a6 100644
+# ======= END DYNAMO PATCH =======
+
# Test iterators.
import sys
@@ -104,12 +161,10 @@ class EmptyIterClass:
# Main test suite
-class TestCase(unittest.TestCase):
+class TestCase(__TestCase):
# Helper to check that an iterator returns a given sequence
def check_iterator(self, it, seq, pickle=True):
- if pickle:
@ -78,7 +78,7 @@ index 1b9f3cf7624..6560c7423a6 100644
while 1:
try:
@@ -121,8 +176,6 @@ class TestCase(unittest.TestCase):
# Helper to check that a for loop generates a given sequence
def check_for_loop(self, expr, seq, pickle=True):
- if pickle:
@ -89,7 +89,7 @@ index 1b9f3cf7624..6560c7423a6 100644
@@ -261,19 +314,20 @@ class TestCase(unittest.TestCase):
def run(builtin_name, item, sentinel=None):
it = iter(item) if sentinel is None else iter(item, sentinel)
- class CustomStr:
- def __init__(self, name, iterator):
- self.name = name
@ -103,7 +103,7 @@ index 1b9f3cf7624..6560c7423a6 100644
- # the pointers after this call
- list(self.iterator)
- return other == self.name
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class CustomStr:
+ def __init__(self, name, iterator):
+ self.name = name
@ -117,25 +117,25 @@ index 1b9f3cf7624..6560c7423a6 100644
+ # the pointers after this call
+ list(self.iterator)
+ return other == self.name
# del is required here
# to not prematurely call __eq__ from
@@ -323,9 +377,10 @@ class TestCase(unittest.TestCase):
# Test a new_style class with __iter__ but no next() method
def test_new_style_iter_class(self):
- class IterClass(object):
- def __iter__(self):
- return self
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class IterClass(object):
+ def __iter__(self):
+ return self
self.assertRaises(TypeError, iter, IterClass())
# Test two-argument iter() with callable instance
@@ -394,11 +449,12 @@ class TestCase(unittest.TestCase):
# Test exception propagation through sequence iterator
def test_exception_sequence(self):
- class MySequenceClass(SequenceClass):
@ -143,7 +143,7 @@ index 1b9f3cf7624..6560c7423a6 100644
- if i == 10:
- raise RuntimeError
- return SequenceClass.__getitem__(self, i)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MySequenceClass(SequenceClass):
+ def __getitem__(self, i):
+ if i == 10:
@ -153,7 +153,7 @@ index 1b9f3cf7624..6560c7423a6 100644
try:
for x in MySequenceClass(20):
@@ -410,11 +466,12 @@ class TestCase(unittest.TestCase):
# Test for StopIteration from __getitem__
def test_stop_sequence(self):
- class MySequenceClass(SequenceClass):
@ -161,25 +161,25 @@ index 1b9f3cf7624..6560c7423a6 100644
- if i == 10:
- raise StopIteration
- return SequenceClass.__getitem__(self, i)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MySequenceClass(SequenceClass):
+ def __getitem__(self, i):
+ if i == 10:
+ raise StopIteration
+ return SequenceClass.__getitem__(self, i)
self.check_for_loop(MySequenceClass(20), list(range(10)), pickle=False)
# Test a big range
@@ -541,32 +598,34 @@ class TestCase(unittest.TestCase):
self.assertRaises(TypeError, filter, None, list)
self.assertRaises(TypeError, filter, None, 42)
- class Boolean:
- def __init__(self, truth):
- self.truth = truth
- def __bool__(self):
- return self.truth
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Boolean:
+ def __init__(self, truth):
+ self.truth = truth
@ -187,7 +187,7 @@ index 1b9f3cf7624..6560c7423a6 100644
+ return self.truth
bTrue = Boolean(True)
bFalse = Boolean(False)
- class Seq:
- def __init__(self, *args):
- self.vals = args
@ -206,7 +206,7 @@ index 1b9f3cf7624..6560c7423a6 100644
- else:
- raise StopIteration
- return SeqIter(self.vals)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Seq:
+ def __init__(self, *args):
+ self.vals = args
@ -225,12 +225,12 @@ index 1b9f3cf7624..6560c7423a6 100644
+ else:
+ raise StopIteration
+ return SeqIter(self.vals)
seq = Seq(*([bTrue, bFalse] * 25))
self.assertEqual(list(filter(lambda x: not x, seq)), [bFalse]*25)
@@ -635,6 +694,7 @@ class TestCase(unittest.TestCase):
pass
# Test zip()'s use of iterators.
+ @skipIfTorchDynamo("infinite loop")
def test_builtin_zip(self):
@ -238,21 +238,21 @@ index 1b9f3cf7624..6560c7423a6 100644
self.assertEqual(list(zip(*[])), [])
@@ -653,17 +713,18 @@ class TestCase(unittest.TestCase):
self.assertEqual(list(d.items()), list(zip(d, d.values())))
# Generate all ints starting at constructor arg.
- class IntsFrom:
- def __init__(self, start):
- self.i = start
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class IntsFrom:
+ def __init__(self, start):
+ self.i = start
- def __iter__(self):
- return self
+ def __iter__(self):
+ return self
- def __next__(self):
- i = self.i
- self.i = i+1
@ -261,60 +261,60 @@ index 1b9f3cf7624..6560c7423a6 100644
+ i = self.i
+ self.i = i+1
+ return i
f = open(TESTFN, "w", encoding="utf-8")
try:
@@ -686,19 +747,20 @@ class TestCase(unittest.TestCase):
self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)])
# Classes that lie about their lengths.
- class NoGuessLen5:
- def __getitem__(self, i):
- if i >= 5:
- raise IndexError
- return i
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class NoGuessLen5:
+ def __getitem__(self, i):
+ if i >= 5:
+ raise IndexError
+ return i
- class Guess3Len5(NoGuessLen5):
- def __len__(self):
- return 3
+ class Guess3Len5(NoGuessLen5):
+ def __len__(self):
+ return 3
- class Guess30Len5(NoGuessLen5):
- def __len__(self):
- return 30
+ class Guess30Len5(NoGuessLen5):
+ def __len__(self):
+ return 30
def lzip(*args):
return list(zip(*args))
@@ -718,20 +780,21 @@ class TestCase(unittest.TestCase):
# This class inserts a Unicode object into its argument's natural
# iteration, in the 3rd position.
- class OhPhooey:
- def __init__(self, seq):
- self.it = iter(seq)
- self.i = 0
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class OhPhooey:
+ def __init__(self, seq):
+ self.it = iter(seq)
+ self.i = 0
- def __iter__(self):
- return self
+ def __iter__(self):
+ return self
- def __next__(self):
- i = self.i
- self.i = i+1
@ -327,25 +327,25 @@ index 1b9f3cf7624..6560c7423a6 100644
+ if i == 2:
+ return "fooled you!"
+ return next(self.it)
f = open(TESTFN, "w", encoding="utf-8")
try:
@@ -895,29 +958,30 @@ class TestCase(unittest.TestCase):
f.writelines({})
# Try a big chunk too.
- class Iterator:
- def __init__(self, start, finish):
- self.start = start
- self.finish = finish
- self.i = self.start
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Iterator:
+ def __init__(self, start, finish):
+ self.start = start
+ self.finish = finish
+ self.i = self.start
- def __next__(self):
- if self.i >= self.finish:
- raise StopIteration
@ -358,12 +358,12 @@ index 1b9f3cf7624..6560c7423a6 100644
+ result = str(self.i) + '\n'
+ self.i += 1
+ return result
- def __iter__(self):
- return self
+ def __iter__(self):
+ return self
- class Whatever:
- def __init__(self, start, finish):
- self.start = start
@ -372,16 +372,16 @@ index 1b9f3cf7624..6560c7423a6 100644
+ def __init__(self, start, finish):
+ self.start = start
+ self.finish = finish
- def __iter__(self):
- return Iterator(self.start, self.finish)
+ def __iter__(self):
+ return Iterator(self.start, self.finish)
f.writelines(Whatever(6, 6+2000))
f.close()
@@ -990,15 +1054,16 @@ class TestCase(unittest.TestCase):
@cpython_only
def test_ref_counting_behavior(self):
- class C(object):
@ -393,7 +393,7 @@ index 1b9f3cf7624..6560c7423a6 100644
- cls = self.__class__
- assert cls.count > 0
- cls.count -= 1
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ count = 0
+ def __new__(cls):
@ -407,7 +407,7 @@ index 1b9f3cf7624..6560c7423a6 100644
self.assertEqual(C.count, 1)
del x
@@ -1089,12 +1154,13 @@ class TestCase(unittest.TestCase):
def test_3720(self):
# Avoid a crash, when an iterator deletes its next() method.
- class BadIterator(object):
@ -416,19 +416,19 @@ index 1b9f3cf7624..6560c7423a6 100644
- def __next__(self):
- del BadIterator.__next__
- return 1
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadIterator(object):
+ def __iter__(self):
+ return self
+ def __next__(self):
+ del BadIterator.__next__
+ return 1
try:
for i in BadIterator() :
@@ -1187,4 +1253,4 @@ class TestCase(unittest.TestCase):
if __name__ == "__main__":
- unittest.main()
+ run_tests()

View File

@ -314,7 +314,7 @@ class TestCase(__TestCase):
def run(builtin_name, item, sentinel=None):
it = iter(item) if sentinel is None else iter(item, sentinel)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class CustomStr:
def __init__(self, name, iterator):
self.name = name
@ -377,7 +377,7 @@ class TestCase(__TestCase):
# Test a new_style class with __iter__ but no next() method
def test_new_style_iter_class(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class IterClass(object):
def __iter__(self):
return self
@ -449,7 +449,7 @@ class TestCase(__TestCase):
# Test exception propagation through sequence iterator
def test_exception_sequence(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MySequenceClass(SequenceClass):
def __getitem__(self, i):
if i == 10:
@ -466,7 +466,7 @@ class TestCase(__TestCase):
# Test for StopIteration from __getitem__
def test_stop_sequence(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MySequenceClass(SequenceClass):
def __getitem__(self, i):
if i == 10:
@ -598,7 +598,7 @@ class TestCase(__TestCase):
self.assertRaises(TypeError, filter, None, list)
self.assertRaises(TypeError, filter, None, 42)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Boolean:
def __init__(self, truth):
self.truth = truth
@ -607,7 +607,7 @@ class TestCase(__TestCase):
bTrue = Boolean(True)
bFalse = Boolean(False)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Seq:
def __init__(self, *args):
self.vals = args
@ -713,7 +713,7 @@ class TestCase(__TestCase):
self.assertEqual(list(d.items()), list(zip(d, d.values())))
# Generate all ints starting at constructor arg.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class IntsFrom:
def __init__(self, start):
self.i = start
@ -747,7 +747,7 @@ class TestCase(__TestCase):
self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)])
# Classes that lie about their lengths.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class NoGuessLen5:
def __getitem__(self, i):
if i >= 5:
@ -780,7 +780,7 @@ class TestCase(__TestCase):
# This class inserts a Unicode object into its argument's natural
# iteration, in the 3rd position.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class OhPhooey:
def __init__(self, seq):
self.it = iter(seq)
@ -958,7 +958,7 @@ class TestCase(__TestCase):
f.writelines({})
# Try a big chunk too.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Iterator:
def __init__(self, start, finish):
self.start = start
@ -1054,7 +1054,7 @@ class TestCase(__TestCase):
@cpython_only
def test_ref_counting_behavior(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
count = 0
def __new__(cls):
@ -1154,7 +1154,7 @@ class TestCase(__TestCase):
def test_3720(self):
# Avoid a crash, when an iterator deletes its next() method.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadIterator(object):
def __iter__(self):
return self

View File

@ -31,7 +31,7 @@ index 7d5ba727389..ff514815da2 100644
@@ -40,6 +62,14 @@ def pickle_deprecated(testfunc):
maxsize = support.MAX_Py_ssize_t
minsize = -maxsize-1
+@torch._dynamo.disable
+def choice(*args):
+ return random.choice(*args)
@ -42,33 +42,33 @@ index 7d5ba727389..ff514815da2 100644
+
def lzip(*args):
return list(zip(*args))
@@ -90,10 +120,10 @@ def fact(n):
return prod(range(1, n+1))
# root level methods for pickling ability
-def testR(r):
+def _testR(r):
return r[0]
-def testR2(r):
+def _testR2(r):
return r[2]
def underten(x):
@@ -102,7 +132,7 @@ def underten(x):
picklecopiers = [lambda s, proto=proto: pickle.loads(pickle.dumps(s, proto))
for proto in range(pickle.HIGHEST_PROTOCOL + 1)]
-class TestBasicOps(unittest.TestCase):
+class TestBasicOps(__TestCase):
def pickletest(self, protocol, it, stop=4, take=1, compare=None):
"""Test that an iterator is the same after pickling, also when part-consumed"""
@@ -454,14 +484,8 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1)
- @pickle_deprecated
def test_permutations(self):
- self.assertRaises(TypeError, permutations) # too few arguments
@ -79,11 +79,11 @@ index 7d5ba727389..ff514815da2 100644
- self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None
self.assertEqual(list(permutations(range(3), 2)),
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
@@ -498,7 +522,7 @@ class TestBasicOps(unittest.TestCase):
if len(set(indices)) == r:
yield tuple(pool[i] for i in indices)
- for n in range(7):
+ for n in range(5):
values = [5*x-12 for x in range(n)]
@ -92,7 +92,7 @@ index 7d5ba727389..ff514815da2 100644
@@ -515,9 +539,6 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(result, list(permutations(values, None))) # test r as None
self.assertEqual(result, list(permutations(values))) # test default r
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- self.pickletest(proto, permutations(values, r)) # test pickling
-
@ -107,7 +107,7 @@ index 7d5ba727389..ff514815da2 100644
+ # self.assertRaises(TypeError, cycle)
self.assertRaises(TypeError, cycle, 5)
self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0])
@@ -888,7 +909,7 @@ class TestBasicOps(unittest.TestCase):
# Check normal pickled
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
@ -118,7 +118,7 @@ index 7d5ba727389..ff514815da2 100644
self.assertEqual(k, elem[0])
dup.append(elem)
@@ -896,8 +917,8 @@ class TestBasicOps(unittest.TestCase):
# Check nested case
dup = []
- for k, g in groupby(s, testR):
@ -140,8 +140,8 @@ index 7d5ba727389..ff514815da2 100644
self.assertEqual(k, elem[0])
self.assertEqual(ik, elem[2])
@@ -917,7 +938,7 @@ class TestBasicOps(unittest.TestCase):
# Check case where inner iterator is not used
- keys = [k for k, g in groupby(s, testR)]
+ keys = [k for k, g in groupby(s, _testR)]
@ -159,7 +159,7 @@ index 7d5ba727389..ff514815da2 100644
_, g3 = next(it)
@@ -936,7 +957,7 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(g3), [])
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- it = groupby(s, testR)
+ it = groupby(s, _testR)
@ -182,7 +182,7 @@ index 7d5ba727389..ff514815da2 100644
+ # self.assertRaises(TypeError, filter, isEven, 3)
+ # dynamo raises Unsupported in this case
+ # self.assertRaises(TypeError, next, filter(range(6), range(6)))
# check copy, deepcopy, pickle
- ans = [0,2,4]
-
@ -212,7 +212,7 @@ index 7d5ba727389..ff514815da2 100644
+ # for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ # c = filter(isEven, range(6))
+ # self.pickletest(proto, c)
- @pickle_deprecated
def test_filterfalse(self):
self.assertEqual(list(filterfalse(isEven, range(6))), [1,3,5])
@ -224,11 +224,11 @@ index 7d5ba727389..ff514815da2 100644
- self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- self.pickletest(proto, filterfalse(isEven, range(6)))
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ self.pickletest(proto, filterfalse(isEven, range(6)))
def test_zip(self):
# XXX This is rather silly now that builtin zip() calls zip()...
@@ -1047,8 +1070,8 @@ class TestBasicOps(unittest.TestCase):
@ -243,7 +243,7 @@ index 7d5ba727389..ff514815da2 100644
lzip('abc', 'def'))
self.assertEqual([pair for pair in zip('abc', 'def')],
@@ -1105,19 +1128,19 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(zip_longest('abc', 'defg', **{})),
list(zip(list('abc')+[None], 'defg'))) # empty keyword dict
- self.assertRaises(TypeError, zip_longest, 3)
@ -272,7 +272,7 @@ index 7d5ba727389..ff514815da2 100644
+ # pass
+ # else:
+ # self.fail('Did not raise Type in: ' + stmt)
self.assertEqual([tuple(list(pair)) for pair in zip_longest('abc', 'def')],
list(zip('abc', 'def')))
@@ -1296,7 +1319,6 @@ class TestBasicOps(unittest.TestCase):
@ -280,7 +280,7 @@ index 7d5ba727389..ff514815da2 100644
list(product(*args, **dict(repeat=r))))
self.assertEqual(len(list(product(*[range(7)]*6))), 7**6)
- self.assertRaises(TypeError, product, range(6), None)
def product1(*args, **kwds):
pools = list(map(tuple, args)) * kwds.get('repeat', 1)
@@ -1336,7 +1358,8 @@ class TestBasicOps(unittest.TestCase):
@ -295,7 +295,7 @@ index 7d5ba727389..ff514815da2 100644
self.assertEqual(list(product(*args)), list(product1(*args)))
@@ -1767,6 +1790,7 @@ class TestBasicOps(unittest.TestCase):
script_helper.assert_python_ok("-c", script)
# Issue 13454: Crash when deleting backward iterator from tee()
+ @skipIfTorchDynamo("infinite loop in torch dynamo")
def test_tee_del_backward(self):
@ -303,68 +303,68 @@ index 7d5ba727389..ff514815da2 100644
try:
@@ -1920,7 +1944,7 @@ class TestBasicOps(unittest.TestCase):
tp.foobar = 1
-class TestExamples(unittest.TestCase):
+class TestExamples(__TestCase):
def test_accumulate(self):
self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15])
@@ -2032,7 +2056,7 @@ class TestExamples(unittest.TestCase):
self.assertEqual(list(takewhile(lambda x: x<5, [1,4,6,4,1])), [1,4])
-class TestPurePythonRoughEquivalents(unittest.TestCase):
+class TestPurePythonRoughEquivalents(__TestCase):
def test_batched_recipe(self):
def batched_recipe(iterable, n):
@@ -2081,6 +2105,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
for i, element in zip(range(i + 1, stop), iterable):
pass
+ @skipIfTorchDynamo("infinite loop in torch dynamo")
def test_islice_recipe(self):
self.assertEqual(list(self.islice('ABCDEFG', 2)), list('AB'))
self.assertEqual(list(self.islice('ABCDEFG', 2, 4)), list('CD'))
@@ -2265,7 +2290,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
raise
-class TestGC(unittest.TestCase):
+class TestGC(__TestCase):
def makecycle(self, iterator, container):
container.append(iterator)
@@ -2465,7 +2490,7 @@ def L(seqn):
return chain(map(lambda x:x, R(Ig(G(seqn)))))
-class TestVariousIteratorArgs(unittest.TestCase):
+class TestVariousIteratorArgs(__TestCase):
def test_accumulate(self):
s = [1,2,3,4,5]
@@ -2644,7 +2669,7 @@ class TestVariousIteratorArgs(unittest.TestCase):
self.assertRaises(TypeError, tee, N(s))
self.assertRaises(ZeroDivisionError, list, tee(E(s))[0])
-class LengthTransparency(unittest.TestCase):
+class LengthTransparency(__TestCase):
def test_repeat(self):
self.assertEqual(operator.length_hint(repeat(None, 50)), 50)
@@ -2657,7 +2682,7 @@ class LengthTransparency(unittest.TestCase):
self.assertEqual(operator.length_hint(repeat(None, times=-1)), 0)
self.assertEqual(operator.length_hint(repeat(None, times=-2)), 0)
-class RegressionTests(unittest.TestCase):
+class RegressionTests(__TestCase):
def test_sf_793826(self):
# Fix Armin Rigo's successful efforts to wreak havoc
@@ -2718,6 +2743,7 @@ class RegressionTests(unittest.TestCase):
@support.skip_if_pgo_task
@support.requires_resource('cpu')
+ @slowTest
@ -373,8 +373,8 @@ index 7d5ba727389..ff514815da2 100644
# dealing with long chains of empty iterables. Even with a high
@@ -2750,7 +2776,7 @@ class RegressionTests(unittest.TestCase):
next(g, None) # shouldn't crash
-class SubclassWithKwargsTest(unittest.TestCase):
+class SubclassWithKwargsTest(__TestCase):
def test_keywords_in_subclass(self):
@ -382,8 +382,8 @@ index 7d5ba727389..ff514815da2 100644
testcases = [
@@ -2805,49 +2831,5 @@ class SubclassWithKwargsTest(unittest.TestCase):
self.assertEqual(u.newarg, 3)
-@support.cpython_only
-class SizeofTest(unittest.TestCase):
- def setUp(self):

View File

@ -1056,7 +1056,7 @@ class TestBasicOps(__TestCase):
self.assertRaises(TypeError, filterfalse, lambda x:x)
self.assertRaises(TypeError, filterfalse, lambda x:x, range(6), 7)
self.assertRaises(TypeError, filterfalse, isEven, 3)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
self.pickletest(proto, filterfalse(isEven, range(6)))
@ -1358,7 +1358,7 @@ class TestBasicOps(__TestCase):
argtypes = ['', 'abc', '', range(0), range(4), dict(a=1, b=2, c=3),
set('abcdefg'), range(11), tuple(range(13))]
for i in range(100):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
args = [choice(argtypes) for j in range(randrange(5))]
expected_len = prod(map(len, args))
self.assertEqual(len(list(product(*args))), expected_len)

View File

@ -67,17 +67,17 @@ index 23ef902aa0b..b9afb1ef26e 100644
@@ -36,7 +90,7 @@ class ListTest(list_tests.CommonTest):
# earlier due to a newlib bug. See the following mailing list
# thread for the details:
self.assertRaises(MemoryError, list, range(sys.maxsize // 2))
# This code used to segfault in Py2.4a3
@@ -49,28 +103,31 @@ class ListTest(list_tests.CommonTest):
list(sequence=[])
def test_keywords_in_subclass(self):
- class subclass(list):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass(list):
+ pass
u = subclass([1, 2])
@ -85,12 +85,12 @@ index 23ef902aa0b..b9afb1ef26e 100644
self.assertEqual(list(u), [1, 2])
with self.assertRaises(TypeError):
subclass(sequence=())
- class subclass_with_init(list):
- def __init__(self, seq, newarg=None):
- super().__init__(seq)
- self.newarg = newarg
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_init(list):
+ def __init__(self, seq, newarg=None):
+ super().__init__(seq)
@ -99,13 +99,13 @@ index 23ef902aa0b..b9afb1ef26e 100644
self.assertIs(type(u), subclass_with_init)
self.assertEqual(list(u), [1, 2])
self.assertEqual(u.newarg, 3)
- class subclass_with_new(list):
- def __new__(cls, seq, newarg=None):
- self = super().__new__(cls, seq)
- self.newarg = newarg
- return self
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_new(list):
+ def __new__(cls, seq, newarg=None):
+ self = super().__new__(cls, seq)
@ -116,7 +116,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
self.assertEqual(list(u), [1, 2])
@@ -117,14 +174,15 @@ class ListTest(list_tests.CommonTest):
lst *= size
def test_repr_mutate(self):
- class Obj:
- @staticmethod
@ -126,7 +126,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
- except IndexError:
- pass
- return 'obj'
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Obj:
+ @staticmethod
+ def __repr__():
@ -135,7 +135,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
+ except IndexError:
+ pass
+ return 'obj'
mylist = [Obj() for _ in range(5)]
self.assertEqual(repr(mylist), '[obj, obj, obj]')
@@ -220,26 +278,28 @@ class ListTest(list_tests.CommonTest):
@ -143,11 +143,11 @@ index 23ef902aa0b..b9afb1ef26e 100644
# optimization causes failures in code that relies on distinct
# function addresses.
- class L(list): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class L(list): pass
with self.assertRaises(TypeError):
(3,) + L([1,2])
def test_equal_operator_modifying_operand(self):
# test fix for seg fault reported in bpo-38588 part 2.
- class X:
@ -164,7 +164,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
- def __eq__(self, other):
- list3.clear()
- return NotImplemented
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X:
+ def __eq__(self,other) :
+ list2.clear()
@ -179,29 +179,29 @@ index 23ef902aa0b..b9afb1ef26e 100644
+ def __eq__(self, other):
+ list3.clear()
+ return NotImplemented
list1 = [X()]
list2 = [Y()]
@@ -250,24 +310,26 @@ class ListTest(list_tests.CommonTest):
self.assertFalse(list3 == list4)
def test_lt_operator_modifying_operand(self):
- # See gh-120298
- class evil:
- def __lt__(self, other):
- other.clear()
- return NotImplemented
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # See gh-120298
+ class evil:
+ def __lt__(self, other):
+ other.clear()
+ return NotImplemented
a = [[evil()]]
with self.assertRaises(TypeError):
a[0] < a
def test_list_index_modifing_operand(self):
- # See gh-120384
- class evil:
@ -210,7 +210,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
- def __iter__(self):
- yield from self.lst
- self.lst.clear()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # See gh-120384
+ class evil:
+ def __init__(self, lst):
@ -218,7 +218,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
+ def __iter__(self):
+ yield from self.lst
+ self.lst.clear()
lst = list(range(5))
operand = evil(lst)
@@ -286,19 +348,21 @@ class ListTest(list_tests.CommonTest):
@ -229,39 +229,39 @@ index 23ef902aa0b..b9afb1ef26e 100644
- def __eq__(self, other):
- lst.clear()
- return NotImplemented
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X:
+ def __eq__(self, other):
+ lst.clear()
+ return NotImplemented
lst = [X()]
with self.assertRaises(ValueError):
lst.index(lst)
- class L(list):
- def __eq__(self, other):
- str(other)
- return NotImplemented
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class L(list):
+ def __eq__(self, other):
+ str(other)
+ return NotImplemented
lst = L([X()])
lst.count(lst)
@@ -324,6 +388,7 @@ class ListTest(list_tests.CommonTest):
a.append(4)
self.assertEqual(list(it), [])
+ @unittest.skip("Fails on python <=3.13.2 and passes on >=3.13.3")
def test_deopt_from_append_list(self):
# gh-132011: it used to crash, because
# of `CALL_LIST_APPEND` specialization failure.
@@ -345,4 +410,4 @@ class ListTest(list_tests.CommonTest):
self.assertEqual(rc, 0)
if __name__ == "__main__":
- unittest.main()
+ run_tests()

View File

@ -101,7 +101,7 @@ class ListTest(list_tests.CommonTest):
list(sequence=[])
def test_keywords_in_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass(list):
pass
u = subclass([1, 2])
@ -110,7 +110,7 @@ class ListTest(list_tests.CommonTest):
with self.assertRaises(TypeError):
subclass(sequence=())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_init(list):
def __init__(self, seq, newarg=None):
super().__init__(seq)
@ -120,7 +120,7 @@ class ListTest(list_tests.CommonTest):
self.assertEqual(list(u), [1, 2])
self.assertEqual(u.newarg, 3)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_new(list):
def __new__(cls, seq, newarg=None):
self = super().__new__(cls, seq)
@ -172,7 +172,7 @@ class ListTest(list_tests.CommonTest):
lst *= size
def test_repr_mutate(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Obj:
@staticmethod
def __repr__():
@ -276,14 +276,14 @@ class ListTest(list_tests.CommonTest):
# Issue 8847: In the PGO build, the MSVC linker's COMDAT folding
# optimization causes failures in code that relies on distinct
# function addresses.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class L(list): pass
with self.assertRaises(TypeError):
(3,) + L([1,2])
def test_equal_operator_modifying_operand(self):
# test fix for seg fault reported in bpo-38588 part 2.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X:
def __eq__(self,other) :
list2.clear()
@ -308,7 +308,7 @@ class ListTest(list_tests.CommonTest):
self.assertFalse(list3 == list4)
def test_lt_operator_modifying_operand(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# See gh-120298
class evil:
def __lt__(self, other):
@ -320,7 +320,7 @@ class ListTest(list_tests.CommonTest):
a[0] < a
def test_list_index_modifing_operand(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# See gh-120384
class evil:
def __init__(self, lst):
@ -346,7 +346,7 @@ class ListTest(list_tests.CommonTest):
# bpo-38610: The count(), index(), and remove() methods were not
# holding strong references to list elements while calling
# PyObject_RichCompareBool().
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X:
def __eq__(self, other):
lst.clear()
@ -356,7 +356,7 @@ class ListTest(list_tests.CommonTest):
with self.assertRaises(ValueError):
lst.index(lst)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class L(list):
def __eq__(self, other):
str(other)

View File

@ -63,20 +63,20 @@ index 5ee3055c871..5402cdc4a6c 100644
+
# Python test set -- math module
# XXXX Should not do tests around zero only
@@ -242,7 +300,7 @@ class BadDescr:
def __get__(self, obj, objtype=None):
raise ValueError
-class MathTests(unittest.TestCase):
+class MathTests(__TestCase):
def ftest(self, name, got, expected, ulp_tol=5, abs_tol=0.0):
"""Compare arguments expected and got, as floats, if either
@@ -417,16 +475,17 @@ class MathTests(unittest.TestCase):
#self.assertEqual(math.ceil(NINF), NINF)
#self.assertTrue(math.isnan(math.ceil(NAN)))
- class TestCeil:
- def __ceil__(self):
- return 42
@ -87,7 +87,7 @@ index 5ee3055c871..5402cdc4a6c 100644
- pass
- class TestBadCeil:
- __ceil__ = BadDescr()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class TestCeil:
+ def __ceil__(self):
+ return 42
@ -104,7 +104,7 @@ index 5ee3055c871..5402cdc4a6c 100644
@@ -533,6 +592,7 @@ class MathTests(unittest.TestCase):
self.ftest('fabs(0)', math.fabs(0), 0)
self.ftest('fabs(1)', math.fabs(1), 1)
+ @skipIfTorchDynamo("infinite loop")
def testFactorial(self):
self.assertEqual(math.factorial(0), 1)
@ -112,7 +112,7 @@ index 5ee3055c871..5402cdc4a6c 100644
@@ -573,16 +633,17 @@ class MathTests(unittest.TestCase):
#self.assertEqual(math.ceil(NINF), NINF)
#self.assertTrue(math.isnan(math.floor(NAN)))
- class TestFloor:
- def __floor__(self):
- return 42
@ -123,7 +123,7 @@ index 5ee3055c871..5402cdc4a6c 100644
- pass
- class TestBadFloor:
- __floor__ = BadDescr()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class TestFloor:
+ def __floor__(self):
+ return 42
@ -139,32 +139,32 @@ index 5ee3055c871..5402cdc4a6c 100644
self.assertEqual(math.floor(FloatLike(41.9)), 41)
@@ -995,8 +1056,9 @@ class MathTests(unittest.TestCase):
)
# Verify tuple subclasses are allowed
- class T(tuple):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class T(tuple):
+ pass
self.assertEqual(dist(T((1, 2, 3)), ((4, 2, -1))), 5.0)
# Test handling of bad arguments
@@ -1028,8 +1090,9 @@ class MathTests(unittest.TestCase):
with self.assertRaises(TypeError):
dist([1], 2)
- class BadFloat:
- __float__ = BadDescr()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadFloat:
+ __float__ = BadDescr()
with self.assertRaises(ValueError):
dist([1], [BadFloat()])
@@ -1072,6 +1135,7 @@ class MathTests(unittest.TestCase):
with self.assertRaises(ValueError):
math.dist([1, 2], [3, 4, 5])
+ @slowTest
def testIsqrt(self):
# Test a variety of inputs, large and small.
@ -172,26 +172,26 @@ index 5ee3055c871..5402cdc4a6c 100644
@@ -1101,12 +1165,13 @@ class MathTests(unittest.TestCase):
self.assertIs(type(s), int)
self.assertEqual(s, 0)
- class IntegerLike(object):
- def __init__(self, value):
- self.value = value
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class IntegerLike(object):
+ def __init__(self, value):
+ self.value = value
- def __index__(self):
- return self.value
+ def __index__(self):
+ return self.value
s = math.isqrt(IntegerLike(1729))
self.assertIs(type(s), int)
@@ -1202,12 +1267,6 @@ class MathTests(unittest.TestCase):
self.assertEqual(math.ldexp(NINF, n), NINF)
self.assertTrue(math.isnan(math.ldexp(NAN, n)))
- @requires_IEEE_754
- def testLdexp_denormal(self):
- # Denormal output incorrectly rounded (truncated)
@ -204,7 +204,7 @@ index 5ee3055c871..5402cdc4a6c 100644
@@ -1233,6 +1292,7 @@ class MathTests(unittest.TestCase):
self.assertRaises(ValueError, math.log1p, -1)
self.assertEqual(math.log1p(INF), INF)
+ @skipIfTorchDynamo("Infinite loop")
@requires_IEEE_754
def testLog2(self):
@ -212,7 +212,7 @@ index 5ee3055c871..5402cdc4a6c 100644
@@ -1251,6 +1311,7 @@ class MathTests(unittest.TestCase):
self.assertRaises(ValueError, math.log2, NINF)
self.assertTrue(math.isnan(math.log2(NAN)))
+ @skipIfTorchDynamo("Infinite loop")
@requires_IEEE_754
# log2() is not accurate enough on Mac OS X Tiger (10.4)
@ -220,20 +220,20 @@ index 5ee3055c871..5402cdc4a6c 100644
@@ -1332,17 +1393,18 @@ class MathTests(unittest.TestCase):
with self.assertRaises(RuntimeError):
sumprod(raise_after(5), range(10))
- from test.test_iter import BasicIterClass
+ from test_iter import BasicIterClass
self.assertEqual(sumprod(BasicIterClass(1), [1]), 0)
self.assertEqual(sumprod([1], BasicIterClass(1)), 0)
# Error in multiplication
- class BadMultiply:
- def __mul__(self, other):
- raise RuntimeError
- def __rmul__(self, other):
- raise RuntimeError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadMultiply:
+ def __mul__(self, other):
+ raise RuntimeError
@ -245,7 +245,7 @@ index 5ee3055c871..5402cdc4a6c 100644
@@ -1387,25 +1449,26 @@ class MathTests(unittest.TestCase):
Decimal = decimal.Decimal
Fraction = fractions.Fraction
- class Int(int):
- def __add__(self, other):
- return Int(int(self) + int(other))
@ -265,7 +265,7 @@ index 5ee3055c871..5402cdc4a6c 100644
- __rmul__ = __mul__
- def __repr__(self):
- return f'Flt({int(self)})'
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Int(int):
+ def __add__(self, other):
+ return Int(int(self) + int(other))
@ -285,13 +285,13 @@ index 5ee3055c871..5402cdc4a6c 100644
+ __rmul__ = __mul__
+ def __repr__(self):
+ return f'Flt({int(self)})'
def baseline_sumprod(p, q):
"""This defines the target behavior including exceptions and special values.
@@ -1925,16 +1988,17 @@ class MathTests(unittest.TestCase):
self.assertEqual(math.trunc(-0.999999), -0)
self.assertEqual(math.trunc(-100.999), -100)
- class TestTrunc:
- def __trunc__(self):
- return 23
@ -302,7 +302,7 @@ index 5ee3055c871..5402cdc4a6c 100644
- pass
- class TestBadTrunc:
- __trunc__ = BadDescr()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class TestTrunc:
+ def __trunc__(self):
+ return 23
@ -313,27 +313,27 @@ index 5ee3055c871..5402cdc4a6c 100644
+ pass
+ class TestBadTrunc:
+ __trunc__ = BadDescr()
self.assertEqual(math.trunc(TestTrunc()), 23)
self.assertEqual(math.trunc(FloatTrunc()), 23)
@@ -2167,9 +2231,10 @@ class MathTests(unittest.TestCase):
self.assertEqual(prod([1., F(3, 2)]), 1.5)
# Error in multiplication
- class BadMultiply:
- def __rmul__(self, other):
- raise RuntimeError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadMultiply:
+ def __rmul__(self, other):
+ raise RuntimeError
with self.assertRaises(RuntimeError):
prod([10., BadMultiply()])
@@ -2252,6 +2317,7 @@ class MathTests(unittest.TestCase):
self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])),
decimal.Decimal)
+ @skipIfTorchDynamo("Infinite loop")
def testPerm(self):
perm = math.perm
@ -341,15 +341,15 @@ index 5ee3055c871..5402cdc4a6c 100644
@@ -2316,6 +2382,7 @@ class MathTests(unittest.TestCase):
self.assertIs(type(perm(IntSubclass(5), IntSubclass(k))), int)
self.assertIs(type(perm(MyIndexable(5), MyIndexable(k))), int)
+ @skipIfTorchDynamo("infinite loop")
def testComb(self):
comb = math.comb
factorial = math.factorial
@@ -2446,6 +2513,7 @@ class MathTests(unittest.TestCase):
math.nextafter(1.0, INF, steps=-1)
+ @unittest.skip("flaky test under torch dynamo") # works on pytest and crashes on unittest
@requires_IEEE_754
def test_ulp(self):
@ -362,7 +362,7 @@ index 5ee3055c871..5402cdc4a6c 100644
- def __float__(self):
- self.converted = True
- 1/0
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class F:
+ def __float__(self):
+ self.converted = True
@ -372,21 +372,21 @@ index 5ee3055c871..5402cdc4a6c 100644
with self.assertRaises(TypeError):
@@ -2508,7 +2577,7 @@ class MathTests(unittest.TestCase):
self.assertEqual(math.copysign(1.0, x), math.copysign(1.0, y))
-class IsCloseTests(unittest.TestCase):
+class IsCloseTests(__TestCase):
isclose = math.isclose # subclasses should override this
def assertIsClose(self, a, b, *args, **kwargs):
@@ -2631,7 +2700,7 @@ class IsCloseTests(unittest.TestCase):
self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
-class FMATests(unittest.TestCase):
+class FMATests(__TestCase):
""" Tests for math.fma. """
def test_fma_nan_results(self):
@@ -2719,8 +2788,7 @@ class FMATests(unittest.TestCase):
# properly: it doesn't use the right sign when the result is zero.
@ -400,8 +400,8 @@ index 5ee3055c871..5402cdc4a6c 100644
nonnegative_finites = [0.0, 1e-300, 2.3, 1e300]
@@ -2879,10 +2947,5 @@ class FMATests(unittest.TestCase):
)
-def load_tests(loader, tests, pattern):
- from doctest import DocFileSuite
- tests.addTest(DocFileSuite(os.path.join("mathdata", "ieee754.txt")))

View File

@ -475,7 +475,7 @@ class MathTests(__TestCase):
#self.assertEqual(math.ceil(NINF), NINF)
#self.assertTrue(math.isnan(math.ceil(NAN)))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class TestCeil:
def __ceil__(self):
return 42
@ -633,7 +633,7 @@ class MathTests(__TestCase):
#self.assertEqual(math.ceil(NINF), NINF)
#self.assertTrue(math.isnan(math.floor(NAN)))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class TestFloor:
def __floor__(self):
return 42
@ -1056,7 +1056,7 @@ class MathTests(__TestCase):
)
# Verify tuple subclasses are allowed
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class T(tuple):
pass
self.assertEqual(dist(T((1, 2, 3)), ((4, 2, -1))), 5.0)
@ -1090,7 +1090,7 @@ class MathTests(__TestCase):
with self.assertRaises(TypeError):
dist([1], 2)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadFloat:
__float__ = BadDescr()
@ -1165,7 +1165,7 @@ class MathTests(__TestCase):
self.assertIs(type(s), int)
self.assertEqual(s, 0)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class IntegerLike(object):
def __init__(self, value):
self.value = value
@ -1399,7 +1399,7 @@ class MathTests(__TestCase):
self.assertEqual(sumprod([1], BasicIterClass(1)), 0)
# Error in multiplication
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadMultiply:
def __mul__(self, other):
raise RuntimeError
@ -1449,7 +1449,7 @@ class MathTests(__TestCase):
Decimal = decimal.Decimal
Fraction = fractions.Fraction
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Int(int):
def __add__(self, other):
return Int(int(self) + int(other))
@ -1988,7 +1988,7 @@ class MathTests(__TestCase):
self.assertEqual(math.trunc(-0.999999), -0)
self.assertEqual(math.trunc(-100.999), -100)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class TestTrunc:
def __trunc__(self):
return 23
@ -2231,7 +2231,7 @@ class MathTests(__TestCase):
self.assertEqual(prod([1., F(3, 2)]), 1.5)
# Error in multiplication
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadMultiply:
def __rmul__(self, other):
raise RuntimeError
@ -2540,7 +2540,7 @@ class MathTests(__TestCase):
def test_issue39871(self):
# A SystemError should not be raised if the first arg to atan2(),
# copysign(), or remainder() cannot be converted to a float.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class F:
def __float__(self):
self.converted = True

View File

@ -27,13 +27,13 @@ index d90f820052c..5d9fdfb70a4 100644
import inspect
import pickle
@@ -84,9 +104,10 @@ class OperatorTestCase:
def test_eq(self):
operator = self.module
- class C(object):
- def __eq__(self, other):
- raise SyntaxError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ def __eq__(self, other):
+ raise SyntaxError
@ -41,13 +41,13 @@ index d90f820052c..5d9fdfb70a4 100644
self.assertRaises(SyntaxError, operator.eq, C(), C())
self.assertFalse(operator.eq(1, 0))
@@ -98,9 +119,10 @@ class OperatorTestCase:
def test_ne(self):
operator = self.module
- class C(object):
- def __ne__(self, other):
- raise SyntaxError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ def __ne__(self, other):
+ raise SyntaxError
@ -61,21 +61,21 @@ index d90f820052c..5d9fdfb70a4 100644
- class M:
- def __matmul__(self, other):
- return other - 1
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class M:
+ def __matmul__(self, other):
+ return other - 1
self.assertEqual(M() @ 42, 41)
def test_neg(self):
@@ -315,9 +338,10 @@ class OperatorTestCase:
def test_truth(self):
operator = self.module
- class C(object):
- def __bool__(self):
- raise SyntaxError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ def __bool__(self):
+ raise SyntaxError
@ -83,12 +83,12 @@ index d90f820052c..5d9fdfb70a4 100644
self.assertRaises(SyntaxError, operator.truth, C())
self.assertTrue(operator.truth(5))
@@ -349,8 +373,9 @@ class OperatorTestCase:
def test_attrgetter(self):
operator = self.module
- class A:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ pass
a = A()
@ -97,39 +97,39 @@ index d90f820052c..5d9fdfb70a4 100644
@@ -371,9 +396,10 @@ class OperatorTestCase:
self.assertEqual(operator.attrgetter('x','z','y')(record), ('X', 'Z', 'Y'))
self.assertRaises(TypeError, operator.attrgetter, ('x', (), 'y'))
- class C(object):
- def __getattr__(self, name):
- raise SyntaxError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ def __getattr__(self, name):
+ raise SyntaxError
self.assertRaises(SyntaxError, operator.attrgetter('foo'), C())
# recursive gets
@@ -411,9 +437,10 @@ class OperatorTestCase:
f = operator.itemgetter(10)
self.assertRaises(IndexError, f, a)
- class C(object):
- def __getitem__(self, name):
- raise SyntaxError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ def __getitem__(self, name):
+ raise SyntaxError
self.assertRaises(SyntaxError, operator.itemgetter(42), C())
f = operator.itemgetter('name')
@@ -444,9 +471,10 @@ class OperatorTestCase:
self.assertEqual(operator.itemgetter(slice(2, 4))(t), ('c', 'd'))
# interesting sequences
- class T(tuple):
- 'Tuple subclass'
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class T(tuple):
+ 'Tuple subclass'
+ pass
@ -147,7 +147,7 @@ index d90f820052c..5d9fdfb70a4 100644
- return f
- def baz(*args, **kwds):
- return kwds['name'], kwds['self']
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ def foo(self, *args, **kwds):
+ return args[0] + args[1]
@ -159,7 +159,7 @@ index d90f820052c..5d9fdfb70a4 100644
f = operator.methodcaller('foo')
self.assertRaises(IndexError, f, a)
@@ -480,21 +509,22 @@ class OperatorTestCase:
def test_inplace(self):
operator = self.module
- class C(object):
@ -177,7 +177,7 @@ index d90f820052c..5d9fdfb70a4 100644
- def __itruediv__ (self, other): return "itruediv"
- def __ixor__ (self, other): return "ixor"
- def __getitem__(self, other): return 5 # so that C is a sequence
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ def __iadd__ (self, other): return "iadd"
+ def __iand__ (self, other): return "iand"
@ -197,27 +197,27 @@ index d90f820052c..5d9fdfb70a4 100644
self.assertEqual(operator.iadd (c, 5), "iadd")
self.assertEqual(operator.iand (c, 5), "iand")
@@ -520,9 +550,10 @@ class OperatorTestCase:
def test_index(self):
operator = self.module
- class X:
- def __index__(self):
- return 1
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X:
+ def __index__(self):
+ return 1
self.assertEqual(operator.index(X()), 1)
self.assertEqual(operator.index(0), 0)
@@ -539,9 +570,10 @@ class OperatorTestCase:
def test_not_(self):
operator = self.module
- class C:
- def __bool__(self):
- raise SyntaxError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ def __bool__(self):
+ raise SyntaxError
@ -225,17 +225,17 @@ index d90f820052c..5d9fdfb70a4 100644
self.assertRaises(SyntaxError, operator.not_, C())
self.assertFalse(operator.not_(5))
@@ -551,15 +583,16 @@ class OperatorTestCase:
def test_length_hint(self):
operator = self.module
- class X(object):
- def __init__(self, value):
- self.value = value
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X(object):
+ def __init__(self, value):
+ self.value = value
- def __length_hint__(self):
- if type(self.value) is type:
- raise self.value
@ -246,47 +246,47 @@ index d90f820052c..5d9fdfb70a4 100644
+ raise self.value
+ else:
+ return self.value
self.assertEqual(operator.length_hint([], 2), 0)
self.assertEqual(operator.length_hint(iter([1, 2, 3])), 3)
@@ -574,7 +607,8 @@ class OperatorTestCase:
with self.assertRaises(LookupError):
operator.length_hint(X(LookupError))
- class Y: pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Y: pass
msg = "'str' object cannot be interpreted as an integer"
with self.assertRaisesRegex(TypeError, msg):
@@ -628,11 +662,11 @@ class OperatorTestCase:
self.assertEqual(str(sig), '(obj, /)')
-class PyOperatorTestCase(OperatorTestCase, unittest.TestCase):
+class PyOperatorTestCase(OperatorTestCase, __TestCase):
module = py_operator
@unittest.skipUnless(c_operator, 'requires _operator')
-class COperatorTestCase(OperatorTestCase, unittest.TestCase):
+class COperatorTestCase(OperatorTestCase, __TestCase):
module = c_operator
@@ -645,8 +679,9 @@ class OperatorPickleTestCase:
def test_attrgetter(self):
attrgetter = self.module.attrgetter
- class A:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ pass
a = A()
a.x = 'X'
a.y = 'Y'
@@ -688,13 +723,14 @@ class OperatorPickleTestCase:
def test_methodcaller(self):
methodcaller = self.module.methodcaller
- class A:
@ -296,7 +296,7 @@ index d90f820052c..5d9fdfb70a4 100644
- return f
- def baz(*args, **kwds):
- return kwds['name'], kwds['self']
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ def foo(self, *args, **kwds):
+ return args[0] + args[1]
@ -310,31 +310,31 @@ index d90f820052c..5d9fdfb70a4 100644
@@ -717,25 +753,25 @@ class OperatorPickleTestCase:
# Can't test repr consistently with multiple keyword args
self.assertEqual(f2(a), f(a))
-class PyPyOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
+class PyPyOperatorPickleTestCase(OperatorPickleTestCase, __TestCase):
module = py_operator
module2 = py_operator
@unittest.skipUnless(c_operator, 'requires _operator')
-class PyCOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
+class PyCOperatorPickleTestCase(OperatorPickleTestCase, __TestCase):
module = py_operator
module2 = c_operator
@unittest.skipUnless(c_operator, 'requires _operator')
-class CPyOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
+class CPyOperatorPickleTestCase(OperatorPickleTestCase, __TestCase):
module = c_operator
module2 = py_operator
@unittest.skipUnless(c_operator, 'requires _operator')
-class CCOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
+class CCOperatorPickleTestCase(OperatorPickleTestCase, __TestCase):
module = c_operator
module2 = c_operator
if __name__ == "__main__":
- unittest.main()
+ run_tests()

View File

@ -104,7 +104,7 @@ class OperatorTestCase:
def test_eq(self):
operator = self.module
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
def __eq__(self, other):
raise SyntaxError
@ -119,7 +119,7 @@ class OperatorTestCase:
def test_ne(self):
operator = self.module
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
def __ne__(self, other):
raise SyntaxError
@ -267,7 +267,7 @@ class OperatorTestCase:
operator = self.module
self.assertRaises(TypeError, operator.matmul)
self.assertRaises(TypeError, operator.matmul, 42, 42)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class M:
def __matmul__(self, other):
return other - 1
@ -338,7 +338,7 @@ class OperatorTestCase:
def test_truth(self):
operator = self.module
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
def __bool__(self):
raise SyntaxError
@ -373,7 +373,7 @@ class OperatorTestCase:
def test_attrgetter(self):
operator = self.module
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
pass
a = A()
@ -396,7 +396,7 @@ class OperatorTestCase:
self.assertEqual(operator.attrgetter('x','z','y')(record), ('X', 'Z', 'Y'))
self.assertRaises(TypeError, operator.attrgetter, ('x', (), 'y'))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
def __getattr__(self, name):
raise SyntaxError
@ -437,7 +437,7 @@ class OperatorTestCase:
f = operator.itemgetter(10)
self.assertRaises(IndexError, f, a)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
def __getitem__(self, name):
raise SyntaxError
@ -471,7 +471,7 @@ class OperatorTestCase:
self.assertEqual(operator.itemgetter(slice(2, 4))(t), ('c', 'd'))
# interesting sequences
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class T(tuple):
'Tuple subclass'
pass
@ -483,7 +483,7 @@ class OperatorTestCase:
operator = self.module
self.assertRaises(TypeError, operator.methodcaller)
self.assertRaises(TypeError, operator.methodcaller, 12)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
def foo(self, *args, **kwds):
return args[0] + args[1]
@ -509,7 +509,7 @@ class OperatorTestCase:
def test_inplace(self):
operator = self.module
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
def __iadd__ (self, other): return "iadd"
def __iand__ (self, other): return "iand"
@ -550,7 +550,7 @@ class OperatorTestCase:
def test_index(self):
operator = self.module
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X:
def __index__(self):
return 1
@ -570,7 +570,7 @@ class OperatorTestCase:
def test_not_(self):
operator = self.module
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
def __bool__(self):
raise SyntaxError
@ -583,7 +583,7 @@ class OperatorTestCase:
def test_length_hint(self):
operator = self.module
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X(object):
def __init__(self, value):
self.value = value
@ -607,7 +607,7 @@ class OperatorTestCase:
with self.assertRaises(LookupError):
operator.length_hint(X(LookupError))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Y: pass
msg = "'str' object cannot be interpreted as an integer"
@ -679,7 +679,7 @@ class OperatorPickleTestCase:
def test_attrgetter(self):
attrgetter = self.module.attrgetter
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
pass
a = A()
@ -723,7 +723,7 @@ class OperatorPickleTestCase:
def test_methodcaller(self):
methodcaller = self.module.methodcaller
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
def foo(self, *args, **kwds):
return args[0] + args[1]

View File

@ -64,7 +64,7 @@ index a9b6a84996e..efc4288d1a4 100644
import contextlib
import copy
@@ -113,13 +170,14 @@ class OrderedDictTests:
def test_init_calls(self):
calls = []
- class Spam:
@ -74,7 +74,7 @@ index a9b6a84996e..efc4288d1a4 100644
- def items(self):
- calls.append('items')
- return ()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Spam:
+ def keys(self):
+ calls.append('keys')
@ -82,7 +82,7 @@ index a9b6a84996e..efc4288d1a4 100644
+ def items(self):
+ calls.append('items')
+ return ()
self.OrderedDict(Spam())
self.assertEqual(calls, ['keys'])
@@ -129,9 +187,10 @@ class OrderedDictTests:
@ -92,21 +92,21 @@ index a9b6a84996e..efc4288d1a4 100644
- class ODNI(OrderedDict):
- def __init__(*args, **kwargs):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class ODNI(OrderedDict):
+ def __init__(*args, **kwargs):
+ pass
od = ODNI()
od['a'] = 1 # This used to fail because __init__ was bypassed
@@ -267,9 +326,10 @@ class OrderedDictTests:
self.assertEqual(od.pop(k, 12345), 12345)
# make sure pop still works when __missing__ is defined
- class Missing(OrderedDict):
- def __missing__(self, key):
- return 0
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Missing(OrderedDict):
+ def __missing__(self, key):
+ return 0
@ -115,17 +115,17 @@ index a9b6a84996e..efc4288d1a4 100644
self.assertEqual(m.pop('a', 6), 1)
@@ -416,9 +476,10 @@ class OrderedDictTests:
self.assertEqual(od.setdefault('g', default=9), 9)
# make sure setdefault still works when __missing__ is defined
- class Missing(OrderedDict):
- def __missing__(self, key):
- return 0
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Missing(OrderedDict):
+ def __missing__(self, key):
+ return 0
self.assertEqual(Missing().setdefault(5, 9), 9)
def test_reinsert(self):
@@ -484,9 +545,10 @@ class OrderedDictTests:
def test_override_update(self):
@ -134,13 +134,13 @@ index a9b6a84996e..efc4288d1a4 100644
- class MyOD(OrderedDict):
- def update(self, *args, **kwds):
- raise Exception()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyOD(OrderedDict):
+ def update(self, *args, **kwds):
+ raise Exception()
items = [('a', 1), ('c', 3), ('b', 2)]
self.assertEqual(list(MyOD(items).items()), items)
@@ -507,9 +569,10 @@ class OrderedDictTests:
# should not crash Python.
OrderedDict = self.OrderedDict
@ -148,7 +148,7 @@ index a9b6a84996e..efc4288d1a4 100644
- class MyOD(OrderedDict):
- def __del__(self):
- deleted.append(self.i)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyOD(OrderedDict):
+ def __del__(self):
+ deleted.append(self.i)
@ -158,7 +158,7 @@ index a9b6a84996e..efc4288d1a4 100644
@@ -521,19 +584,20 @@ class OrderedDictTests:
def test_delitem_hash_collision(self):
OrderedDict = self.OrderedDict
- class Key:
- def __init__(self, hash):
- self._hash = hash
@ -172,7 +172,7 @@ index a9b6a84996e..efc4288d1a4 100644
- return False
- def __repr__(self):
- return self.value
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key:
+ def __init__(self, hash):
+ self._hash = hash
@ -186,149 +186,149 @@ index a9b6a84996e..efc4288d1a4 100644
+ return False
+ def __repr__(self):
+ return self.value
def blocking_hash(hash):
# See the collision-handling in lookdict (in Objects/dictobject.c).
@@ -560,9 +624,10 @@ class OrderedDictTests:
def test_issue24347(self):
OrderedDict = self.OrderedDict
- class Key:
- def __hash__(self):
- return randrange(100000)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key:
+ def __hash__(self):
+ return randrange(100000)
od = OrderedDict()
for i in range(100):
@@ -582,9 +647,10 @@ class OrderedDictTests:
def test_issue24348(self):
OrderedDict = self.OrderedDict
- class Key:
- def __hash__(self):
- return 1
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key:
+ def __hash__(self):
+ return 1
od = OrderedDict()
od[Key()] = 0
@@ -760,15 +826,16 @@ class _TriggerSideEffectOnEqual:
def side_effect(self):
raise NotImplementedError
-class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase):
+class PurePythonOrderedDictTests(OrderedDictTests, __TestCase):
module = py_coll
OrderedDict = py_coll.OrderedDict
def test_issue119004_attribute_error(self):
- class Key(_TriggerSideEffectOnEqual):
- def side_effect(self):
- del dict1[TODEL]
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key(_TriggerSideEffectOnEqual):
+ def side_effect(self):
+ del dict1[TODEL]
TODEL = Key()
dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2)))
@@ -781,7 +848,7 @@ class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase):
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
-class CPythonBuiltinDictTests(unittest.TestCase):
+class CPythonBuiltinDictTests(__TestCase):
"""Builtin dict preserves insertion order.
Reuse some of tests in OrderedDict selectively.
@@ -800,6 +867,7 @@ for method in (
del method
+
class CPythonOrderedDictSideEffects:
def check_runtime_error_issue119004(self, dict1, dict2):
@@ -807,9 +875,10 @@ class CPythonOrderedDictSideEffects:
self.assertRaisesRegex(RuntimeError, msg, operator.eq, dict1, dict2)
def test_issue119004_change_size_by_clear(self):
- class Key(_TriggerSideEffectOnEqual):
- def side_effect(self):
- dict1.clear()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key(_TriggerSideEffectOnEqual):
+ def side_effect(self):
+ dict1.clear()
dict1 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2)))
dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2)))
@@ -819,9 +888,10 @@ class CPythonOrderedDictSideEffects:
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
def test_issue119004_change_size_by_delete_key(self):
- class Key(_TriggerSideEffectOnEqual):
- def side_effect(self):
- del dict1[TODEL]
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key(_TriggerSideEffectOnEqual):
+ def side_effect(self):
+ del dict1[TODEL]
TODEL = Key()
dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2)))
@@ -832,10 +902,11 @@ class CPythonOrderedDictSideEffects:
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
def test_issue119004_change_linked_list_by_clear(self):
- class Key(_TriggerSideEffectOnEqual):
- def side_effect(self):
- dict1.clear()
- dict1['a'] = dict1['b'] = 'c'
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key(_TriggerSideEffectOnEqual):
+ def side_effect(self):
+ dict1.clear()
+ dict1['a'] = dict1['b'] = 'c'
dict1 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2)))
dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2)))
@@ -845,10 +916,11 @@ class CPythonOrderedDictSideEffects:
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
def test_issue119004_change_linked_list_by_delete_key(self):
- class Key(_TriggerSideEffectOnEqual):
- def side_effect(self):
- del dict1[TODEL]
- dict1['a'] = 'c'
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key(_TriggerSideEffectOnEqual):
+ def side_effect(self):
+ del dict1[TODEL]
+ dict1['a'] = 'c'
TODEL = Key()
dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2)))
@@ -859,10 +931,11 @@ class CPythonOrderedDictSideEffects:
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
def test_issue119004_change_size_by_delete_key_in_dict_eq(self):
- class Key(_TriggerSideEffectOnEqual):
- trigger = 0
- def side_effect(self):
- del dict1[TODEL]
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key(_TriggerSideEffectOnEqual):
+ trigger = 0
+ def side_effect(self):
+ del dict1[TODEL]
TODEL = Key()
dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2)))
@@ -878,7 +951,7 @@ class CPythonOrderedDictSideEffects:
@ -337,25 +337,25 @@ index a9b6a84996e..efc4288d1a4 100644
CPythonOrderedDictSideEffects,
- unittest.TestCase):
+ __TestCase):
module = c_coll
OrderedDict = c_coll.OrderedDict
@@ -986,7 +1059,7 @@ class CPythonOrderedDictSubclassTests(CPythonOrderedDictTests):
pass
-class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
+class PurePythonOrderedDictWithSlotsCopyingTests(__TestCase):
module = py_coll
class OrderedDict(py_coll.OrderedDict):
@@ -995,7 +1068,7 @@ class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
@unittest.skipUnless(c_coll, 'requires the C version of the collections module')
-class CPythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
+class CPythonOrderedDictWithSlotsCopyingTests(__TestCase):
module = c_coll
class OrderedDict(c_coll.OrderedDict):
@@ -1008,6 +1081,7 @@ class PurePythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
@ -363,7 +363,7 @@ index a9b6a84996e..efc4288d1a4 100644
def setUpClass(cls):
cls.type2test = py_coll.OrderedDict
+ super().setUpClass()
def test_popitem(self):
d = self._empty_mapping()
@@ -1020,6 +1094,7 @@ class CPythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
@ -371,7 +371,7 @@ index a9b6a84996e..efc4288d1a4 100644
def setUpClass(cls):
cls.type2test = c_coll.OrderedDict
+ super().setUpClass()
def test_popitem(self):
d = self._empty_mapping()
@@ -1033,6 +1108,7 @@ class PurePythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol):
@ -379,7 +379,7 @@ index a9b6a84996e..efc4288d1a4 100644
pass
cls.type2test = MyOrderedDict
+ super().setUpClass()
def test_popitem(self):
d = self._empty_mapping()
@@ -1047,6 +1123,7 @@ class CPythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol):
@ -387,32 +387,32 @@ index a9b6a84996e..efc4288d1a4 100644
pass
cls.type2test = MyOrderedDict
+ super().setUpClass()
def test_popitem(self):
d = self._empty_mapping()
@@ -1120,21 +1197,22 @@ class SimpleLRUCacheTests:
self.assertEqual(list(c), [1, 3, 2])
-class PySimpleLRUCacheTests(SimpleLRUCacheTests, unittest.TestCase):
+class PySimpleLRUCacheTests(SimpleLRUCacheTests, __TestCase):
class type2test(SimpleLRUCache, py_coll.OrderedDict):
pass
@unittest.skipUnless(c_coll, 'requires the C version of the collections module')
-class CSimpleLRUCacheTests(SimpleLRUCacheTests, unittest.TestCase):
+class CSimpleLRUCacheTests(SimpleLRUCacheTests, __TestCase):
@classmethod
def setUpClass(cls):
class type2test(SimpleLRUCache, c_coll.OrderedDict):
pass
cls.type2test = type2test
+ super().setUpClass()
if __name__ == "__main__":
- unittest.main()
+ run_tests()

View File

@ -170,7 +170,7 @@ class OrderedDictTests:
def test_init_calls(self):
calls = []
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Spam:
def keys(self):
calls.append('keys')
@ -187,7 +187,7 @@ class OrderedDictTests:
# a consistent internal state is created in __new__
# rather than __init__.
OrderedDict = self.OrderedDict
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class ODNI(OrderedDict):
def __init__(*args, **kwargs):
pass
@ -326,7 +326,7 @@ class OrderedDictTests:
self.assertEqual(od.pop(k, 12345), 12345)
# make sure pop still works when __missing__ is defined
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Missing(OrderedDict):
def __missing__(self, key):
return 0
@ -476,7 +476,7 @@ class OrderedDictTests:
self.assertEqual(od.setdefault('g', default=9), 9)
# make sure setdefault still works when __missing__ is defined
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Missing(OrderedDict):
def __missing__(self, key):
return 0
@ -545,7 +545,7 @@ class OrderedDictTests:
def test_override_update(self):
OrderedDict = self.OrderedDict
# Verify that subclasses can override update() without breaking __init__()
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyOD(OrderedDict):
def update(self, *args, **kwds):
raise Exception()
@ -569,7 +569,7 @@ class OrderedDictTests:
# should not crash Python.
OrderedDict = self.OrderedDict
deleted = []
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyOD(OrderedDict):
def __del__(self):
deleted.append(self.i)
@ -584,7 +584,7 @@ class OrderedDictTests:
def test_delitem_hash_collision(self):
OrderedDict = self.OrderedDict
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key:
def __init__(self, hash):
self._hash = hash
@ -624,7 +624,7 @@ class OrderedDictTests:
def test_issue24347(self):
OrderedDict = self.OrderedDict
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key:
def __hash__(self):
return randrange(100000)
@ -647,7 +647,7 @@ class OrderedDictTests:
def test_issue24348(self):
OrderedDict = self.OrderedDict
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key:
def __hash__(self):
return 1
@ -832,7 +832,7 @@ class PurePythonOrderedDictTests(OrderedDictTests, __TestCase):
OrderedDict = py_coll.OrderedDict
def test_issue119004_attribute_error(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key(_TriggerSideEffectOnEqual):
def side_effect(self):
del dict1[TODEL]
@ -875,7 +875,7 @@ class CPythonOrderedDictSideEffects:
self.assertRaisesRegex(RuntimeError, msg, operator.eq, dict1, dict2)
def test_issue119004_change_size_by_clear(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key(_TriggerSideEffectOnEqual):
def side_effect(self):
dict1.clear()
@ -888,7 +888,7 @@ class CPythonOrderedDictSideEffects:
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
def test_issue119004_change_size_by_delete_key(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key(_TriggerSideEffectOnEqual):
def side_effect(self):
del dict1[TODEL]
@ -902,7 +902,7 @@ class CPythonOrderedDictSideEffects:
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
def test_issue119004_change_linked_list_by_clear(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key(_TriggerSideEffectOnEqual):
def side_effect(self):
dict1.clear()
@ -916,7 +916,7 @@ class CPythonOrderedDictSideEffects:
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
def test_issue119004_change_linked_list_by_delete_key(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key(_TriggerSideEffectOnEqual):
def side_effect(self):
del dict1[TODEL]
@ -931,7 +931,7 @@ class CPythonOrderedDictSideEffects:
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
def test_issue119004_change_size_by_delete_key_in_dict_eq(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key(_TriggerSideEffectOnEqual):
trigger = 0
def side_effect(self):

View File

@ -62,23 +62,23 @@ index d9102eb98a5..c8ee5ca451f 100644
@@ -38,7 +91,7 @@ class HashCountingInt(int):
self.hash_count += 1
return int.__hash__(self)
-class TestJointOps:
+class _TestJointOps:
# Tests common to both set and frozenset
def setUp(self):
@@ -47,6 +100,7 @@ class TestJointOps:
self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
self.s = self.thetype(word)
self.d = dict.fromkeys(word)
+ super().setUp()
def test_new_or_init(self):
self.assertRaises(TypeError, self.thetype, [], 2)
@@ -261,13 +315,14 @@ class TestJointOps:
self.assertEqual(self.thetype(it), data - self.thetype((drop,)))
def test_deepcopy(self):
- class Tracer:
- def __init__(self, value):
@ -87,7 +87,7 @@ index d9102eb98a5..c8ee5ca451f 100644
- return self.value
- def __deepcopy__(self, memo=None):
- return Tracer(self.value + 1)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Tracer:
+ def __init__(self, value):
+ self.value = value
@ -99,25 +99,25 @@ index d9102eb98a5..c8ee5ca451f 100644
s = self.thetype([t])
dup = copy.deepcopy(s)
@@ -279,8 +334,9 @@ class TestJointOps:
def test_gc(self):
# Create a nest of cycles to exercise overall ref count check
- class A:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ pass
s = set(A() for i in range(1000))
for elem in s:
elem.cycle = s
@@ -289,9 +345,10 @@ class TestJointOps:
def test_subclass_with_custom_hash(self):
# Bug #1257731
- class H(self.thetype):
- def __hash__(self):
- return int(id(self) & 0x7fffffff)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class H(self.thetype):
+ def __hash__(self):
+ return int(id(self) & 0x7fffffff)
@ -125,12 +125,12 @@ index d9102eb98a5..c8ee5ca451f 100644
f=set()
f.add(s)
@@ -342,8 +399,9 @@ class TestJointOps:
def test_container_iterator(self):
# Bug #3680: tp_traverse was not implemented for set iterator object
- class C(object):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ pass
obj = C()
@ -139,15 +139,15 @@ index d9102eb98a5..c8ee5ca451f 100644
@@ -355,7 +413,7 @@ class TestJointOps:
def test_free_after_iterating(self):
support.check_free_after_iterating(self, iter, self.thetype)
-class TestSet(TestJointOps, unittest.TestCase):
+class TestSet(_TestJointOps, __TestCase):
thetype = set
basetype = set
@@ -600,19 +658,20 @@ class TestSet(TestJointOps, unittest.TestCase):
self.assertRaises(ReferenceError, str, p)
def test_rich_compare(self):
- class TestRichSetCompare:
- def __gt__(self, some_set):
@ -162,7 +162,7 @@ index d9102eb98a5..c8ee5ca451f 100644
- def __le__(self, some_set):
- self.le_called = True
- return False
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class TestRichSetCompare:
+ def __gt__(self, some_set):
+ self.gt_called = True
@ -176,16 +176,16 @@ index d9102eb98a5..c8ee5ca451f 100644
+ def __le__(self, some_set):
+ self.le_called = True
+ return False
# This first tries the builtin rich set comparison, which doesn't know
# how to handle the custom object. Upon returning NotImplemented, the
@@ -644,28 +703,31 @@ class TestSetSubclass(TestSet):
basetype = set
def test_keywords_in_subclass(self):
- class subclass(set):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass(set):
+ pass
u = subclass([1, 2])
@ -193,12 +193,12 @@ index d9102eb98a5..c8ee5ca451f 100644
self.assertEqual(set(u), {1, 2})
with self.assertRaises(TypeError):
subclass(sequence=())
- class subclass_with_init(set):
- def __init__(self, arg, newarg=None):
- super().__init__(arg)
- self.newarg = newarg
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_init(set):
+ def __init__(self, arg, newarg=None):
+ super().__init__(arg)
@ -207,13 +207,13 @@ index d9102eb98a5..c8ee5ca451f 100644
self.assertIs(type(u), subclass_with_init)
self.assertEqual(set(u), {1, 2})
self.assertEqual(u.newarg, 3)
- class subclass_with_new(set):
- def __new__(cls, arg, newarg=None):
- self = super().__new__(cls, arg)
- self.newarg = newarg
- return self
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_new(set):
+ def __new__(cls, arg, newarg=None):
+ self = super().__new__(cls, arg)
@ -224,20 +224,20 @@ index d9102eb98a5..c8ee5ca451f 100644
self.assertEqual(set(u), {1, 2})
@@ -675,7 +737,7 @@ class TestSetSubclass(TestSet):
subclass_with_new([1, 2], newarg=3)
-class TestFrozenSet(TestJointOps, unittest.TestCase):
+class TestFrozenSet(_TestJointOps, __TestCase):
thetype = frozenset
basetype = frozenset
@@ -756,27 +818,30 @@ class TestFrozenSetSubclass(TestFrozenSet):
basetype = frozenset
def test_keywords_in_subclass(self):
- class subclass(frozenset):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass(frozenset):
+ pass
u = subclass([1, 2])
@ -245,11 +245,11 @@ index d9102eb98a5..c8ee5ca451f 100644
self.assertEqual(set(u), {1, 2})
with self.assertRaises(TypeError):
subclass(sequence=())
- class subclass_with_init(frozenset):
- def __init__(self, arg, newarg=None):
- self.newarg = newarg
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_init(frozenset):
+ def __init__(self, arg, newarg=None):
+ self.newarg = newarg
@ -257,13 +257,13 @@ index d9102eb98a5..c8ee5ca451f 100644
self.assertIs(type(u), subclass_with_init)
self.assertEqual(set(u), {1, 2})
self.assertEqual(u.newarg, 3)
- class subclass_with_new(frozenset):
- def __new__(cls, arg, newarg=None):
- self = super().__new__(cls, arg)
- self.newarg = newarg
- return self
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_new(frozenset):
+ def __new__(cls, arg, newarg=None):
+ self = super().__new__(cls, arg)
@ -275,7 +275,7 @@ index d9102eb98a5..c8ee5ca451f 100644
@@ -811,10 +876,17 @@ class TestFrozenSetSubclass(TestFrozenSet):
class SetSubclassWithSlots(set):
__slots__ = ('x', 'y', '__dict__')
-class TestSetSubclassWithSlots(unittest.TestCase):
+class TestSetSubclassWithSlots(__TestCase):
thetype = SetSubclassWithSlots
@ -290,22 +290,22 @@ index d9102eb98a5..c8ee5ca451f 100644
+ self.s = self.thetype(word)
+ self.d = dict.fromkeys(word)
+ super().setUp()
class FrozenSetSubclassWithSlots(frozenset):
__slots__ = ('x', 'y', '__dict__')
@@ -828,7 +900,7 @@ empty_set = set()
#==============================================================================
-class TestBasicOps:
+class _TestBasicOps:
def test_repr(self):
if self.repr is not None:
@@ -934,7 +1006,7 @@ class TestBasicOps:
#------------------------------------------------------------------------------
-class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase):
+class TestBasicOpsEmpty(_TestBasicOps, __TestCase):
def setUp(self):
@ -316,9 +316,9 @@ index d9102eb98a5..c8ee5ca451f 100644
self.length = 0
self.repr = "set()"
+ super().setUp()
#------------------------------------------------------------------------------
-class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
+class TestBasicOpsSingleton(_TestBasicOps, __TestCase):
def setUp(self):
@ -329,13 +329,13 @@ index d9102eb98a5..c8ee5ca451f 100644
self.length = 1
self.repr = "{3}"
+ super().setUp()
def test_in(self):
self.assertIn(3, self.set)
@@ -962,7 +1036,7 @@ class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
#------------------------------------------------------------------------------
-class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
+class TestBasicOpsTuple(_TestBasicOps, __TestCase):
def setUp(self):
@ -346,13 +346,13 @@ index d9102eb98a5..c8ee5ca451f 100644
self.length = 1
self.repr = "{(0, 'zero')}"
+ super().setUp()
def test_in(self):
self.assertIn((0, "zero"), self.set)
@@ -979,7 +1054,7 @@ class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
#------------------------------------------------------------------------------
-class TestBasicOpsTriple(TestBasicOps, unittest.TestCase):
+class TestBasicOpsTriple(_TestBasicOps, __TestCase):
def setUp(self):
@ -363,9 +363,9 @@ index d9102eb98a5..c8ee5ca451f 100644
self.length = 3
self.repr = None
+ super().setUp()
#------------------------------------------------------------------------------
-class TestBasicOpsString(TestBasicOps, unittest.TestCase):
+class TestBasicOpsString(_TestBasicOps, __TestCase):
def setUp(self):
@ -375,12 +375,12 @@ index d9102eb98a5..c8ee5ca451f 100644
self.dup = set(self.values)
self.length = 3
+ super().setUp()
def test_repr(self):
self.check_repr_against_values()
#------------------------------------------------------------------------------
-class TestBasicOpsBytes(TestBasicOps, unittest.TestCase):
+class TestBasicOpsBytes(_TestBasicOps, __TestCase):
def setUp(self):
@ -390,12 +390,12 @@ index d9102eb98a5..c8ee5ca451f 100644
self.dup = set(self.values)
self.length = 3
+ super().setUp()
def test_repr(self):
self.check_repr_against_values()
#------------------------------------------------------------------------------
-class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
+class TestBasicOpsMixedStringBytes(_TestBasicOps, __TestCase):
def setUp(self):
@ -406,71 +406,71 @@ index d9102eb98a5..c8ee5ca451f 100644
self.dup = set(self.values)
self.length = 4
+ super().setUp()
def test_repr(self):
self.check_repr_against_values()
@@ -1038,7 +1117,7 @@ def baditer():
def gooditer():
yield True
-class TestExceptionPropagation(unittest.TestCase):
+class TestExceptionPropagation(__TestCase):
"""SF 628246: Set constructor should not trap iterator TypeErrors"""
def test_instanceWithException(self):
@@ -1065,7 +1144,7 @@ class TestExceptionPropagation(unittest.TestCase):
#==============================================================================
-class TestSetOfSets(unittest.TestCase):
+class TestSetOfSets(__TestCase):
def test_constructor(self):
inner = frozenset([1])
outer = set([inner])
@@ -1078,9 +1157,10 @@ class TestSetOfSets(unittest.TestCase):
#==============================================================================
-class TestBinaryOps(unittest.TestCase):
+class TestBinaryOps(__TestCase):
def setUp(self):
self.set = set((2, 4, 6))
+ super().setUp()
def test_eq(self): # SF bug 643115
self.assertEqual(self.set, set({2:1,4:3,6:5}))
@@ -1151,9 +1231,10 @@ class TestBinaryOps(unittest.TestCase):
#==============================================================================
-class TestUpdateOps(unittest.TestCase):
+class TestUpdateOps(__TestCase):
def setUp(self):
self.set = set((2, 4, 6))
+ super().setUp()
def test_union_subset(self):
self.set |= set([2])
@@ -1237,10 +1318,11 @@ class TestUpdateOps(unittest.TestCase):
#==============================================================================
-class TestMutate(unittest.TestCase):
+class TestMutate(__TestCase):
def setUp(self):
self.values = ["a", "b", "c"]
self.set = set(self.values)
+ super().setUp()
def test_add_present(self):
self.set.add("c")
@@ -1311,7 +1393,7 @@ class TestMutate(unittest.TestCase):
#==============================================================================
-class TestSubsets:
+class _TestSubsets:
case2method = {"<=": "issubset",
">=": "issuperset",
@@ -1334,22 +1416,22 @@ class TestSubsets:
@ -483,7 +483,7 @@ index d9102eb98a5..c8ee5ca451f 100644
+ method = getattr(x, _TestSubsets.case2method[case])
result = method(y)
self.assertEqual(result, expected)
# Now do the same for the operands reversed.
- rcase = TestSubsets.reverse[case]
+ rcase = _TestSubsets.reverse[case]
@ -496,61 +496,61 @@ index d9102eb98a5..c8ee5ca451f 100644
result = method(x)
self.assertEqual(result, expected)
#------------------------------------------------------------------------------
-class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase):
+class TestSubsetEqualEmpty(_TestSubsets, __TestCase):
left = set()
right = set()
name = "both empty"
@@ -1357,7 +1439,7 @@ class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase):
#------------------------------------------------------------------------------
-class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase):
+class TestSubsetEqualNonEmpty(_TestSubsets, __TestCase):
left = set([1, 2])
right = set([1, 2])
name = "equal pair"
@@ -1365,7 +1447,7 @@ class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase):
#------------------------------------------------------------------------------
-class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase):
+class TestSubsetEmptyNonEmpty(_TestSubsets, __TestCase):
left = set()
right = set([1, 2])
name = "one empty, one non-empty"
@@ -1373,7 +1455,7 @@ class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase):
#------------------------------------------------------------------------------
-class TestSubsetPartial(TestSubsets, unittest.TestCase):
+class TestSubsetPartial(_TestSubsets, __TestCase):
left = set([1])
right = set([1, 2])
name = "one a non-empty proper subset of other"
@@ -1381,7 +1463,7 @@ class TestSubsetPartial(TestSubsets, unittest.TestCase):
#------------------------------------------------------------------------------
-class TestSubsetNonOverlap(TestSubsets, unittest.TestCase):
+class TestSubsetNonOverlap(_TestSubsets, __TestCase):
left = set([1])
right = set([2])
name = "neither empty, neither contains"
@@ -1389,7 +1471,7 @@ class TestSubsetNonOverlap(TestSubsets, unittest.TestCase):
#==============================================================================
-class TestOnlySetsInBinaryOps:
+class _TestOnlySetsInBinaryOps:
def test_eq_ne(self):
# Unlike the others, this is testing that == and != *are* allowed.
@@ -1505,47 +1587,52 @@ class TestOnlySetsInBinaryOps:
#------------------------------------------------------------------------------
-class TestOnlySetsNumeric(TestOnlySetsInBinaryOps, unittest.TestCase):
+class TestOnlySetsNumeric(_TestOnlySetsInBinaryOps, __TestCase):
def setUp(self):
@ -558,9 +558,9 @@ index d9102eb98a5..c8ee5ca451f 100644
self.other = 19
self.otherIsIterable = False
+ super().setUp()
#------------------------------------------------------------------------------
-class TestOnlySetsDict(TestOnlySetsInBinaryOps, unittest.TestCase):
+class TestOnlySetsDict(_TestOnlySetsInBinaryOps, __TestCase):
def setUp(self):
@ -568,9 +568,9 @@ index d9102eb98a5..c8ee5ca451f 100644
self.other = {1:2, 3:4}
self.otherIsIterable = True
+ super().setUp()
#------------------------------------------------------------------------------
-class TestOnlySetsOperator(TestOnlySetsInBinaryOps, unittest.TestCase):
+class TestOnlySetsOperator(_TestOnlySetsInBinaryOps, __TestCase):
def setUp(self):
@ -578,9 +578,9 @@ index d9102eb98a5..c8ee5ca451f 100644
self.other = operator.add
self.otherIsIterable = False
+ super().setUp()
#------------------------------------------------------------------------------
-class TestOnlySetsTuple(TestOnlySetsInBinaryOps, unittest.TestCase):
+class TestOnlySetsTuple(_TestOnlySetsInBinaryOps, __TestCase):
def setUp(self):
@ -588,9 +588,9 @@ index d9102eb98a5..c8ee5ca451f 100644
self.other = (2, 4, 6)
self.otherIsIterable = True
+ super().setUp()
#------------------------------------------------------------------------------
-class TestOnlySetsString(TestOnlySetsInBinaryOps, unittest.TestCase):
+class TestOnlySetsString(_TestOnlySetsInBinaryOps, __TestCase):
def setUp(self):
@ -598,9 +598,9 @@ index d9102eb98a5..c8ee5ca451f 100644
self.other = 'abc'
self.otherIsIterable = True
+ super().setUp()
#------------------------------------------------------------------------------
-class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase):
+class TestOnlySetsGenerator(_TestOnlySetsInBinaryOps, __TestCase):
def setUp(self):
@ -611,80 +611,80 @@ index d9102eb98a5..c8ee5ca451f 100644
self.other = gen()
self.otherIsIterable = True
+ super().setUp()
#==============================================================================
-class TestCopying:
+class _TestCopying:
def test_copy(self):
dup = self.set.copy()
@@ -1577,40 +1665,46 @@ class TestCopying:
#------------------------------------------------------------------------------
-class TestCopyingEmpty(TestCopying, unittest.TestCase):
+class TestCopyingEmpty(_TestCopying, __TestCase):
def setUp(self):
self.set = set()
+ super().setUp()
#------------------------------------------------------------------------------
-class TestCopyingSingleton(TestCopying, unittest.TestCase):
+class TestCopyingSingleton(_TestCopying, __TestCase):
def setUp(self):
self.set = set(["hello"])
+ super().setUp()
#------------------------------------------------------------------------------
-class TestCopyingTriple(TestCopying, unittest.TestCase):
+class TestCopyingTriple(_TestCopying, __TestCase):
def setUp(self):
self.set = set(["zero", 0, None])
+ super().setUp()
#------------------------------------------------------------------------------
-class TestCopyingTuple(TestCopying, unittest.TestCase):
+class TestCopyingTuple(_TestCopying, __TestCase):
def setUp(self):
self.set = set([(1, 2)])
+ super().setUp()
#------------------------------------------------------------------------------
-class TestCopyingNested(TestCopying, unittest.TestCase):
+class TestCopyingNested(_TestCopying, __TestCase):
def setUp(self):
self.set = set([((1, 2), (3, 4))])
+ super().setUp()
#==============================================================================
-class TestIdentities(unittest.TestCase):
+class TestIdentities(__TestCase):
def setUp(self):
self.a = set('abracadabra')
self.b = set('alacazam')
+ super().setUp()
def test_binopsVsSubsets(self):
a, b = self.a, self.b
@@ -1727,7 +1821,7 @@ def L(seqn):
'Test multiple tiers of iterators'
return chain(map(lambda x:x, R(Ig(G(seqn)))))
-class TestVariousIteratorArgs(unittest.TestCase):
+class TestVariousIteratorArgs(__TestCase):
def test_constructor(self):
for cons in (set, frozenset):
@@ -1785,7 +1879,7 @@ class bad_dict_clear:
def __hash__(self):
return 0
-class TestWeirdBugs(unittest.TestCase):
+class TestWeirdBugs(__TestCase):
def test_8420_set_merge(self):
@ -692,7 +692,7 @@ index d9102eb98a5..c8ee5ca451f 100644
global be_bad, set2, dict2
@@ -1813,12 +1907,13 @@ class TestWeirdBugs(unittest.TestCase):
list(si)
def test_merge_and_mutate(self):
- class X:
- def __hash__(self):
@ -700,27 +700,27 @@ index d9102eb98a5..c8ee5ca451f 100644
- def __eq__(self, o):
- other.clear()
- return False
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X:
+ def __hash__(self):
+ return hash(0)
+ def __eq__(self, o):
+ other.clear()
+ return False
other = set()
other = {X() for i in range(10)}
@@ -1826,24 +1921,25 @@ class TestWeirdBugs(unittest.TestCase):
s.update(other)
-class TestOperationsMutating:
+class _TestOperationsMutating:
"""Regression test for bpo-46615"""
constructor1 = None
constructor2 = None
def make_sets_of_bad_objects(self):
- class Bad:
- def __eq__(self, other):
@ -733,7 +733,7 @@ index d9102eb98a5..c8ee5ca451f 100644
- return bool(randrange(2))
- def __hash__(self):
- return randrange(2)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Bad:
+ def __eq__(self, other):
+ if not enabled:
@ -750,89 +750,89 @@ index d9102eb98a5..c8ee5ca451f 100644
set1 = self.constructor1(Bad() for _ in range(randrange(50)))
@@ -1862,7 +1958,7 @@ class TestOperationsMutating:
self.assertIn("changed size during iteration", str(e))
-class TestBinaryOpsMutating(TestOperationsMutating):
+class _TestBinaryOpsMutating(_TestOperationsMutating):
def test_eq_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a == b)
@@ -1933,24 +2029,24 @@ class TestBinaryOpsMutating(TestOperationsMutating):
self.check_set_op_does_not_crash(f3)
-class TestBinaryOpsMutating_Set_Set(TestBinaryOpsMutating, unittest.TestCase):
+class TestBinaryOpsMutating_Set_Set(_TestBinaryOpsMutating, __TestCase):
constructor1 = set
constructor2 = set
-class TestBinaryOpsMutating_Subclass_Subclass(TestBinaryOpsMutating, unittest.TestCase):
+class TestBinaryOpsMutating_Subclass_Subclass(_TestBinaryOpsMutating, __TestCase):
constructor1 = SetSubclass
constructor2 = SetSubclass
-class TestBinaryOpsMutating_Set_Subclass(TestBinaryOpsMutating, unittest.TestCase):
+class TestBinaryOpsMutating_Set_Subclass(_TestBinaryOpsMutating, __TestCase):
constructor1 = set
constructor2 = SetSubclass
-class TestBinaryOpsMutating_Subclass_Set(TestBinaryOpsMutating, unittest.TestCase):
+class TestBinaryOpsMutating_Subclass_Set(_TestBinaryOpsMutating, __TestCase):
constructor1 = SetSubclass
constructor2 = set
-class TestMethodsMutating(TestOperationsMutating):
+class _TestMethodsMutating(_TestOperationsMutating):
def test_issubset_with_mutation(self):
self.check_set_op_does_not_crash(set.issubset)
@@ -1986,27 +2082,27 @@ class TestMethodsMutating(TestOperationsMutating):
self.check_set_op_does_not_crash(set.update)
-class TestMethodsMutating_Set_Set(TestMethodsMutating, unittest.TestCase):
+class TestMethodsMutating_Set_Set(_TestMethodsMutating, __TestCase):
constructor1 = set
constructor2 = set
-class TestMethodsMutating_Subclass_Subclass(TestMethodsMutating, unittest.TestCase):
+class TestMethodsMutating_Subclass_Subclass(_TestMethodsMutating, __TestCase):
constructor1 = SetSubclass
constructor2 = SetSubclass
-class TestMethodsMutating_Set_Subclass(TestMethodsMutating, unittest.TestCase):
+class TestMethodsMutating_Set_Subclass(_TestMethodsMutating, __TestCase):
constructor1 = set
constructor2 = SetSubclass
-class TestMethodsMutating_Subclass_Set(TestMethodsMutating, unittest.TestCase):
+class TestMethodsMutating_Subclass_Set(_TestMethodsMutating, __TestCase):
constructor1 = SetSubclass
constructor2 = set
-class TestMethodsMutating_Set_Dict(TestMethodsMutating, unittest.TestCase):
+class TestMethodsMutating_Set_Dict(_TestMethodsMutating, __TestCase):
constructor1 = set
constructor2 = dict.fromkeys
-class TestMethodsMutating_Set_List(TestMethodsMutating, unittest.TestCase):
+class TestMethodsMutating_Set_List(_TestMethodsMutating, __TestCase):
constructor1 = set
constructor2 = list
@@ -2068,7 +2164,7 @@ def faces(G):
return f
-class TestGraphs(unittest.TestCase):
+class TestGraphs(__TestCase):
def test_cube(self):
@@ -2118,4 +2214,4 @@ class TestGraphs(unittest.TestCase):
#==============================================================================
if __name__ == "__main__":
- unittest.main()
+ run_tests()

View File

@ -315,7 +315,7 @@ class _TestJointOps:
self.assertEqual(self.thetype(it), data - self.thetype((drop,)))
def test_deepcopy(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Tracer:
def __init__(self, value):
self.value = value
@ -334,7 +334,7 @@ class _TestJointOps:
def test_gc(self):
# Create a nest of cycles to exercise overall ref count check
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
pass
s = set(A() for i in range(1000))
@ -345,7 +345,7 @@ class _TestJointOps:
def test_subclass_with_custom_hash(self):
# Bug #1257731
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class H(self.thetype):
def __hash__(self):
return int(id(self) & 0x7fffffff)
@ -399,7 +399,7 @@ class _TestJointOps:
def test_container_iterator(self):
# Bug #3680: tp_traverse was not implemented for set iterator object
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
pass
obj = C()
@ -658,7 +658,7 @@ class TestSet(_TestJointOps, __TestCase):
self.assertRaises(ReferenceError, str, p)
def test_rich_compare(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class TestRichSetCompare:
def __gt__(self, some_set):
self.gt_called = True
@ -703,7 +703,7 @@ class TestSetSubclass(TestSet):
basetype = set
def test_keywords_in_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass(set):
pass
u = subclass([1, 2])
@ -712,7 +712,7 @@ class TestSetSubclass(TestSet):
with self.assertRaises(TypeError):
subclass(sequence=())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_init(set):
def __init__(self, arg, newarg=None):
super().__init__(arg)
@ -722,7 +722,7 @@ class TestSetSubclass(TestSet):
self.assertEqual(set(u), {1, 2})
self.assertEqual(u.newarg, 3)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_new(set):
def __new__(cls, arg, newarg=None):
self = super().__new__(cls, arg)
@ -818,7 +818,7 @@ class TestFrozenSetSubclass(TestFrozenSet):
basetype = frozenset
def test_keywords_in_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass(frozenset):
pass
u = subclass([1, 2])
@ -827,7 +827,7 @@ class TestFrozenSetSubclass(TestFrozenSet):
with self.assertRaises(TypeError):
subclass(sequence=())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_init(frozenset):
def __init__(self, arg, newarg=None):
self.newarg = newarg
@ -836,7 +836,7 @@ class TestFrozenSetSubclass(TestFrozenSet):
self.assertEqual(set(u), {1, 2})
self.assertEqual(u.newarg, 3)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_new(frozenset):
def __new__(cls, arg, newarg=None):
self = super().__new__(cls, arg)
@ -1907,7 +1907,7 @@ class TestWeirdBugs(__TestCase):
list(si)
def test_merge_and_mutate(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X:
def __hash__(self):
return hash(0)
@ -1928,7 +1928,7 @@ class _TestOperationsMutating:
constructor2 = None
def make_sets_of_bad_objects(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Bad:
def __eq__(self, other):
if not enabled:

View File

@ -63,7 +63,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
@@ -39,7 +93,7 @@ def check(tag, expected, raw, compare=None):
nerrors += 1
return
-class TestBase(unittest.TestCase):
+class TestBase(__TestCase):
def testStressfully(self):
@ -72,18 +72,18 @@ index 2a7cfb7affa..4805f1fcceb 100644
@@ -48,32 +102,33 @@ class TestBase(unittest.TestCase):
sizes.extend(range(n-1, n+2))
sizes.extend([10, 100, 1000])
- class Complains(object):
- maybe_complain = True
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Complains(object):
+ maybe_complain = True
- def __init__(self, i):
- self.i = i
+ def __init__(self, i):
+ self.i = i
- def __lt__(self, other):
- if Complains.maybe_complain and random.random() < 0.001:
- if verbose:
@ -96,12 +96,12 @@ index 2a7cfb7affa..4805f1fcceb 100644
+ print(" complaining at", self, other)
+ raise RuntimeError
+ return self.i < other.i
- def __repr__(self):
- return "Complains(%d)" % self.i
+ def __repr__(self):
+ return "Complains(%d)" % self.i
- class Stable(object):
- def __init__(self, key, i):
- self.key = key
@ -110,31 +110,31 @@ index 2a7cfb7affa..4805f1fcceb 100644
+ def __init__(self, key, i):
+ self.key = key
+ self.index = i
- def __lt__(self, other):
- return self.key < other.key
+ def __lt__(self, other):
+ return self.key < other.key
- def __repr__(self):
- return "Stable(%d, %d)" % (self.key, self.index)
+ def __repr__(self):
+ return "Stable(%d, %d)" % (self.key, self.index)
for n in sizes:
x = list(range(n))
@@ -151,20 +206,21 @@ class TestBase(unittest.TestCase):
self.assertEqual(forced, native)
#==============================================================================
-class TestBugs(unittest.TestCase):
+class TestBugs(__TestCase):
def test_bug453523(self):
# bug 453523 -- list.sort() crasher.
# If this fails, the most likely outcome is a core dump.
# Mutations during a list sort should raise a ValueError.
- class C:
- def __lt__(self, other):
- if L and random.random() < 0.75:
@ -142,7 +142,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
- else:
- L.append(3)
- return random.random() < 0.5
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ def __lt__(self, other):
+ if L and random.random() < 0.75:
@ -150,20 +150,20 @@ index 2a7cfb7affa..4805f1fcceb 100644
+ else:
+ L.append(3)
+ return random.random() < 0.5
L = [C() for i in range(50)]
self.assertRaises(ValueError, L.sort)
@@ -188,7 +244,7 @@ class TestBugs(unittest.TestCase):
#==============================================================================
-class TestDecorateSortUndecorate(unittest.TestCase):
+class TestDecorateSortUndecorate(__TestCase):
def test_decorated(self):
data = 'The quick Brown fox Jumped over The lazy Dog'.split()
@@ -228,26 +284,28 @@ class TestDecorateSortUndecorate(unittest.TestCase):
def test_key_with_mutating_del(self):
data = list(range(10))
- class SortKiller(object):
@ -174,7 +174,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
- data[:] = range(20)
- def __lt__(self, other):
- return id(self) < id(other)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class SortKiller(object):
+ def __init__(self, x):
+ pass
@ -184,7 +184,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
+ def __lt__(self, other):
+ return id(self) < id(other)
self.assertRaises(ValueError, data.sort, key=SortKiller)
def test_key_with_mutating_del_and_exception(self):
data = list(range(10))
## dup = data[:]
@ -195,7 +195,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
- def __del__(self):
- del data[:]
- data[:] = list(range(20))
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class SortKiller(object):
+ def __init__(self, x):
+ if x > 2:
@ -209,7 +209,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
@@ -309,7 +367,7 @@ def check_against_PyObject_RichCompareBool(self, L):
self.assertIs(opt, ref)
#note: not assertEqual! We want to ensure *identical* behavior.
-class TestOptimizedCompares(unittest.TestCase):
+class TestOptimizedCompares(__TestCase):
def test_safe_object_compare(self):
@ -218,39 +218,39 @@ index 2a7cfb7affa..4805f1fcceb 100644
@@ -331,17 +389,18 @@ class TestOptimizedCompares(unittest.TestCase):
# This test is by ppperry. It ensures that unsafe_object_compare is
# verifying ms->key_richcompare == tp->richcompare before comparing.
- class WackyComparator(int):
- def __lt__(self, other):
- elem.__class__ = WackyList2
- return int.__lt__(self, other)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class WackyComparator(int):
+ def __lt__(self, other):
+ elem.__class__ = WackyList2
+ return int.__lt__(self, other)
- class WackyList1(list):
- pass
+ class WackyList1(list):
+ pass
- class WackyList2(list):
- def __lt__(self, other):
- raise ValueError
+ class WackyList2(list):
+ def __lt__(self, other):
+ raise ValueError
L = [WackyList1([WackyComparator(i), i]) for i in range(10)]
elem = L[-1]
@@ -355,9 +414,10 @@ class TestOptimizedCompares(unittest.TestCase):
# The following test is also by ppperry. It ensures that
# unsafe_object_compare handles Py_NotImplemented appropriately.
- class PointlessComparator:
- def __lt__(self, other):
- return NotImplemented
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class PointlessComparator:
+ def __lt__(self, other):
+ return NotImplemented
@ -259,7 +259,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
self.assertRaises(TypeError, [(x,) for x in L].sort)
@@ -408,4 +468,4 @@ class TestOptimizedCompares(unittest.TestCase):
#==============================================================================
if __name__ == "__main__":
- unittest.main()
+ run_tests()

View File

@ -102,7 +102,7 @@ class TestBase(__TestCase):
sizes.extend(range(n-1, n+2))
sizes.extend([10, 100, 1000])
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Complains(object):
maybe_complain = True
@ -213,7 +213,7 @@ class TestBugs(__TestCase):
# If this fails, the most likely outcome is a core dump.
# Mutations during a list sort should raise a ValueError.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
def __lt__(self, other):
if L and random.random() < 0.75:
@ -284,7 +284,7 @@ class TestDecorateSortUndecorate(__TestCase):
def test_key_with_mutating_del(self):
data = list(range(10))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class SortKiller(object):
def __init__(self, x):
pass
@ -298,7 +298,7 @@ class TestDecorateSortUndecorate(__TestCase):
def test_key_with_mutating_del_and_exception(self):
data = list(range(10))
## dup = data[:]
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class SortKiller(object):
def __init__(self, x):
if x > 2:
@ -389,7 +389,7 @@ class TestOptimizedCompares(__TestCase):
# This test is by ppperry. It ensures that unsafe_object_compare is
# verifying ms->key_richcompare == tp->richcompare before comparing.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class WackyComparator(int):
def __lt__(self, other):
elem.__class__ = WackyList2
@ -414,7 +414,7 @@ class TestOptimizedCompares(__TestCase):
# The following test is also by ppperry. It ensures that
# unsafe_object_compare handles Py_NotImplemented appropriately.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class PointlessComparator:
def __lt__(self, other):
return NotImplemented

View File

@ -60,15 +60,15 @@ index 9ce80c5e8ea..1080e85e31a 100644
+from test import support
+import seq_tests
import unittest
import gc
@@ -43,27 +97,30 @@ class TupleTest(seq_tests.CommonTest):
tuple(sequence=())
def test_keywords_in_subclass(self):
- class subclass(tuple):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass(tuple):
+ pass
u = subclass([1, 2])
@ -76,11 +76,11 @@ index 9ce80c5e8ea..1080e85e31a 100644
self.assertEqual(list(u), [1, 2])
with self.assertRaises(TypeError):
subclass(sequence=())
- class subclass_with_init(tuple):
- def __init__(self, arg, newarg=None):
- self.newarg = newarg
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_init(tuple):
+ def __init__(self, arg, newarg=None):
+ self.newarg = newarg
@ -88,13 +88,13 @@ index 9ce80c5e8ea..1080e85e31a 100644
self.assertIs(type(u), subclass_with_init)
self.assertEqual(list(u), [1, 2])
self.assertEqual(u.newarg, 3)
- class subclass_with_new(tuple):
- def __new__(cls, arg, newarg=None):
- self = super().__new__(cls, arg)
- self.newarg = newarg
- return self
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_new(tuple):
+ def __new__(cls, arg, newarg=None):
+ self = super().__new__(cls, arg)
@ -109,25 +109,25 @@ index 9ce80c5e8ea..1080e85e31a 100644
# Tuple subtypes must always be tracked
- class MyTuple(tuple):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyTuple(tuple):
+ pass
self.check_track_dynamic(MyTuple, True)
@support.cpython_only
@@ -404,7 +462,8 @@ class TupleTest(seq_tests.CommonTest):
# Issue 8847: In the PGO build, the MSVC linker's COMDAT folding
# optimization causes failures in code that relies on distinct
# function addresses.
- class T(tuple): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class T(tuple): pass
with self.assertRaises(TypeError):
[3,] + T((1,2))
@@ -510,4 +569,4 @@ class TupleTest(seq_tests.CommonTest):
# pileup 262,143 mean 8.0 coll 262,143 z +92683.6
if __name__ == "__main__":
- unittest.main()
+ run_tests()

View File

@ -97,7 +97,7 @@ class TupleTest(seq_tests.CommonTest):
tuple(sequence=())
def test_keywords_in_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass(tuple):
pass
u = subclass([1, 2])
@ -106,7 +106,7 @@ class TupleTest(seq_tests.CommonTest):
with self.assertRaises(TypeError):
subclass(sequence=())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_init(tuple):
def __init__(self, arg, newarg=None):
self.newarg = newarg
@ -115,7 +115,7 @@ class TupleTest(seq_tests.CommonTest):
self.assertEqual(list(u), [1, 2])
self.assertEqual(u.newarg, 3)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_new(tuple):
def __new__(cls, arg, newarg=None):
self = super().__new__(cls, arg)
@ -408,7 +408,7 @@ class TupleTest(seq_tests.CommonTest):
@support.cpython_only
def test_track_subtypes(self):
# Tuple subtypes must always be tracked
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyTuple(tuple):
pass
self.check_track_dynamic(MyTuple, True)
@ -462,7 +462,7 @@ class TupleTest(seq_tests.CommonTest):
# Issue 8847: In the PGO build, the MSVC linker's COMDAT folding
# optimization causes failures in code that relies on distinct
# function addresses.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class T(tuple): pass
with self.assertRaises(TypeError):
[3,] + T((1,2))

View File

@ -58,29 +58,29 @@ index 312702c8e39..d3d8dbf394a 100644
+# ======= END DYNAMO PATCH =======
+
# Check every path through every method of UserList
from collections import UserList
-from test import list_tests
+import list_tests
import unittest
from test import support
@@ -56,9 +110,10 @@ class UserListTest(list_tests.CommonTest):
def test_getitemoverwriteiter(self):
# Verify that __getitem__ overrides *are* recognized by __iter__
- class T(self.type2test):
- def __getitem__(self, key):
- return str(key) + '!!!'
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class T(self.type2test):
+ def __getitem__(self, key):
+ return str(key) + '!!!'
self.assertEqual(next(iter(T((1,2)))), "0!!!")
def test_userlist_copy(self):
@@ -69,9 +124,9 @@ class UserListTest(list_tests.CommonTest):
# Decorate existing test with recursion limit, because
# the test is for C structure, but `UserList` is a Python structure.
- test_repr_deep = support.infinite_recursion(25)(
@ -89,7 +89,7 @@ index 312702c8e39..d3d8dbf394a 100644
+ # test_repr_deep = support.infinite_recursion(25)(
+ # list_tests.CommonTest.test_repr_deep,
+ # )
if __name__ == "__main__":
- unittest.main()
+ run_tests()

View File

@ -110,7 +110,7 @@ class UserListTest(list_tests.CommonTest):
def test_getitemoverwriteiter(self):
# Verify that __getitem__ overrides *are* recognized by __iter__
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class T(self.type2test):
def __getitem__(self, key):
return str(key) + '!!!'

View File

@ -24,84 +24,84 @@ index 8e9ed8500c7..66c18ad886a 100644
+# ======= END DYNAMO PATCH =======
+
"""Unit tests for the with statement specified in PEP 343."""
@@ -104,16 +124,17 @@ class MockNested(Nested):
return Nested.__exit__(self, *exc_info)
-class FailureTestCase(unittest.TestCase):
+class FailureTestCase(__TestCase):
def testNameError(self):
def fooNotDeclared():
with foo: pass
self.assertRaises(NameError, fooNotDeclared)
def testEnterAttributeError1(self):
- class LacksEnter(object):
- def __exit__(self, type, value, traceback):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class LacksEnter(object):
+ def __exit__(self, type, value, traceback):
+ pass
def fooLacksEnter():
foo = LacksEnter()
@@ -121,8 +142,9 @@ class FailureTestCase(unittest.TestCase):
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnter)
def testEnterAttributeError2(self):
- class LacksEnterAndExit(object):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class LacksEnterAndExit(object):
+ pass
def fooLacksEnterAndExit():
foo = LacksEnterAndExit()
@@ -130,9 +152,10 @@ class FailureTestCase(unittest.TestCase):
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnterAndExit)
def testExitAttributeError(self):
- class LacksExit(object):
- def __enter__(self):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class LacksExit(object):
+ def __enter__(self):
+ pass
def fooLacksExit():
foo = LacksExit()
@@ -162,11 +185,12 @@ class FailureTestCase(unittest.TestCase):
' pass')
def testEnterThrows(self):
- class EnterThrows(object):
- def __enter__(self):
- raise RuntimeError("Enter threw")
- def __exit__(self, *args):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class EnterThrows(object):
+ def __enter__(self):
+ raise RuntimeError("Enter threw")
+ def __exit__(self, *args):
+ pass
def shouldThrow():
ct = EnterThrows()
@@ -180,11 +204,12 @@ class FailureTestCase(unittest.TestCase):
self.assertEqual(self.foo, None)
def testExitThrows(self):
- class ExitThrows(object):
- def __enter__(self):
- return
- def __exit__(self, *args):
- raise RuntimeError(42)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class ExitThrows(object):
+ def __enter__(self):
+ return
@ -111,17 +111,17 @@ index 8e9ed8500c7..66c18ad886a 100644
with ExitThrows():
pass
@@ -194,6 +219,7 @@ class ContextmanagerAssertionMixin(object):
def setUp(self):
self.TEST_EXCEPTION = RuntimeError("test exception")
+ super().setUp()
def assertInWithManagerInvariants(self, mock_manager):
self.assertTrue(mock_manager.enter_called)
@@ -237,7 +263,7 @@ class ContextmanagerAssertionMixin(object):
self.assertTrue(mock_generator.stopped)
-class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
+class NonexceptionalTestCase(__TestCase, ContextmanagerAssertionMixin):
def testInlineGeneratorSyntax(self):
@ -129,8 +129,8 @@ index 8e9ed8500c7..66c18ad886a 100644
pass
@@ -289,7 +315,7 @@ class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
self.assertAfterWithGeneratorInvariantsNoError(foo)
-class NestedNonexceptionalTestCase(unittest.TestCase,
+class NestedNonexceptionalTestCase(__TestCase,
ContextmanagerAssertionMixin):
@ -138,15 +138,15 @@ index 8e9ed8500c7..66c18ad886a 100644
with Nested(mock_contextmanager_generator()):
@@ -355,7 +381,7 @@ class NestedNonexceptionalTestCase(unittest.TestCase,
self.assertAfterWithManagerInvariantsNoError(mock_nested)
-class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
+class ExceptionalTestCase(ContextmanagerAssertionMixin, __TestCase):
def testSingleResource(self):
cm = mock_contextmanager_generator()
def shouldThrow():
@@ -466,11 +492,12 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
def testRaisedStopIteration2(self):
# From bug 1462485
- class cm(object):
@ -154,17 +154,17 @@ index 8e9ed8500c7..66c18ad886a 100644
- pass
- def __exit__(self, type, value, traceback):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class cm(object):
+ def __enter__(self):
+ pass
+ def __exit__(self, type, value, traceback):
+ pass
def shouldThrow():
with cm():
@@ -507,11 +534,12 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
def testRaisedGeneratorExit2(self):
# From bug 1462485
- class cm (object):
@ -172,19 +172,19 @@ index 8e9ed8500c7..66c18ad886a 100644
- pass
- def __exit__(self, type, value, traceback):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class cm (object):
+ def __enter__(self):
+ pass
+ def __exit__(self, type, value, traceback):
+ pass
def shouldThrow():
with cm():
@@ -523,16 +551,17 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
# issue4589: __exit__ return code may raise an exception
# when looking at its truth value.
- class cm(object):
- def __init__(self, bool_conversion):
- class Bool:
@ -195,7 +195,7 @@ index 8e9ed8500c7..66c18ad886a 100644
- return 3
- def __exit__(self, a, b, c):
- return self.exit_result
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class cm(object):
+ def __init__(self, bool_conversion):
+ class Bool:
@ -206,25 +206,25 @@ index 8e9ed8500c7..66c18ad886a 100644
+ return 3
+ def __exit__(self, a, b, c):
+ return self.exit_result
def trueAsBool():
with cm(lambda: True):
@@ -550,7 +579,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
self.assertRaises(ZeroDivisionError, failAsBool)
-class NonLocalFlowControlTestCase(unittest.TestCase):
+class NonLocalFlowControlTestCase(__TestCase):
def testWithBreak(self):
counter = 0
@@ -607,7 +636,7 @@ class NonLocalFlowControlTestCase(unittest.TestCase):
self.fail("Didn't raise RuntimeError")
-class AssignmentTargetTestCase(unittest.TestCase):
+class AssignmentTargetTestCase(__TestCase):
def testSingleComplexTarget(self):
targets = {1: [0, 1, 2]}
@@ -621,15 +650,17 @@ class AssignmentTargetTestCase(unittest.TestCase):
@ -232,17 +232,17 @@ index 8e9ed8500c7..66c18ad886a 100644
keys.sort()
self.assertEqual(keys, [1, 2])
- class C: pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C: pass
blah = C()
with mock_contextmanager_generator() as blah.foo:
self.assertEqual(hasattr(blah, "foo"), True)
def testMultipleComplexTargets(self):
- class C:
- def __enter__(self): return 1, 2, 3
- def __exit__(self, t, v, tb): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ def __enter__(self): return 1, 2, 3
+ def __exit__(self, t, v, tb): pass
@ -254,23 +254,23 @@ index 8e9ed8500c7..66c18ad886a 100644
with C() as (targets[1], targets[2], targets[3]):
self.assertEqual(targets, {1: 1, 2: 2, 3: 3})
- class B: pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class B: pass
blah = B()
with C() as (blah.one, blah.two, blah.three):
self.assertEqual(blah.one, 1)
@@ -651,12 +683,13 @@ class AssignmentTargetTestCase(unittest.TestCase):
self.assertEqual(c, 4)
-class ExitSwallowsExceptionTestCase(unittest.TestCase):
+class ExitSwallowsExceptionTestCase(__TestCase):
def testExitTrueSwallowsException(self):
- class AfricanSwallow:
- def __enter__(self): pass
- def __exit__(self, t, v, tb): return True
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class AfricanSwallow:
+ def __enter__(self): pass
+ def __exit__(self, t, v, tb): return True
@ -279,12 +279,12 @@ index 8e9ed8500c7..66c18ad886a 100644
1/0
@@ -664,9 +697,10 @@ class ExitSwallowsExceptionTestCase(unittest.TestCase):
self.fail("ZeroDivisionError should have been swallowed")
def testExitFalseDoesntSwallowException(self):
- class EuropeanSwallow:
- def __enter__(self): pass
- def __exit__(self, t, v, tb): return False
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class EuropeanSwallow:
+ def __enter__(self): pass
+ def __exit__(self, t, v, tb): return False
@ -293,16 +293,16 @@ index 8e9ed8500c7..66c18ad886a 100644
1/0
@@ -676,7 +710,7 @@ class ExitSwallowsExceptionTestCase(unittest.TestCase):
self.fail("ZeroDivisionError should have been raised")
-class NestedWith(unittest.TestCase):
+class NestedWith(__TestCase):
class Dummy(object):
def __init__(self, value=None, gobble=False):
@@ -796,4 +830,4 @@ class NestedWith(unittest.TestCase):
if __name__ == '__main__':
- unittest.main()
+ run_tests()

View File

@ -131,7 +131,7 @@ class FailureTestCase(__TestCase):
self.assertRaises(NameError, fooNotDeclared)
def testEnterAttributeError1(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class LacksEnter(object):
def __exit__(self, type, value, traceback):
pass
@ -142,7 +142,7 @@ class FailureTestCase(__TestCase):
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnter)
def testEnterAttributeError2(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class LacksEnterAndExit(object):
pass
@ -152,7 +152,7 @@ class FailureTestCase(__TestCase):
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnterAndExit)
def testExitAttributeError(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class LacksExit(object):
def __enter__(self):
pass
@ -185,7 +185,7 @@ class FailureTestCase(__TestCase):
' pass')
def testEnterThrows(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class EnterThrows(object):
def __enter__(self):
raise RuntimeError("Enter threw")
@ -204,7 +204,7 @@ class FailureTestCase(__TestCase):
self.assertEqual(self.foo, None)
def testExitThrows(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class ExitThrows(object):
def __enter__(self):
return
@ -492,7 +492,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, __TestCase):
def testRaisedStopIteration2(self):
# From bug 1462485
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class cm(object):
def __enter__(self):
pass
@ -534,7 +534,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, __TestCase):
def testRaisedGeneratorExit2(self):
# From bug 1462485
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class cm (object):
def __enter__(self):
pass
@ -551,7 +551,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, __TestCase):
# issue4589: __exit__ return code may raise an exception
# when looking at its truth value.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class cm(object):
def __init__(self, bool_conversion):
class Bool:
@ -650,14 +650,14 @@ class AssignmentTargetTestCase(__TestCase):
keys = list(targets.keys())
keys.sort()
self.assertEqual(keys, [1, 2])
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C: pass
blah = C()
with mock_contextmanager_generator() as blah.foo:
self.assertEqual(hasattr(blah, "foo"), True)
def testMultipleComplexTargets(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
def __enter__(self): return 1, 2, 3
def __exit__(self, t, v, tb): pass
@ -668,7 +668,7 @@ class AssignmentTargetTestCase(__TestCase):
self.assertEqual(targets, {1: [3, 2, 1]})
with C() as (targets[1], targets[2], targets[3]):
self.assertEqual(targets, {1: 1, 2: 2, 3: 3})
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class B: pass
blah = B()
with C() as (blah.one, blah.two, blah.three):
@ -686,7 +686,7 @@ class AssignmentTargetTestCase(__TestCase):
class ExitSwallowsExceptionTestCase(__TestCase):
def testExitTrueSwallowsException(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class AfricanSwallow:
def __enter__(self): pass
def __exit__(self, t, v, tb): return True
@ -697,7 +697,7 @@ class ExitSwallowsExceptionTestCase(__TestCase):
self.fail("ZeroDivisionError should have been swallowed")
def testExitFalseDoesntSwallowException(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class EuropeanSwallow:
def __enter__(self): pass
def __exit__(self, t, v, tb): return False

View File

@ -1721,13 +1721,13 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
):
f4(torch.randn(3))
def test_set_fullgraph(self):
def test_error_on_graph_break(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
def f1(x):
x = x + 1
with torch._dynamo.set_fullgraph(False):
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.graph_break()
return x + 2
@ -1738,7 +1738,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
@torch.compile(backend=cnts)
def f2(x):
x = x + 1
with torch._dynamo.set_fullgraph(True):
with torch._dynamo.error_on_graph_break(True):
torch._dynamo.graph_break()
return x + 2
@ -1748,7 +1748,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
@torch.compile(backend=cnts, fullgraph=True)
def f3(x):
x = x + 1
with torch._dynamo.set_fullgraph(False):
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.graph_break()
x = x + 2
torch._dynamo.graph_break()
@ -1766,7 +1766,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
@torch.compile(backend=cnts, fullgraph=True)
def f4(x):
x = x + 1
with torch._dynamo.set_fullgraph(False):
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.skip_frame()
return inner_f4(x)
@ -1774,11 +1774,11 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
self.assertEqual(f4(inp), inp + 7)
self.assertEqual(cnts.frame_count, 2)
def test_set_fullgraph_nested(self):
# set_fullgraph in a nested frame
def test_error_on_graph_break_nested(self):
# error_on_graph_break in a nested frame
cnts = torch._dynamo.testing.CompileCounter()
@torch._dynamo.set_fullgraph(False)
@torch._dynamo.error_on_graph_break(False)
def inner_f5(x):
x = x + 2
torch._dynamo.graph_break()
@ -1795,7 +1795,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def inner_f6(x):
x = x + 2
with torch._dynamo.set_fullgraph(False):
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.graph_break()
return x + 4
@ -1810,7 +1810,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def inner_f7(x):
x = x + 2
with torch._dynamo.set_fullgraph(True):
with torch._dynamo.error_on_graph_break(True):
torch._dynamo.graph_break()
return x + 4
@ -1822,18 +1822,18 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
with self.assertRaises(Unsupported):
f7(inp)
def test_set_fullgraph_nested_with_skip(self):
# set_fullgraph in a nested frame with a skipped frame in between
def test_error_on_graph_break_nested_with_skip(self):
# error_on_graph_break in a nested frame with a skipped frame in between
cnts = torch._dynamo.testing.CompileCounter()
@torch._dynamo.set_fullgraph(False)
@torch._dynamo.error_on_graph_break(False)
def inner2_f8(x):
x = x + 2
torch._dynamo.graph_break()
return x + 4
def inner1_f8(x):
with torch._dynamo.set_fullgraph(False):
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.skip_frame()
return inner2_f8(x)
@ -1848,7 +1848,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def inner2_f9(x):
x = x + 2
with torch._dynamo.set_fullgraph(True):
with torch._dynamo.error_on_graph_break(True):
torch._dynamo.graph_break()
return x + 4
@ -1864,10 +1864,10 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
with self.assertRaises(Unsupported):
f9(inp)
# test export with set_fullgraph(False) still errors
# test export with error_on_graph_break(False) still errors
def test_set_fullgraph_export(self):
@torch._dynamo.set_fullgraph(False)
def test_error_on_graph_break_export(self):
@torch._dynamo.error_on_graph_break(False)
def inner(x):
x = x + 2
torch._dynamo.graph_break()
@ -1880,7 +1880,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
with self.assertRaises(Unsupported):
torch._dynamo.export(f)(torch.ones(3))
def test_set_fullgraph_nested_deep(self):
def test_error_on_graph_break_nested_deep(self):
cnts = torch._dynamo.testing.CompileCounter()
def inner1_f1(x):
@ -1892,7 +1892,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
return inner1_f1(x)
def inner3_f1(x):
with torch._dynamo.set_fullgraph(False):
with torch._dynamo.error_on_graph_break(False):
return inner2_f1(x)
def inner4_f1(x):
@ -1916,7 +1916,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
return inner1_f2(x)
def inner3_f2(x):
with torch._dynamo.set_fullgraph(True):
with torch._dynamo.error_on_graph_break(True):
return inner2_f2(x)
def inner4_f2(x):
@ -1930,20 +1930,20 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
with self.assertRaises(Unsupported):
f2(inp)
def test_set_fullgraph_error(self):
def test_error_on_graph_break_error(self):
@torch.compile(backend="eager")
def f1():
with torch._dynamo.set_fullgraph(foo="bar"):
with torch._dynamo.error_on_graph_break(foo="bar"):
pass
@torch.compile(backend="eager")
def f2():
with torch._dynamo.set_fullgraph():
with torch._dynamo.error_on_graph_break():
pass
@torch.compile(backend="eager")
def f3():
with torch._dynamo.set_fullgraph("foo"):
with torch._dynamo.error_on_graph_break("foo"):
pass
with self.assertRaises(Exception):

View File

@ -21,6 +21,7 @@ from .decorators import (
disable,
disallow_in_graph,
dont_skip_tracing,
error_on_graph_break,
forbid_in_graph,
graph_break,
mark_dynamic,
@ -30,7 +31,6 @@ from .decorators import (
nonstrict_trace,
patch_dynamo_config,
run,
set_fullgraph,
set_stance,
skip_frame,
substitute_in_graph,
@ -90,7 +90,7 @@ __all__ = [
"replay",
"reset",
"run",
"set_fullgraph",
"error_on_graph_break",
"set_stance",
"skip_frame",
"substitute_in_graph",

View File

@ -1747,7 +1747,7 @@ def replay(filename: str) -> None:
record = ExecutionRecord.load(in_file)
record.globals = dict(itertools.chain(record.globals.items(), globals().items()))
with decorators.set_fullgraph(fullgraph=False):
with decorators.error_on_graph_break(False):
try:
_compile(
record.code,

View File

@ -918,15 +918,15 @@ def dont_skip_tracing(fn: Optional[Any] = None) -> Any:
return ctx
class SetFullgraphDecoratorContextManager:
def __init__(self, fullgraph: bool) -> None:
self.fullgraph = fullgraph
class ErrorOnGraphBreakDecoratorContextManager:
def __init__(self, error_on_graph_break: bool) -> None:
self.error_on_graph_break = error_on_graph_break
__call__ = wrap_dunder_call_ctx_manager
def __enter__(self) -> None:
self.prev_fullgraph = _get_error_on_graph_break()
_set_error_on_graph_break(self.fullgraph)
self.prev_error_on_graph_break = _get_error_on_graph_break()
_set_error_on_graph_break(self.error_on_graph_break)
def __exit__(
self,
@ -934,14 +934,16 @@ class SetFullgraphDecoratorContextManager:
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
_set_error_on_graph_break(self.prev_fullgraph)
_set_error_on_graph_break(self.prev_error_on_graph_break)
def set_fullgraph(fullgraph: bool) -> SetFullgraphDecoratorContextManager:
def error_on_graph_break(
error_on_graph_break: bool,
) -> ErrorOnGraphBreakDecoratorContextManager:
"""
Context manager/decorator to toggle fullgraph setting.
Context manager/decorator to toggle error_on_graph_break (i.e. torch.compile's fullgraph) setting.
More precisely, when encountering a graph break, we will decide to resume (fullgraph=False)
or error out (fullgraph=True) based on the fullgraph setting at the location of the graph break.
"""
return SetFullgraphDecoratorContextManager(fullgraph)
return ErrorOnGraphBreakDecoratorContextManager(error_on_graph_break)

View File

@ -858,7 +858,9 @@ class _TorchDynamoContext:
# hooks to properly handle inlining
compile_wrapper._torchdynamo_inline = ( # type: ignore[attr-defined]
external_utils.wrap_inline_with_set_fullgraph(fn, self.error_on_graph_break)
external_utils.wrap_inline_with_error_on_graph_break(
fn, self.error_on_graph_break
)
)
# Save the function pointer to find the original callable while nesting

View File

@ -229,23 +229,23 @@ def call_accumulate_grad(
variable.grad = updated_grad[0]
def wrap_inline_with_set_fullgraph(
fn: Callable[_P, _R], fullgraph: bool
def wrap_inline_with_error_on_graph_break(
fn: Callable[_P, _R], error_on_graph_break: bool
) -> Callable[_P, _R]:
# NB: need multiple definitions in order to prevent `fullgraph` from
# being a freevar of wrapper
if fullgraph:
if error_on_graph_break:
@functools.wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
with torch._dynamo.set_fullgraph(True):
with torch._dynamo.error_on_graph_break(True):
return fn(*args, **kwargs)
else:
@functools.wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
with torch._dynamo.set_fullgraph(False):
with torch._dynamo.error_on_graph_break(False):
return fn(*args, **kwargs)
return wrapper

View File

@ -348,7 +348,7 @@ manual_torch_name_rule_map: dict[
"torch._dynamo.mark_static": UserFunctionVariable,
"torch._dynamo.nonstrict_trace": UserFunctionVariable,
"torch._dynamo.patch_dynamo_config": UserFunctionVariable,
"torch._dynamo.set_fullgraph": UserFunctionVariable,
"torch._dynamo.error_on_graph_break": UserFunctionVariable,
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
"torch.fx.experimental.symbolic_shapes.guard_or_true": TorchInGraphFunctionVariable,
"torch.fx.experimental.symbolic_shapes.guard_or_false": TorchInGraphFunctionVariable,

View File

@ -27,6 +27,7 @@ from .ctx_manager import (
DisabledSavedTensorsHooksVariable,
DualLevelContextManager,
DynamoConfigPatchVariable,
ErrorOnGraphBreakVariable,
FSDPParamGroupUseTrainingStateVariable,
GradIncrementNestingCtxManagerVariable,
GradInplaceRequiresGradCtxManagerVariable,
@ -34,7 +35,6 @@ from .ctx_manager import (
InferenceModeVariable,
JvpIncrementNestingCtxManagerVariable,
SDPAKernelVariable,
SetFullgraphVariable,
SetFwdGradEnabledContextManager,
StreamContextVariable,
StreamVariable,
@ -200,7 +200,7 @@ __all__ = [
"RemovableHandleVariable",
"RepeatIteratorVariable",
"SDPAParamsVariable",
"SetFullgraphVariable",
"ErrorOnGraphBreakVariable",
"SkipFunctionVariable",
"SliceVariable",
"StringFormatVariable",

View File

@ -168,10 +168,10 @@ from .constant import ConstantVariable, EnumVariable
from .ctx_manager import (
AutocastModeVariable,
DynamoConfigPatchVariable,
ErrorOnGraphBreakVariable,
EventVariable,
NullContextVariable,
PreserveVersionContextVariable,
SetFullgraphVariable,
StreamContextVariable,
StreamVariable,
)
@ -630,7 +630,7 @@ class VariableBuilder:
from ..decorators import (
DynamoConfigPatchProxy,
SetFullgraphDecoratorContextManager,
ErrorOnGraphBreakDecoratorContextManager,
)
if has_triton():
@ -988,8 +988,8 @@ class VariableBuilder:
)
elif isinstance(value, DynamoConfigPatchProxy):
return DynamoConfigPatchVariable(value.changes)
elif isinstance(value, SetFullgraphDecoratorContextManager):
return SetFullgraphVariable(value.fullgraph)
elif isinstance(value, ErrorOnGraphBreakDecoratorContextManager):
return ErrorOnGraphBreakVariable(value.error_on_graph_break)
elif callable(value) and trace_rules.lookup_callable(value) is not None:
if trace_rules.is_callable_allowed(value):
self.tx.output.has_user_defined_allowed_in_graph = True

View File

@ -1429,12 +1429,12 @@ class DynamoConfigPatchVariable(ContextWrappingVariable):
return "patch_dynamo_config"
class SetFullgraphVariable(ContextWrappingVariable):
"""represents torch._dynamo.set_fullgraph"""
class ErrorOnGraphBreakVariable(ContextWrappingVariable):
"""represents torch._dynamo.error_on_graph_break"""
def __init__(self, fullgraph, **kwargs) -> None:
def __init__(self, error_on_graph_break, **kwargs) -> None:
super().__init__(
target_values=(fullgraph,),
target_values=(error_on_graph_break,),
initial_values=(_get_error_on_graph_break(),),
**kwargs,
)
@ -1447,7 +1447,7 @@ class SetFullgraphVariable(ContextWrappingVariable):
return "torch._dynamo"
def fn_name(self):
return "set_fullgraph"
return "error_on_graph_break"
class WithExitFunctionVariable(VariableTracker):

View File

@ -521,15 +521,17 @@ class UserFunctionVariable(BaseUserFunctionVariable):
"Please fix your call to patch_dynamo_config by using simpler inputs. "
f"args: {args}, kwargs: {kwargs}"
) from e
elif self.fn is torch._dynamo.set_fullgraph:
elif self.fn is torch._dynamo.error_on_graph_break:
try:
bound = inspect.signature(self.fn).bind(*args, **kwargs)
fullgraph = bound.arguments["fullgraph"].as_python_constant()
assert isinstance(fullgraph, bool)
return variables.SetFullgraphVariable(fullgraph)
error_on_graph_break = bound.arguments[
"error_on_graph_break"
].as_python_constant()
assert isinstance(error_on_graph_break, bool)
return variables.ErrorOnGraphBreakVariable(error_on_graph_break)
except Exception as e:
raise RuntimeError(
"Improper set_fullgraph() call. Please fix your call to set_fullgraph(). "
"Improper error_on_graph_break() call. Please fix your call to error_on_graph_break(). "
f"args: {args}, kwargs: {kwargs}"
) from e
# Handle a `nonstrict_trace(fn)` call