[dynamo] rename set_fullgraph to error_on_graph_break (#161739)

Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph.

I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet).

 cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739
Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
This commit is contained in:
William Wen 2025-09-03 14:25:40 -07:00 committed by PyTorch MergeBot
parent 1281470155
commit 8678d831c4
57 changed files with 1577 additions and 1571 deletions

View File

@ -90,7 +90,7 @@ index dbc5ef4f9f2..af717703053 100644
- raise StopIteration
- def __length_hint__(self):
- return sys.maxsize
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class CustomIter:
+ def __iter__(self):
+ return self
@ -107,7 +107,7 @@ index dbc5ef4f9f2..af717703053 100644
- class BadExc(Exception):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadExc(Exception):
+ pass
@ -128,7 +128,7 @@ index dbc5ef4f9f2..af717703053 100644
- class BadCmp2:
- def __eq__(self, other):
- raise BadExc()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadCmp2:
+ def __eq__(self, other):
+ raise BadExc()
@ -146,7 +146,7 @@ index dbc5ef4f9f2..af717703053 100644
- def __eq__(self, other):
- del self.victim[:]
- return False
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Test modifying the list during index's iteration
+ class EvilCmp:
+ def __init__(self, victim):

View File

@ -319,7 +319,7 @@ class CommonTest(seq_tests.CommonTest):
self.assertRaises(TypeError, a.extend)
# overflow test. issue1621
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class CustomIter:
def __iter__(self):
return self
@ -387,7 +387,7 @@ class CommonTest(seq_tests.CommonTest):
a = self.type2test([NEVER_EQ])
self.assertRaises(ValueError, a.remove, ALWAYS_EQ)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadExc(Exception):
pass
@ -400,7 +400,7 @@ class CommonTest(seq_tests.CommonTest):
a = self.type2test([0, 1, 2, 3])
self.assertRaises(BadExc, a.remove, BadCmp())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadCmp2:
def __eq__(self, other):
raise BadExc()
@ -428,7 +428,7 @@ class CommonTest(seq_tests.CommonTest):
self.assertRaises(ValueError, a.index, 2, 0, 4)
self.assertEqual(a, self.type2test([-2, -1, 0, 1, 2]))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Test modifying the list during index's iteration
class EvilCmp:
def __init__(self, victim):

View File

@ -79,7 +79,7 @@ index ed89a81a6ea..b19cec7cb23 100644
- return self.d.keys()
- def __getitem__(self, i):
- return self.d[i]
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class SimpleUserDict:
+ def __init__(self):
+ self.d = outerself.reference
@ -94,14 +94,14 @@ index ed89a81a6ea..b19cec7cb23 100644
self.assertEqual(i1, i2)
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
d = self._empty_mapping()
- class FailingUserDict:
- def keys(self):
- raise Exc
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class FailingUserDict:
+ def keys(self):
+ raise Exc
@ -124,7 +124,7 @@ index ed89a81a6ea..b19cec7cb23 100644
- return BogonIter()
- def __getitem__(self, key):
- return key
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class FailingUserDict:
+ def keys(self):
+ class BogonIter:
@ -158,7 +158,7 @@ index ed89a81a6ea..b19cec7cb23 100644
- return BogonIter()
- def __getitem__(self, key):
- raise Exc
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class FailingUserDict:
+ def keys(self):
+ class BogonIter:
@ -183,7 +183,7 @@ index ed89a81a6ea..b19cec7cb23 100644
- return self
- def __next__(self):
- raise Exc()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class badseq(object):
+ def __iter__(self):
+ return self
@ -203,7 +203,7 @@ index ed89a81a6ea..b19cec7cb23 100644
- return self.d.keys()
- def __getitem__(self, i):
- return self.d[i]
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class SimpleUserDict:
+ def __init__(self):
+ self.d = {1:1, 2:2, 3:3}
@ -219,7 +219,7 @@ index ed89a81a6ea..b19cec7cb23 100644
self.assertEqual(d.fromkeys(g()), {1:None})
self.assertRaises(TypeError, {}.fromkeys, 3)
- class dictlike(self.type2test): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class dictlike(self.type2test): pass
self.assertEqual(dictlike.fromkeys('a'), {'a':None})
self.assertEqual(dictlike().fromkeys('a'), {'a':None})
@ -229,7 +229,7 @@ index ed89a81a6ea..b19cec7cb23 100644
- class mydict(self.type2test):
- def __new__(cls):
- return collections.UserDict()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class mydict(self.type2test):
+ def __new__(cls):
+ return collections.UserDict()
@ -239,7 +239,7 @@ index ed89a81a6ea..b19cec7cb23 100644
self.assertRaises(TypeError, dict.fromkeys)
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class baddict1(self.type2test):
@ -256,7 +256,7 @@ index ed89a81a6ea..b19cec7cb23 100644
- return self
- def __next__(self):
- raise Exc()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadSeq(object):
+ def __iter__(self):
+ return self
@ -268,7 +268,7 @@ index ed89a81a6ea..b19cec7cb23 100644
- class baddict2(self.type2test):
- def __setitem__(self, key, value):
- raise Exc()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class baddict2(self.type2test):
+ def __setitem__(self, key, value):
+ raise Exc()
@ -280,7 +280,7 @@ index ed89a81a6ea..b19cec7cb23 100644
def test_getitem(self):
TestMappingProtocol.test_getitem(self)
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class BadEq(object):
@ -305,7 +305,7 @@ index ed89a81a6ea..b19cec7cb23 100644
- raise Exc()
- else:
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadHash(object):
+ fail = False
+ def __hash__(self):
@ -323,7 +323,7 @@ index ed89a81a6ea..b19cec7cb23 100644
- class mydict(self.type2test):
- def __new__(cls):
- return collections.UserDict()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class mydict(self.type2test):
+ def __new__(cls):
+ return collections.UserDict()
@ -335,7 +335,7 @@ index ed89a81a6ea..b19cec7cb23 100644
TestMappingProtocol.test_pop(self)
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class BadHash(object):
@ -360,7 +360,7 @@ index ed89a81a6ea..b19cec7cb23 100644
self.assertEqual(repr(d), '{1: {...}}')
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class BadRepr(object):
@ -377,7 +377,7 @@ index ed89a81a6ea..b19cec7cb23 100644
self._full_mapping({1: 2}))
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class BadCmp(object):
@ -398,7 +398,7 @@ index ed89a81a6ea..b19cec7cb23 100644
TestMappingProtocol.test_setdefault(self)
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class BadHash(object):

View File

@ -250,7 +250,7 @@ class BasicTestMappingProtocol(__TestCase):
self.assertRaises((TypeError, AttributeError), d.update, 42)
outerself = self
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class SimpleUserDict:
def __init__(self):
self.d = outerself.reference
@ -264,11 +264,11 @@ class BasicTestMappingProtocol(__TestCase):
i2 = sorted(self.reference.items())
self.assertEqual(i1, i2)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
d = self._empty_mapping()
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class FailingUserDict:
def keys(self):
raise Exc
@ -276,7 +276,7 @@ class BasicTestMappingProtocol(__TestCase):
d.clear()
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class FailingUserDict:
def keys(self):
class BogonIter:
@ -294,7 +294,7 @@ class BasicTestMappingProtocol(__TestCase):
return key
self.assertRaises(Exc, d.update, FailingUserDict())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class FailingUserDict:
def keys(self):
class BogonIter:
@ -314,7 +314,7 @@ class BasicTestMappingProtocol(__TestCase):
self.assertRaises(Exc, d.update, FailingUserDict())
d = self._empty_mapping()
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class badseq(object):
def __iter__(self):
return self
@ -469,7 +469,7 @@ class TestMappingProtocol(BasicTestMappingProtocol):
d.update(self._full_mapping({1:2, 3:4, 5:6}).items())
self.assertEqual(d, {1:2, 2:4, 3:4, 5:6})
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class SimpleUserDict:
def __init__(self):
self.d = {1:1, 2:2, 3:3}
@ -492,14 +492,14 @@ class TestMappingProtocol(BasicTestMappingProtocol):
yield 1
self.assertEqual(d.fromkeys(g()), {1:None})
self.assertRaises(TypeError, {}.fromkeys, 3)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class dictlike(self.type2test): pass
self.assertEqual(dictlike.fromkeys('a'), {'a':None})
self.assertEqual(dictlike().fromkeys('a'), {'a':None})
self.assertTrue(dictlike.fromkeys('a').__class__ is dictlike)
self.assertTrue(dictlike().fromkeys('a').__class__ is dictlike)
self.assertTrue(type(dictlike.fromkeys('a')) is dictlike)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class mydict(self.type2test):
def __new__(cls):
return collections.UserDict()
@ -508,7 +508,7 @@ class TestMappingProtocol(BasicTestMappingProtocol):
self.assertIsInstance(ud, collections.UserDict)
self.assertRaises(TypeError, dict.fromkeys)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class baddict1(self.type2test):
@ -517,7 +517,7 @@ class TestMappingProtocol(BasicTestMappingProtocol):
self.assertRaises(Exc, baddict1.fromkeys, [1])
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadSeq(object):
def __iter__(self):
return self
@ -526,7 +526,7 @@ class TestMappingProtocol(BasicTestMappingProtocol):
self.assertRaises(Exc, self.type2test.fromkeys, BadSeq())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class baddict2(self.type2test):
def __setitem__(self, key, value):
raise Exc()
@ -603,7 +603,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
def test_getitem(self):
TestMappingProtocol.test_getitem(self)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadEq(object):
@ -616,7 +616,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
d[BadEq()] = 42
self.assertRaises(KeyError, d.__getitem__, 23)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadHash(object):
fail = False
def __hash__(self):
@ -633,7 +633,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
def test_fromkeys(self):
TestMappingProtocol.test_fromkeys(self)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class mydict(self.type2test):
def __new__(cls):
return collections.UserDict()
@ -644,7 +644,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
def test_pop(self):
TestMappingProtocol.test_pop(self)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadHash(object):
@ -683,7 +683,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
d[1] = d
self.assertEqual(repr(d), '{1: {...}}')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadRepr(object):
@ -706,7 +706,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
self.assertEqual(self._full_mapping({1: 2}),
self._full_mapping({1: 2}))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadCmp(object):
@ -723,7 +723,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
def test_setdefault(self):
TestMappingProtocol.test_setdefault(self)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadHash(object):

View File

@ -80,7 +80,7 @@ index 719c9434a16..290e57c04a0 100644
- return len(self.__data)
- def __getitem__(self, i):
- return self.__data[i]
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class OtherSeq:
+ def __init__(self, initseq):
+ self.__data = initseq
@ -100,7 +100,7 @@ index 719c9434a16..290e57c04a0 100644
- class StopCompares:
- def __eq__(self, other):
- raise DoNotTestEq
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class DoNotTestEq(Exception):
+ pass
+ class StopCompares:
@ -115,7 +115,7 @@ index 719c9434a16..290e57c04a0 100644
- class subclass(self.type2test):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass(self.type2test):
+ pass
u3 = subclass([0, 1])
@ -128,7 +128,7 @@ index 719c9434a16..290e57c04a0 100644
- class T(self.type2test):
- def __getitem__(self, key):
- return str(key) + '!!!'
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class T(self.type2test):
+ def __getitem__(self, key):
+ return str(key) + '!!!'
@ -141,7 +141,7 @@ index 719c9434a16..290e57c04a0 100644
- class BadExc(Exception):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadExc(Exception):
+ pass
@ -164,7 +164,7 @@ index 719c9434a16..290e57c04a0 100644
- class BadExc(Exception):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadExc(Exception):
+ pass

View File

@ -169,7 +169,7 @@ class CommonTest(__TestCase):
uu2 = self.type2test(u2)
v = self.type2test(tuple(u))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class OtherSeq:
def __init__(self, initseq):
self.__data = initseq
@ -294,7 +294,7 @@ class CommonTest(__TestCase):
# Sequences must test in-order. If a rich comparison has side
# effects, these will be visible to tests against later members.
# In this test, the "side effect" is a short-circuiting raise.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class DoNotTestEq(Exception):
pass
class StopCompares:
@ -339,7 +339,7 @@ class CommonTest(__TestCase):
self.assertEqual(u2+u2+u2, u2*3)
self.assertEqual(u2+u2+u2, 3*u2)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass(self.type2test):
pass
u3 = subclass([0, 1])
@ -368,7 +368,7 @@ class CommonTest(__TestCase):
def test_getitemoverwriteiter(self):
# Verify that __getitem__ overrides are not recognized by __iter__
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class T(self.type2test):
def __getitem__(self, key):
return str(key) + '!!!'
@ -419,7 +419,7 @@ class CommonTest(__TestCase):
self.assertRaises(TypeError, a.count)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadExc(Exception):
pass
@ -453,7 +453,7 @@ class CommonTest(__TestCase):
self.assertRaises(TypeError, u.index)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadExc(Exception):
pass

View File

@ -37,7 +37,7 @@ index 34ecb45f161..12b719c432b 100644
try:
- class C(bool):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(bool):
+ pass
except TypeError:
@ -50,7 +50,7 @@ index 34ecb45f161..12b719c432b 100644
- class Foo(object):
- def __bool__(self):
- return self
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Foo(object):
+ def __bool__(self):
+ return self
@ -59,7 +59,7 @@ index 34ecb45f161..12b719c432b 100644
- class Bar(object):
- def __bool__(self):
- return "Yes"
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Bar(object):
+ def __bool__(self):
+ return "Yes"
@ -68,7 +68,7 @@ index 34ecb45f161..12b719c432b 100644
- class Baz(int):
- def __bool__(self):
- return self
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Baz(int):
+ def __bool__(self):
+ return self
@ -78,7 +78,7 @@ index 34ecb45f161..12b719c432b 100644
- class Spam(int):
- def __bool__(self):
- return 1
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Spam(int):
+ def __bool__(self):
+ return 1
@ -87,7 +87,7 @@ index 34ecb45f161..12b719c432b 100644
- class Eggs:
- def __len__(self):
- return -1
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Eggs:
+ def __len__(self):
+ return -1
@ -97,7 +97,7 @@ index 34ecb45f161..12b719c432b 100644
- class SymbolicBool:
- def __bool__(self):
- raise TypeError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class SymbolicBool:
+ def __bool__(self):
+ raise TypeError
@ -118,7 +118,7 @@ index 34ecb45f161..12b719c432b 100644
- class A:
- def __len__(self):
- return badval
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ def __len__(self):
+ return badval
@ -131,7 +131,7 @@ index 34ecb45f161..12b719c432b 100644
def test_blocked(self):
- class A:
- __bool__ = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ __bool__ = None
self.assertRaises(TypeError, bool, A())
@ -140,7 +140,7 @@ index 34ecb45f161..12b719c432b 100644
- def __len__(self):
- return 10
- __bool__ = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class B:
+ def __len__(self):
+ return 10
@ -158,7 +158,7 @@ index 34ecb45f161..12b719c432b 100644
- def __bool__(self):
- self.count += 1
- return True
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X:
+ def __init__(self):
+ self.count = 0

View File

@ -29,7 +29,7 @@ class BoolTest(__TestCase):
def test_subclass(self):
try:
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(bool):
pass
except TypeError:
@ -328,39 +328,39 @@ class BoolTest(__TestCase):
# from __bool__(). This isn't really a bool test, but
# it's related.
check = lambda o: self.assertRaises(TypeError, bool, o)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Foo(object):
def __bool__(self):
return self
check(Foo())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Bar(object):
def __bool__(self):
return "Yes"
check(Bar())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Baz(int):
def __bool__(self):
return self
check(Baz())
# __bool__() must return a bool not an int
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Spam(int):
def __bool__(self):
return 1
check(Spam())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Eggs:
def __len__(self):
return -1
self.assertRaises(ValueError, bool, Eggs())
def test_interpreter_convert_to_bool_raises(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class SymbolicBool:
def __bool__(self):
raise TypeError
@ -388,7 +388,7 @@ class BoolTest(__TestCase):
# this test just tests our assumptions about __len__
# this will start failing if __len__ changes assertions
for badval in ['illegal', -1, 1 << 32]:
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
def __len__(self):
return badval
@ -401,12 +401,12 @@ class BoolTest(__TestCase):
self.assertEqual(str(e_bool), str(e_len))
def test_blocked(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
__bool__ = None
self.assertRaises(TypeError, bool, A())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class B:
def __len__(self):
return 10
@ -424,7 +424,7 @@ class BoolTest(__TestCase):
self.assertIs(type(False.imag), int)
def test_bool_called_at_least_once(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X:
def __init__(self):
self.count = 0

View File

@ -127,7 +127,7 @@ index a96a5780b31..d00dfca8a17 100644
- class MyComplexException:
- def __complex__(self):
- raise SomeException
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyComplex:
+ def __init__(self, value):
+ self.value = value

View File

@ -251,7 +251,7 @@ class CMathTests(__TestCase):
# end up being passed to the cmath functions
# usual case: new-style class implementing __complex__
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyComplex:
def __init__(self, value):
self.value = value

View File

@ -42,7 +42,7 @@ index cafc44007d1..4571e5a14fd 100644
- class A(UserDict):
- def __missing__(self, key):
- return 456
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A(UserDict):
+ def __missing__(self, key):
+ return 456
@ -65,7 +65,7 @@ index cafc44007d1..4571e5a14fd 100644
- class DefaultChainMap(ChainMap):
- def __missing__(self, key):
- return 999
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class DefaultChainMap(ChainMap):
+ def __missing__(self, key):
+ return 999
@ -83,7 +83,7 @@ index cafc44007d1..4571e5a14fd 100644
- def __getitem__(self, item):
- self.called = True
- UserDict.__getitem__(self, item)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class DictWithGetItem(UserDict):
+ def __init__(self, *args, **kwds):
+ self.called = False
@ -107,7 +107,7 @@ index cafc44007d1..4571e5a14fd 100644
- if isinstance(key, str):
- key = key.lower()
- return dict.__contains__(self, key)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class lowerdict(dict):
+ def __getitem__(self, key):
+ if isinstance(key, str):
@ -135,7 +135,7 @@ index cafc44007d1..4571e5a14fd 100644
def test_namedtuple_subclass_issue_24931(self):
- class Point(namedtuple('_Point', ['x', 'y'])):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Point(namedtuple('_Point', ['x', 'y'])):
+ pass
@ -153,7 +153,7 @@ index cafc44007d1..4571e5a14fd 100644
# everything should work will all required methods are present
- C = type('C', (abc,), methodstubs)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ C = type('C', (abc,), methodstubs)
C()
@ -183,7 +183,7 @@ index cafc44007d1..4571e5a14fd 100644
- class I(Iterable):
- def __iter__(self):
- return super().__iter__()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Check direct subclassing
+ class I(Iterable):
+ def __iter__(self):
@ -197,7 +197,7 @@ index cafc44007d1..4571e5a14fd 100644
- def __iter__(self): return iter([])
- class ItBlocked(It):
- __iter__ = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Check None blocking
+ class It:
+ def __iter__(self): return iter([])
@ -216,7 +216,7 @@ index cafc44007d1..4571e5a14fd 100644
- return iter(list())
- def __reversed__(self):
- return iter(list())
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Check direct subclassing
+ class R(Reversible):
+ def __iter__(self):
@ -231,7 +231,7 @@ index cafc44007d1..4571e5a14fd 100644
- def __reversed__(self): return reversed([])
- class RevPlusIter(RevNoIter):
- def __iter__(self): return iter([])
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Check reversible non-iterable (which is not Reversible)
+ class RevNoIter:
+ def __reversed__(self): return reversed([])
@ -249,7 +249,7 @@ index cafc44007d1..4571e5a14fd 100644
- __iter__ = None
- class RevRevBlocked(Rev):
- __reversed__ = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Check None blocking
+ class Rev:
+ def __iter__(self): return iter([])
@ -274,7 +274,7 @@ index cafc44007d1..4571e5a14fd 100644
- def __contains__(self, item):
- return False
- class DerCol(Col): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Check direct subclassing
+ class Col(Collection):
+ def __iter__(self):
@ -300,7 +300,7 @@ index cafc44007d1..4571e5a14fd 100644
- class ColNoCont:
- def __iter__(self): return iter([])
- def __len__(self): return 0
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class ColNoIter:
+ def __len__(self): return 0
+ def __contains__(self, item): return False
@ -326,7 +326,7 @@ index cafc44007d1..4571e5a14fd 100644
- def __contains__(self): return True
- __iter__ = None
+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Check None blocking
+ class SizeBlock:
+ def __iter__(self): return iter([])
@ -350,7 +350,7 @@ index cafc44007d1..4571e5a14fd 100644
- return False
- class NonCol(ColImpl):
- __contains__ = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Check None blocking in subclass
+ class ColImpl:
+ def __iter__(self):
@ -373,7 +373,7 @@ index cafc44007d1..4571e5a14fd 100644
- def __next__(self):
- yield 1
- return
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Issue 10565
+ class NextOnly:
+ def __next__(self):
@ -398,7 +398,7 @@ index cafc44007d1..4571e5a14fd 100644
- def close(self): pass
- def send(self, value): return value
- def throw(self, typ, val=None, tb=None): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class NonGen1:
+ def __iter__(self): return self
+ def __next__(self): return None
@ -428,7 +428,7 @@ index cafc44007d1..4571e5a14fd 100644
- def close(self): pass
- def send(self, value): return value
- def throw(self, typ, val=None, tb=None): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Gen:
+ def __iter__(self): return self
+ def __next__(self): return None
@ -456,7 +456,7 @@ index cafc44007d1..4571e5a14fd 100644
- class FailOnClose(Generator):
- def send(self, value): return value
- def throw(self, *args): raise ValueError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class FailOnClose(Generator):
+ def send(self, value): return value
+ def throw(self, *args): raise ValueError
@ -466,7 +466,7 @@ index cafc44007d1..4571e5a14fd 100644
- class IgnoreGeneratorExit(Generator):
- def send(self, value): return value
- def throw(self, *args): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class IgnoreGeneratorExit(Generator):
+ def send(self, value): return value
+ def throw(self, *args): pass
@ -479,7 +479,7 @@ index cafc44007d1..4571e5a14fd 100644
for B in Hashable, Iterable, Iterator, Reversible, Sized, Container, Callable:
- class C(B):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(B):
+ pass
self.assertTrue(issubclass(C, B))
@ -489,7 +489,7 @@ index cafc44007d1..4571e5a14fd 100644
for B in Hashable, Iterable, Iterator, Reversible, Sized, Container, Callable:
- class C:
- __hash__ = None # Make sure it isn't hashable by default
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ __hash__ = None # Make sure it isn't hashable by default
self.assertFalse(issubclass(C, B), B.__name__)
@ -506,7 +506,7 @@ index cafc44007d1..4571e5a14fd 100644
- return 0
- def __iter__(self):
- return iter([])
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MySet(Set):
+ def __contains__(self, x):
+ return False
@ -530,7 +530,7 @@ index cafc44007d1..4571e5a14fd 100644
- return iter(self.contents)
- def __len__(self):
- return len([x for x in self.contents])
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MySet(Set):
+ def __init__(self, itr):
+ self.contents = itr
@ -556,7 +556,7 @@ index cafc44007d1..4571e5a14fd 100644
- return iter(self.contents)
- def __len__(self):
- return len([x for x in self.contents])
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MySet(Set):
+ def __init__(self, itr):
+ self.contents = itr
@ -582,7 +582,7 @@ index cafc44007d1..4571e5a14fd 100644
- return iter(self.contents)
- def __len__(self):
- return len([x for x in self.contents])
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MySet(Set):
+ def __init__(self, itr):
+ self.contents = itr
@ -621,7 +621,7 @@ index cafc44007d1..4571e5a14fd 100644
- return result
- def __repr__(self):
- return "MySet(%s)" % repr(list(self))
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MySet(MutableSet):
+ __slots__=['__s']
+ def __init__(self,items=None):
@ -669,7 +669,7 @@ index cafc44007d1..4571e5a14fd 100644
- return NotImplemented
- def __lt__(self, x):
- return NotImplemented
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyComparableSet(Set):
+ def __contains__(self, x):
+ return False
@ -702,7 +702,7 @@ index cafc44007d1..4571e5a14fd 100644
- return self._seq[index]
- def __len__(self):
- return len(self._seq)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class CustomSequence(Sequence):
+ def __init__(self, seq):
+ self._seq = seq
@ -723,7 +723,7 @@ index cafc44007d1..4571e5a14fd 100644
- raise ValueError('created_by must be specified')
- self.created_by = created_by
- self._values = set(values)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class SetUsingInstanceFromIterable(MutableSet):
+ def __init__(self, values, created_by):
+ if not created_by:
@ -781,7 +781,7 @@ index cafc44007d1..4571e5a14fd 100644
- return len(self.data)
- def __repr__(self):
- return 'Set({!r})'.format(self.data)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class ListSet(Set):
+ def __init__(self, elements=()):
+ self.data = []
@ -810,7 +810,7 @@ index cafc44007d1..4571e5a14fd 100644
- raise IndexError
- def __iter__(self):
- return iter(())
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyMapping(Mapping):
+ def __len__(self):
+ return 0
@ -837,7 +837,7 @@ index cafc44007d1..4571e5a14fd 100644
- class SequenceSubclass(Sequence):
- def __init__(self, seq=()):
- self.seq = seq
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class SequenceSubclass(Sequence):
+ def __init__(self, seq=()):
+ self.seq = seq
@ -861,7 +861,7 @@ index cafc44007d1..4571e5a14fd 100644
- class MutableSequenceSubclass(MutableSequence):
- def __init__(self):
- self.lst = []
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MutableSequenceSubclass(MutableSequence):
+ def __init__(self):
+ self.lst = []
@ -908,7 +908,7 @@ index cafc44007d1..4571e5a14fd 100644
def test_copy_subclass(self):
- class MyCounter(Counter):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyCounter(Counter):
+ pass
c = MyCounter('slartibartfast')

View File

@ -93,7 +93,7 @@ class TestUserObjects(__TestCase):
self._copy_test(obj)
def test_dict_missing(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A(UserDict):
def __missing__(self, key):
return 456
@ -193,7 +193,7 @@ class TestChainMap(__TestCase):
self.assertTrue(ChainMap({}, {1:2}))
def test_missing(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class DefaultChainMap(ChainMap):
def __missing__(self, key):
return 999
@ -228,7 +228,7 @@ class TestChainMap(__TestCase):
('i', 9999), ('j', 0)])
def test_iter_not_calling_getitem_on_maps(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class DictWithGetItem(UserDict):
def __init__(self, *args, **kwds):
self.called = False
@ -260,7 +260,7 @@ class TestChainMap(__TestCase):
self.assertIs(m, d.maps[0])
# Use a different map than a dict
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class lowerdict(dict):
def __getitem__(self, key):
if isinstance(key, str):
@ -690,7 +690,7 @@ class TestNamedTuple(__TestCase):
NT = namedtuple('NT', ['abc', 'def'], False, True)
def test_namedtuple_subclass_issue_24931(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Point(namedtuple('_Point', ['x', 'y'])):
pass
@ -753,7 +753,7 @@ class ABCTestCase(__TestCase):
methodstubs = dict.fromkeys(names, lambda s, *args: 0)
# everything should work will all required methods are present
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
C = type('C', (abc,), methodstubs)
C()
@ -1011,7 +1011,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
for x in samples:
self.assertIsInstance(x, Iterable)
self.assertTrue(issubclass(type(x), Iterable), repr(type(x)))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Check direct subclassing
class I(Iterable):
def __iter__(self):
@ -1020,7 +1020,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertFalse(issubclass(str, I))
self.validate_abstract_methods(Iterable, '__iter__')
self.validate_isinstance(Iterable, '__iter__')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Check None blocking
class It:
def __iter__(self): return iter([])
@ -1055,7 +1055,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertTrue(issubclass(Sequence, Reversible), repr(Sequence))
self.assertFalse(issubclass(Mapping, Reversible), repr(Mapping))
self.assertFalse(issubclass(MutableMapping, Reversible), repr(MutableMapping))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Check direct subclassing
class R(Reversible):
def __iter__(self):
@ -1065,7 +1065,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertEqual(list(reversed(R())), [])
self.assertFalse(issubclass(float, R))
self.validate_abstract_methods(Reversible, '__reversed__', '__iter__')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Check reversible non-iterable (which is not Reversible)
class RevNoIter:
def __reversed__(self): return reversed([])
@ -1075,7 +1075,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertFalse(isinstance(RevNoIter(), Reversible))
self.assertTrue(issubclass(RevPlusIter, Reversible))
self.assertTrue(isinstance(RevPlusIter(), Reversible))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Check None blocking
class Rev:
def __iter__(self): return iter([])
@ -1117,7 +1117,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertTrue(issubclass(Set, Collection), repr(Set))
self.assertTrue(issubclass(MutableSet, Collection), repr(MutableSet))
self.assertTrue(issubclass(Sequence, Collection), repr(MutableSet))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Check direct subclassing
class Col(Collection):
def __iter__(self):
@ -1138,7 +1138,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.validate_abstract_methods(Collection, '__len__', '__iter__',
'__contains__')
# Check sized container non-iterable (which is not Collection) etc.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class ColNoIter:
def __len__(self): return 0
def __contains__(self, item): return False
@ -1155,7 +1155,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertFalse(issubclass(ColNoCont, Collection))
self.assertFalse(isinstance(ColNoCont(), Collection))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Check None blocking
class SizeBlock:
def __iter__(self): return iter([])
@ -1169,7 +1169,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertFalse(isinstance(SizeBlock(), Collection))
self.assertFalse(issubclass(IterBlock, Collection))
self.assertFalse(isinstance(IterBlock(), Collection))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Check None blocking in subclass
class ColImpl:
def __iter__(self):
@ -1202,7 +1202,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertTrue(issubclass(type(x), Iterator), repr(type(x)))
self.validate_abstract_methods(Iterator, '__next__', '__iter__')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Issue 10565
class NextOnly:
def __next__(self):
@ -1211,7 +1211,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertNotIsInstance(NextOnly(), Iterator)
def test_Generator(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class NonGen1:
def __iter__(self): return self
def __next__(self): return None
@ -1236,7 +1236,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertNotIsInstance(x, Generator)
self.assertFalse(issubclass(type(x), Generator), repr(type(x)))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Gen:
def __iter__(self): return self
def __next__(self): return None
@ -1271,14 +1271,14 @@ class TestOneTrickPonyABCs(ABCTestCase):
mgen.throw, ValueError, ValueError("huhu"))
self.assertRaises(StopIteration, mgen.throw, StopIteration())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class FailOnClose(Generator):
def send(self, value): return value
def throw(self, *args): raise ValueError
self.assertRaises(ValueError, FailOnClose().close)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class IgnoreGeneratorExit(Generator):
def send(self, value): return value
def throw(self, *args): pass
@ -1424,7 +1424,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
def test_direct_subclassing(self):
for B in Hashable, Iterable, Iterator, Reversible, Sized, Container, Callable:
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(B):
pass
self.assertTrue(issubclass(C, B))
@ -1432,7 +1432,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
def test_registration(self):
for B in Hashable, Iterable, Iterator, Reversible, Sized, Container, Callable:
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
__hash__ = None # Make sure it isn't hashable by default
self.assertFalse(issubclass(C, B), B.__name__)
@ -1470,7 +1470,7 @@ class TestCollectionABCs(ABCTestCase):
self.assertIsInstance(sample(), Set)
self.assertTrue(issubclass(sample, Set))
self.validate_abstract_methods(Set, '__contains__', '__iter__', '__len__')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MySet(Set):
def __contains__(self, x):
return False
@ -1496,7 +1496,7 @@ class TestCollectionABCs(ABCTestCase):
self.assertTrue(hash(a) == hash(b))
def test_isdisjoint_Set(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MySet(Set):
def __init__(self, itr):
self.contents = itr
@ -1513,7 +1513,7 @@ class TestCollectionABCs(ABCTestCase):
self.assertFalse(s1.isdisjoint(s3))
def test_equality_Set(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MySet(Set):
def __init__(self, itr):
self.contents = itr
@ -1536,7 +1536,7 @@ class TestCollectionABCs(ABCTestCase):
self.assertNotEqual(s2, s3)
def test_arithmetic_Set(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MySet(Set):
def __init__(self, itr):
self.contents = itr
@ -1567,7 +1567,7 @@ class TestCollectionABCs(ABCTestCase):
def test_issue_4920(self):
# MutableSet.pop() method did not work
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MySet(MutableSet):
__slots__=['__s']
def __init__(self,items=None):
@ -1615,7 +1615,7 @@ class TestCollectionABCs(ABCTestCase):
def test_issue16373(self):
# Recursion error comparing comparable and noncomparable
# Set instances
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyComparableSet(Set):
def __contains__(self, x):
return False
@ -1644,7 +1644,7 @@ class TestCollectionABCs(ABCTestCase):
def test_issue26915(self):
# Container membership test should check identity first
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class CustomSequence(Sequence):
def __init__(self, seq):
self._seq = seq
@ -1676,7 +1676,7 @@ class TestCollectionABCs(ABCTestCase):
def test_Set_from_iterable(self):
"""Verify _from_iterable overridden to an instance method works."""
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class SetUsingInstanceFromIterable(MutableSet):
def __init__(self, values, created_by):
if not created_by:
@ -1733,7 +1733,7 @@ class TestCollectionABCs(ABCTestCase):
def test_Set_interoperability_with_real_sets(self):
# Issue: 8743
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class ListSet(Set):
def __init__(self, elements=()):
self.data = []
@ -1902,7 +1902,7 @@ class TestCollectionABCs(ABCTestCase):
self.assertTrue(issubclass(sample, Mapping))
self.validate_abstract_methods(Mapping, '__contains__', '__iter__', '__len__',
'__getitem__')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyMapping(Mapping):
def __len__(self):
return 0
@ -1960,7 +1960,7 @@ class TestCollectionABCs(ABCTestCase):
'__getitem__')
def test_Sequence_mixins(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class SequenceSubclass(Sequence):
def __init__(self, seq=()):
self.seq = seq
@ -2041,7 +2041,7 @@ class TestCollectionABCs(ABCTestCase):
def test_MutableSequence_mixins(self):
# Test the mixins of MutableSequence by creating a minimal concrete
# class inherited from it.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MutableSequenceSubclass(MutableSequence):
def __init__(self):
self.lst = []
@ -2284,7 +2284,7 @@ class TestCounter(__TestCase):
check(Counter(words))
def test_copy_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyCounter(Counter):
pass
c = MyCounter('slartibartfast')

View File

@ -240,7 +240,7 @@ index 6ff1a8ab29d..1572433c5ae 100644
def test_boolcontext(self):
for i in range(100):
- self.assertTrue(complex(random() + 1e-6, random() + 1e-6))
+ with torch._dynamo.set_fullgraph(False):
+ with torch._dynamo.error_on_graph_break(False):
+ r1 = random()
+ r2 = random()
+ self.assertTrue(complex(r1 + 1e-6, r2 + 1e-6))
@ -253,7 +253,7 @@ index 6ff1a8ab29d..1572433c5ae 100644
- class EvilExc(Exception):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class EvilExc(Exception):
+ pass
@ -273,7 +273,7 @@ index 6ff1a8ab29d..1572433c5ae 100644
- class MyInt:
- def __int__(self):
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyInt:
+ def __int__(self):
+ return 42
@ -299,7 +299,7 @@ index 6ff1a8ab29d..1572433c5ae 100644
- complex is returned"""
- def __complex__(self):
- return None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class complex0(complex):
+ """Test usage of __complex__() when inheriting from 'complex'"""
+ def __complex__(self):

View File

@ -526,7 +526,7 @@ class ComplexTest(__TestCase):
def test_boolcontext(self):
for i in range(100):
with torch._dynamo.set_fullgraph(False):
with torch._dynamo.error_on_graph_break(False):
r1 = random()
r2 = random()
self.assertTrue(complex(r1 + 1e-6, r2 + 1e-6))
@ -622,7 +622,7 @@ class ComplexTest(__TestCase):
self.assertRaises(TypeError, complex, WithComplex(1), object())
self.assertRaises(TypeError, complex, WithComplex(None), object())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class EvilExc(Exception):
pass
@ -652,7 +652,7 @@ class ComplexTest(__TestCase):
self.assertRaises(TypeError, complex, WithIndex(None), 1.5)
self.assertRaises(TypeError, complex, 1.5, WithIndex(None))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyInt:
def __int__(self):
return 42
@ -661,7 +661,7 @@ class ComplexTest(__TestCase):
self.assertRaises(TypeError, complex, MyInt(), 1.5)
self.assertRaises(TypeError, complex, 1.5, MyInt())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class complex0(complex):
"""Test usage of __complex__() when inheriting from 'complex'"""
def __complex__(self):

View File

@ -71,7 +71,7 @@ index cf651959803..256a824932d 100644
- class DefaultEnter(AbstractContextManager):
- def __exit__(self, *args):
- super().__exit__(*args)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class DefaultEnter(AbstractContextManager):
+ def __exit__(self, *args):
+ super().__exit__(*args)
@ -82,7 +82,7 @@ index cf651959803..256a824932d 100644
def test_slots(self):
- class DefaultContextManager(AbstractContextManager):
- __slots__ = ()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class DefaultContextManager(AbstractContextManager):
+ __slots__ = ()
@ -97,7 +97,7 @@ index cf651959803..256a824932d 100644
def test_exit_is_abstract(self):
- class MissingExit(AbstractContextManager):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MissingExit(AbstractContextManager):
+ pass
@ -110,7 +110,7 @@ index cf651959803..256a824932d 100644
- return self
- def __exit__(self, exc_type, exc_value, traceback):
- return None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class ManagerFromScratch:
+ def __enter__(self):
+ return self
@ -122,7 +122,7 @@ index cf651959803..256a824932d 100644
- class DefaultEnter(AbstractContextManager):
- def __exit__(self, *args):
- super().__exit__(*args)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class DefaultEnter(AbstractContextManager):
+ def __exit__(self, *args):
+ super().__exit__(*args)
@ -131,7 +131,7 @@ index cf651959803..256a824932d 100644
- class NoEnter(ManagerFromScratch):
- __enter__ = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class NoEnter(ManagerFromScratch):
+ __enter__ = None
@ -139,7 +139,7 @@ index cf651959803..256a824932d 100644
- class NoExit(ManagerFromScratch):
- __exit__ = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class NoExit(ManagerFromScratch):
+ __exit__ = None
@ -157,7 +157,7 @@ index cf651959803..256a824932d 100644
# Repeat with RuntimeError (which goes through a different code path)
- class RuntimeErrorSubclass(RuntimeError):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class RuntimeErrorSubclass(RuntimeError):
+ pass
@ -169,7 +169,7 @@ index cf651959803..256a824932d 100644
- class StopIterationSubclass(StopIteration):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class StopIterationSubclass(StopIteration):
+ pass
@ -207,7 +207,7 @@ index cf651959803..256a824932d 100644
- class StopIterationSubclass(StopIteration):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class StopIterationSubclass(StopIteration):
+ pass
@ -219,7 +219,7 @@ index cf651959803..256a824932d 100644
def test_nokeepref(self):
- class A:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ pass
@ -241,7 +241,7 @@ index cf651959803..256a824932d 100644
- class C:
- def close(self):
- state.append(1)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ def close(self):
+ state.append(1)
@ -255,7 +255,7 @@ index cf651959803..256a824932d 100644
- class C:
- def close(self):
- state.append(1)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ def close(self):
+ state.append(1)
@ -271,7 +271,7 @@ index cf651959803..256a824932d 100644
def test_nullcontext(self):
- class C:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ pass
c = C()
@ -307,7 +307,7 @@ index cf651959803..256a824932d 100644
context = mycontext()
- class Test(object):
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Test(object):
- @context
@ -332,7 +332,7 @@ index cf651959803..256a824932d 100644
- pass
- def __exit__(self, *exc):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class mycontext(ContextDecorator):
+ def __unter__(self):
+ pass
@ -350,7 +350,7 @@ index cf651959803..256a824932d 100644
- pass
- def __uxit__(self, *exc):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class mycontext(ContextDecorator):
+ def __enter__(self):
+ pass
@ -366,7 +366,7 @@ index cf651959803..256a824932d 100644
- class somecontext(object):
- started = False
- exc = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class somecontext(object):
+ started = False
+ exc = None
@ -410,7 +410,7 @@ index cf651959803..256a824932d 100644
- self.fail("Should not be called!")
- def __exit__(self, *exc_details):
- self.check_exc(*exc_details)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class ExitCM(object):
+ def __init__(self, check_exc):
+ self.check_exc = check_exc
@ -430,7 +430,7 @@ index cf651959803..256a824932d 100644
- result.append(1)
- def __exit__(self, *exc_details):
- result.append(3)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class TestCM(object):
+ def __enter__(self):
+ result.append(1)
@ -450,7 +450,7 @@ index cf651959803..256a824932d 100644
- pass
- class LacksExit:
- def __enter__(self):
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class LacksEnterAndExit:
pass
+ class LacksEnter:
@ -492,7 +492,7 @@ index cf651959803..256a824932d 100644
- def __exit__(self, *exc_details):
- type(self).saved_details = exc_details
- return True
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class RaiseExc:
+ def __init__(self, exc):
+ self.exc = exc
@ -528,7 +528,7 @@ index cf651959803..256a824932d 100644
- class MyException(Exception):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyException(Exception):
+ pass
@ -539,7 +539,7 @@ index cf651959803..256a824932d 100644
def test_instance_bypass(self):
- class Example(object): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Example(object): pass
cm = Example()
cm.__enter__ = object()
@ -550,7 +550,7 @@ index cf651959803..256a824932d 100644
# https://bugs.python.org/issue27122
- class UniqueException(Exception): pass
- class UniqueRuntimeError(RuntimeError): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class UniqueException(Exception): pass
+ class UniqueRuntimeError(RuntimeError): pass

View File

@ -71,7 +71,7 @@ import weakref
class TestAbstractContextManager(__TestCase):
def test_enter(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class DefaultEnter(AbstractContextManager):
def __exit__(self, *args):
super().__exit__(*args)
@ -80,7 +80,7 @@ class TestAbstractContextManager(__TestCase):
self.assertIs(manager.__enter__(), manager)
def test_slots(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class DefaultContextManager(AbstractContextManager):
__slots__ = ()
@ -91,7 +91,7 @@ class TestAbstractContextManager(__TestCase):
DefaultContextManager().var = 42
def test_exit_is_abstract(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MissingExit(AbstractContextManager):
pass
@ -99,7 +99,7 @@ class TestAbstractContextManager(__TestCase):
MissingExit()
def test_structural_subclassing(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class ManagerFromScratch:
def __enter__(self):
return self
@ -108,20 +108,20 @@ class TestAbstractContextManager(__TestCase):
self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class DefaultEnter(AbstractContextManager):
def __exit__(self, *args):
super().__exit__(*args)
self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class NoEnter(ManagerFromScratch):
__enter__ = None
self.assertFalse(issubclass(NoEnter, AbstractContextManager))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class NoExit(ManagerFromScratch):
__exit__ = None
@ -176,7 +176,7 @@ class ContextManagerTestCase(__TestCase):
self.assertEqual(frames[0].line, '1/0')
# Repeat with RuntimeError (which goes through a different code path)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class RuntimeErrorSubclass(RuntimeError):
pass
@ -190,7 +190,7 @@ class ContextManagerTestCase(__TestCase):
self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class StopIterationSubclass(StopIteration):
pass
@ -293,7 +293,7 @@ class ContextManagerTestCase(__TestCase):
def woohoo():
yield
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class StopIterationSubclass(StopIteration):
pass
@ -408,7 +408,7 @@ def woohoo():
self.assertEqual(target, (11, 22, 33, 44))
def test_nokeepref(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
pass
@ -472,7 +472,7 @@ class ClosingTestCase(__TestCase):
def test_closing(self):
state = []
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
def close(self):
state.append(1)
@ -484,7 +484,7 @@ class ClosingTestCase(__TestCase):
def test_closing_error(self):
state = []
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
def close(self):
state.append(1)
@ -499,7 +499,7 @@ class ClosingTestCase(__TestCase):
class NullcontextTestCase(__TestCase):
def test_nullcontext(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
pass
c = C()
@ -652,7 +652,7 @@ class TestContextDecorator(__TestCase):
def test_decorating_method(self):
context = mycontext()
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Test(object):
@context
@ -681,7 +681,7 @@ class TestContextDecorator(__TestCase):
def test_typo_enter(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class mycontext(ContextDecorator):
def __unter__(self):
pass
@ -694,7 +694,7 @@ class TestContextDecorator(__TestCase):
def test_typo_exit(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class mycontext(ContextDecorator):
def __enter__(self):
pass
@ -707,7 +707,7 @@ class TestContextDecorator(__TestCase):
def test_contextdecorator_as_mixin(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class somecontext(object):
started = False
exc = None
@ -817,7 +817,7 @@ class _TestBaseExitStack:
self.assertIsNone(exc_type)
self.assertIsNone(exc)
self.assertIsNone(exc_tb)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class ExitCM(object):
def __init__(self, check_exc):
self.check_exc = check_exc
@ -843,7 +843,7 @@ class _TestBaseExitStack:
1/0
def test_enter_context(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class TestCM(object):
def __enter__(self):
result.append(1)
@ -863,7 +863,7 @@ class _TestBaseExitStack:
self.assertEqual(result, [1, 2, 3, 4])
def test_enter_context_errors(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class LacksEnterAndExit:
pass
class LacksEnter:
@ -952,7 +952,7 @@ class _TestBaseExitStack:
def test_exit_exception_chaining_reference(self):
# Sanity check to make sure that ExitStack chaining matches
# actual nested with statements
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class RaiseExc:
def __init__(self, exc):
self.exc = exc
@ -1033,7 +1033,7 @@ class _TestBaseExitStack:
# Ensure ExitStack chaining matches actual nested `with` statements
# regarding explicit __context__ = None.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyException(Exception):
pass
@ -1173,7 +1173,7 @@ class _TestBaseExitStack:
stack.callback(int)
def test_instance_bypass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Example(object): pass
cm = Example()
cm.__enter__ = object()
@ -1186,7 +1186,7 @@ class _TestBaseExitStack:
def test_dont_reraise_RuntimeError(self):
# https://bugs.python.org/issue27122
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class UniqueException(Exception): pass
class UniqueRuntimeError(RuntimeError): pass

View File

@ -81,7 +81,7 @@ index bdbe9b81e8f..d55f1dc54c6 100644
- self.default_factory = self._factory
- def _factory(self):
- return []
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class sub(defaultdict):
+ def __init__(self):
+ self.default_factory = self._factory

View File

@ -184,7 +184,7 @@ class TestDefaultDict(__TestCase):
def test_recursive_repr(self):
# Issue2045: stack overflow when default_factory is a bound method
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class sub(defaultdict):
def __init__(self):
self.default_factory = self._factory

View File

@ -73,7 +73,7 @@ index 4729132c5a5..6ecf111c1e3 100644
def test_invalid_keyword_arguments(self):
- class Custom(dict):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Custom(dict):
+ pass
for invalid in {1 : 2}, Custom({1 : 2}):
@ -85,7 +85,7 @@ index 4729132c5a5..6ecf111c1e3 100644
mappingproxy = type(type.__dict__)
- class Dict(dict):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Dict(dict):
+ pass
for cls in [dict, Dict]:
@ -100,7 +100,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- raise Exc()
- def __hash__(self):
- return 24
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadEq(object):
+ def __eq__(self, other):
+ raise Exc()
@ -112,7 +112,7 @@ index 4729132c5a5..6ecf111c1e3 100644
self.assertRaises(KeyError, d.__getitem__, 23)
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class BadHash(object):
@ -143,7 +143,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- return self.d.keys()
- def __getitem__(self, i):
- return self.d[i]
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class SimpleUserDict:
+ def __init__(self):
+ self.d = {1:1, 2:2, 3:3}
@ -156,7 +156,7 @@ index 4729132c5a5..6ecf111c1e3 100644
self.assertEqual(d, {1:1, 2:2, 3:3})
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
d.clear()
@ -164,7 +164,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- def keys(self):
- raise Exc
+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class FailingUserDict:
+ def keys(self):
+ raise Exc
@ -185,7 +185,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- return BogonIter()
- def __getitem__(self, key):
- return key
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class FailingUserDict:
+ def keys(self):
+ class BogonIter:
@ -219,7 +219,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- return BogonIter()
- def __getitem__(self, key):
- raise Exc
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class FailingUserDict:
+ def keys(self):
+ class BogonIter:
@ -244,7 +244,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- def __next__(self):
- raise Exc()
+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class badseq(object):
+ def __iter__(self):
+ return self
@ -264,7 +264,7 @@ index 4729132c5a5..6ecf111c1e3 100644
self.assertEqual(d.fromkeys(g()), {1:None})
self.assertRaises(TypeError, {}.fromkeys, 3)
- class dictlike(dict): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class dictlike(dict): pass
self.assertEqual(dictlike.fromkeys('a'), {'a':None})
self.assertEqual(dictlike().fromkeys('a'), {'a':None})
@ -273,7 +273,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class mydict(dict):
- def __new__(cls):
- return collections.UserDict()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class mydict(dict):
+ def __new__(cls):
+ return collections.UserDict()
@ -283,7 +283,7 @@ index 4729132c5a5..6ecf111c1e3 100644
self.assertRaises(TypeError, dict.fromkeys)
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class baddict1(dict):
@ -300,7 +300,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- return self
- def __next__(self):
- raise Exc()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadSeq(object):
+ def __iter__(self):
+ return self
@ -312,7 +312,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class baddict2(dict):
- def __setitem__(self, key, value):
- raise Exc()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class baddict2(dict):
+ def __setitem__(self, key, value):
+ raise Exc()
@ -326,7 +326,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class baddict3(dict):
- def __new__(cls):
- return d
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class baddict3(dict):
+ def __new__(cls):
+ return d
@ -339,7 +339,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class baddict4(dict):
- def __init__(self):
- dict.__init__(self, d)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class baddict4(dict):
+ def __init__(self):
+ dict.__init__(self, d)
@ -352,7 +352,7 @@ index 4729132c5a5..6ecf111c1e3 100644
def test_copy_maintains_tracking(self):
- class A:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ pass
@ -371,7 +371,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- raise Exc()
- else:
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
+
+ class BadHash(object):
@ -398,7 +398,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- def __eq__(self, other):
- self.eq_count += 1
- return id(self) == id(other)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Hashed(object):
+ def __init__(self):
+ self.hash_count = 0
@ -426,7 +426,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- def __eq__(self, other):
- self.eq_count += 1
- return id(self) == id(other)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Hashed(object):
+ def __init__(self):
+ self.hash_count = 0
@ -454,7 +454,7 @@ index 4729132c5a5..6ecf111c1e3 100644
self.assertRaises(TypeError, d.pop)
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class BadHash(object):
@ -480,7 +480,7 @@ index 4729132c5a5..6ecf111c1e3 100644
# changing dict during a lookup (issue #14417)
- class NastyKey:
- mutate_dict = None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class NastyKey:
+ mutate_dict = None
@ -516,7 +516,7 @@ index 4729132c5a5..6ecf111c1e3 100644
self.assertEqual(repr(d), '{1: {...}}')
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class BadRepr(object):
@ -533,7 +533,7 @@ index 4729132c5a5..6ecf111c1e3 100644
self.assertEqual({1: 2}, {1: 2})
- class Exc(Exception): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
- class BadCmp(object):
@ -556,7 +556,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class C:
- def __eq__(self, other):
- raise RuntimeError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ def __eq__(self, other):
+ raise RuntimeError
@ -570,7 +570,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class D(dict):
- def __missing__(self, key):
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class D(dict):
+ def __missing__(self, key):
+ return 42
@ -584,7 +584,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class E(dict):
- def __missing__(self, key):
- raise RuntimeError(key)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class E(dict):
+ def __missing__(self, key):
+ raise RuntimeError(key)
@ -597,7 +597,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- def __init__(self):
- # An instance variable __missing__ should have no effect
- self.__missing__ = lambda key: None
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class F(dict):
+ def __init__(self):
+ # An instance variable __missing__ should have no effect
@ -609,7 +609,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class G(dict):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class G(dict):
+ pass
g = G()
@ -621,7 +621,7 @@ index 4729132c5a5..6ecf111c1e3 100644
# Dictionary lookups should fail if __eq__() raises an exception.
- class CustomException(Exception):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class CustomException(Exception):
+ pass
@ -654,7 +654,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- if resizing:
- d.clear()
- return False
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X(object):
+ def __hash__(self):
+ return 5
@ -671,7 +671,7 @@ index 4729132c5a5..6ecf111c1e3 100644
# dictview objects.
- class C(object):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ pass
views = (dict.items, dict.values, dict.keys)
@ -684,7 +684,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class MyObject(object):
- pass
+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyObject(object):
+ pass
x, y, z, w, o = 1.5, "a", (1, object()), [], MyObject()
@ -709,7 +709,7 @@ index 4729132c5a5..6ecf111c1e3 100644
def make_shared_key_dict(self, n):
- class C:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ pass
@ -725,7 +725,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- self.a, self.b, self.c = 1, 2, 3
- else:
- self.c, self.b, self.a = 1, 2, 3
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ def __init__(self, order):
+ if order:
@ -741,7 +741,7 @@ index 4729132c5a5..6ecf111c1e3 100644
"""split table must be correctly resized and converted to generic combined table"""
- class C:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ pass
@ -754,14 +754,14 @@ index 4729132c5a5..6ecf111c1e3 100644
- class Foo:
- def __init__(self, msg):
- self.msg = msg
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Foo:
+ def __init__(self, msg):
+ self.msg = msg
f = Foo('123')
- class _str(str):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class _str(str):
+ pass
self.assertEqual(f.msg, getattr(f, _str('msg')))
@ -769,7 +769,7 @@ index 4729132c5a5..6ecf111c1e3 100644
def test_object_set_item_single_instance_non_str_key(self):
- class Foo: pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Foo: pass
f = Foo()
f.__dict__[1] = 1
@ -781,7 +781,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class Mutating:
- def __del__(self):
- mutate(d)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Mutating:
+ def __del__(self):
+ mutate(d)
@ -795,7 +795,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class X:
- def __hash__(self):
- return 0
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X:
+ def __hash__(self):
+ return 0
@ -816,7 +816,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class X():
- def __del__(self):
- dict_b.clear()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X():
+ def __del__(self):
+ dict_b.clear()
@ -842,7 +842,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- def __eq__(self, other):
- dict_d.clear()
- return True
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Y:
+ def __eq__(self, other):
+ dict_d.clear()
@ -857,7 +857,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class X(int):
- def __hash__(self):
- return 13
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X(int):
+ def __hash__(self):
+ return 13
@ -880,7 +880,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class X(int):
- def __hash__(self):
- return 13
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X(int):
+ def __hash__(self):
+ return 13
@ -904,7 +904,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- def __eq__(self, other):
- d.clear()
- return NotImplemented
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X:
+ def __eq__(self, other):
+ d.clear()
@ -919,7 +919,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- def __eq__(self, other):
- d.clear()
- return NotImplemented
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class S(str):
+ def __eq__(self, other):
+ d.clear()
@ -938,7 +938,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- def __hash__(self):
- pair[:] = []
- return 13
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X:
+ def __hash__(self):
+ pair[:] = []
@ -951,7 +951,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class X(int):
- def __del__(self):
- d.clear()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X(int):
+ def __del__(self):
+ d.clear()
@ -966,7 +966,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- def __init__(self, x, y):
- if x: self.x = x
- if y: self.y = y
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ def __init__(self, x, y):
+ if x: self.x = x
@ -980,7 +980,7 @@ index 4729132c5a5..6ecf111c1e3 100644
# dict subclass doesn't override __iter__
- class CustomDict(dict):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class CustomDict(dict):
+ pass
@ -992,7 +992,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class CustomReversedDict(dict):
- def keys(self):
- return reversed(list(dict.keys(self)))
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class CustomReversedDict(dict):
+ def keys(self):
+ return reversed(list(dict.keys(self)))
@ -1014,7 +1014,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class EvilAttr:
- def __init__(self, d):
- self.d = d
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class EvilAttr:
+ def __init__(self, d):
+ self.d = d
@ -1041,7 +1041,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- class StrSub(str):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class StrSub(str):
+ pass
@ -1057,7 +1057,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- eq_count += 1
- return True
- return False
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key3:
+ def __hash__(self):
+ return hash('key3')
@ -1090,7 +1090,7 @@ index 4729132c5a5..6ecf111c1e3 100644
- raise Exc
- def __hash__(self):
- return 7
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Exc(Exception): pass
+ class BadEq:
+ def __eq__(self, other):

View File

@ -71,7 +71,7 @@ from test.support import import_helper, get_c_recursion_limit
class DictTest(__TestCase):
def test_invalid_keyword_arguments(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Custom(dict):
pass
for invalid in {1 : 2}, Custom({1 : 2}):
@ -166,7 +166,7 @@ class DictTest(__TestCase):
def test_views_mapping(self):
mappingproxy = type(type.__dict__)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Dict(dict):
pass
for cls in [dict, Dict]:
@ -216,7 +216,7 @@ class DictTest(__TestCase):
self.assertRaises(TypeError, d.__getitem__)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadEq(object):
def __eq__(self, other):
raise Exc()
@ -227,7 +227,7 @@ class DictTest(__TestCase):
d[BadEq()] = 42
self.assertRaises(KeyError, d.__getitem__, 23)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadHash(object):
@ -262,7 +262,7 @@ class DictTest(__TestCase):
self.assertRaises((TypeError, AttributeError), d.update, None)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class SimpleUserDict:
def __init__(self):
self.d = {1:1, 2:2, 3:3}
@ -274,18 +274,18 @@ class DictTest(__TestCase):
d.update(SimpleUserDict())
self.assertEqual(d, {1:1, 2:2, 3:3})
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
d.clear()
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class FailingUserDict:
def keys(self):
raise Exc
self.assertRaises(Exc, d.update, FailingUserDict())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class FailingUserDict:
def keys(self):
class BogonIter:
@ -303,7 +303,7 @@ class DictTest(__TestCase):
return key
self.assertRaises(Exc, d.update, FailingUserDict())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class FailingUserDict:
def keys(self):
class BogonIter:
@ -323,7 +323,7 @@ class DictTest(__TestCase):
self.assertRaises(Exc, d.update, FailingUserDict())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class badseq(object):
def __iter__(self):
return self
@ -346,13 +346,13 @@ class DictTest(__TestCase):
yield 1
self.assertEqual(d.fromkeys(g()), {1:None})
self.assertRaises(TypeError, {}.fromkeys, 3)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class dictlike(dict): pass
self.assertEqual(dictlike.fromkeys('a'), {'a':None})
self.assertEqual(dictlike().fromkeys('a'), {'a':None})
self.assertIsInstance(dictlike.fromkeys('a'), dictlike)
self.assertIsInstance(dictlike().fromkeys('a'), dictlike)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class mydict(dict):
def __new__(cls):
return collections.UserDict()
@ -361,7 +361,7 @@ class DictTest(__TestCase):
self.assertIsInstance(ud, collections.UserDict)
self.assertRaises(TypeError, dict.fromkeys)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class baddict1(dict):
@ -370,7 +370,7 @@ class DictTest(__TestCase):
self.assertRaises(Exc, baddict1.fromkeys, [1])
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadSeq(object):
def __iter__(self):
return self
@ -379,7 +379,7 @@ class DictTest(__TestCase):
self.assertRaises(Exc, dict.fromkeys, BadSeq())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class baddict2(dict):
def __setitem__(self, key, value):
raise Exc()
@ -398,7 +398,7 @@ class DictTest(__TestCase):
self.assertEqual(dict.fromkeys(d, 0), res)
# test fast path when object's constructor returns large non-empty dict
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class baddict3(dict):
def __new__(cls):
return d
@ -408,7 +408,7 @@ class DictTest(__TestCase):
self.assertEqual(baddict3.fromkeys({"a", "b", "c"}), res)
# test slow path when object is a proper subclass of dict
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class baddict4(dict):
def __init__(self):
dict.__init__(self, d)
@ -447,7 +447,7 @@ class DictTest(__TestCase):
self.assertEqual(len(d2), len(d) + 1)
def test_copy_maintains_tracking(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
pass
@ -495,7 +495,7 @@ class DictTest(__TestCase):
self.assertRaises(TypeError, d.setdefault)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadHash(object):
@ -513,7 +513,7 @@ class DictTest(__TestCase):
def test_setdefault_atomic(self):
# Issue #13521: setdefault() calls __hash__ and __eq__ only once.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Hashed(object):
def __init__(self):
self.hash_count = 0
@ -533,7 +533,7 @@ class DictTest(__TestCase):
self.assertEqual(hashed1.eq_count + hashed2.eq_count, 1)
def test_setitem_atomic_at_resize(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Hashed(object):
def __init__(self):
self.hash_count = 0
@ -599,7 +599,7 @@ class DictTest(__TestCase):
self.assertRaises(TypeError, d.pop)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadHash(object):
@ -652,7 +652,7 @@ class DictTest(__TestCase):
def test_mutating_lookup(self):
# changing dict during a lookup (issue #14417)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class NastyKey:
mutate_dict = None
@ -686,7 +686,7 @@ class DictTest(__TestCase):
d[1] = d
self.assertEqual(repr(d), '{1: {...}}')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadRepr(object):
@ -706,7 +706,7 @@ class DictTest(__TestCase):
self.assertEqual({}, {})
self.assertEqual({1: 2}, {1: 2})
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadCmp(object):
@ -770,7 +770,7 @@ class DictTest(__TestCase):
self.assertFalse(larger == larger3)
def test_errors_in_view_containment_check(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
def __eq__(self, other):
raise RuntimeError
@ -853,7 +853,7 @@ class DictTest(__TestCase):
# (E) subclass defines __missing__ method raising RuntimeError
# (F) subclass sets __missing__ instance variable (no effect)
# (G) subclass doesn't define __missing__ at all
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class D(dict):
def __missing__(self, key):
return 42
@ -864,7 +864,7 @@ class DictTest(__TestCase):
self.assertNotIn(2, d.keys())
self.assertEqual(d[2], 42)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class E(dict):
def __missing__(self, key):
raise RuntimeError(key)
@ -873,7 +873,7 @@ class DictTest(__TestCase):
e[42]
self.assertEqual(c.exception.args, (42,))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class F(dict):
def __init__(self):
# An instance variable __missing__ should have no effect
@ -883,7 +883,7 @@ class DictTest(__TestCase):
f[42]
self.assertEqual(c.exception.args, (42,))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class G(dict):
pass
g = G()
@ -900,7 +900,7 @@ class DictTest(__TestCase):
def test_bad_key(self):
# Dictionary lookups should fail if __eq__() raises an exception.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class CustomException(Exception):
pass
@ -947,7 +947,7 @@ class DictTest(__TestCase):
# Another dict resizing bug (SF bug #1456209).
# This caused Segmentation faults or Illegal instructions.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X(object):
def __hash__(self):
return 5
@ -977,7 +977,7 @@ class DictTest(__TestCase):
def test_container_iterator(self):
# Bug #3680: tp_traverse was not implemented for dictiter and
# dictview objects.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
pass
views = (dict.items, dict.values, dict.keys)
@ -1033,7 +1033,7 @@ class DictTest(__TestCase):
def test_track_dynamic(self):
# Test GC-optimization of dynamically-created dicts
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyObject(object):
pass
x, y, z, w, o = 1.5, "a", (1, object()), [], MyObject()
@ -1103,7 +1103,7 @@ class DictTest(__TestCase):
self._tracked(MyDict())
def make_shared_key_dict(self, n):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
pass
@ -1194,7 +1194,7 @@ class DictTest(__TestCase):
@support.cpython_only
def test_splittable_update(self):
"""dict.update(other) must preserve order in other."""
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
def __init__(self, order):
if order:
@ -1212,7 +1212,7 @@ class DictTest(__TestCase):
@support.cpython_only
def test_splittable_to_generic_combinedtable(self):
"""split table must be correctly resized and converted to generic combined table"""
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
pass
@ -1336,19 +1336,19 @@ class DictTest(__TestCase):
self.assertEqual(sorted(values), sorted(data.values()))
def test_instance_dict_getattr_str_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Foo:
def __init__(self, msg):
self.msg = msg
f = Foo('123')
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class _str(str):
pass
self.assertEqual(f.msg, getattr(f, _str('msg')))
self.assertEqual(f.msg, f.__dict__[_str('msg')])
def test_object_set_item_single_instance_non_str_key(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Foo: pass
f = Foo()
f.__dict__[1] = 1
@ -1359,7 +1359,7 @@ class DictTest(__TestCase):
# This object will trigger mutation of the dict when replaced
# by another value. Note this relies on refcounting: the test
# won't achieve its purpose on fully-GCed Python implementations.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Mutating:
def __del__(self):
mutate(d)
@ -1385,7 +1385,7 @@ class DictTest(__TestCase):
self.check_reentrant_insertion(mutate)
def test_merge_and_mutate(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X:
def __hash__(self):
return 0
@ -1408,7 +1408,7 @@ class DictTest(__TestCase):
def test_equal_operator_modifying_operand(self):
# test fix for seg fault reported in bpo-27945 part 3.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X():
def __del__(self):
dict_b.clear()
@ -1425,7 +1425,7 @@ class DictTest(__TestCase):
self.assertTrue(dict_a == dict_b)
# test fix for seg fault reported in bpo-38588 part 1.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Y:
def __eq__(self, other):
dict_d.clear()
@ -1437,7 +1437,7 @@ class DictTest(__TestCase):
def test_fromkeys_operator_modifying_dict_operand(self):
# test fix for seg fault reported in issue 27945 part 4a.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X(int):
def __hash__(self):
return 13
@ -1456,7 +1456,7 @@ class DictTest(__TestCase):
def test_fromkeys_operator_modifying_set_operand(self):
# test fix for seg fault reported in issue 27945 part 4b.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X(int):
def __hash__(self):
return 13
@ -1474,7 +1474,7 @@ class DictTest(__TestCase):
pass
def test_dictitems_contains_use_after_free(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X:
def __eq__(self, other):
d.clear()
@ -1485,7 +1485,7 @@ class DictTest(__TestCase):
def test_dict_contain_use_after_free(self):
# bpo-40489
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class S(str):
def __eq__(self, other):
d.clear()
@ -1498,7 +1498,7 @@ class DictTest(__TestCase):
self.assertFalse('test' in d)
def test_init_use_after_free(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X:
def __hash__(self):
pair[:] = []
@ -1508,7 +1508,7 @@ class DictTest(__TestCase):
dict([pair])
def test_oob_indexing_dictiter_iternextitem(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X(int):
def __del__(self):
d.clear()
@ -1545,7 +1545,7 @@ class DictTest(__TestCase):
self.assertEqual(list(reversed(dict().keys())), [])
def test_reverse_iterator_for_shared_shared_dicts(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
def __init__(self, x, y):
if x: self.x = x
@ -1565,7 +1565,7 @@ class DictTest(__TestCase):
self.assertEqual(list(copy.items()), expected)
# dict subclass doesn't override __iter__
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class CustomDict(dict):
pass
@ -1574,7 +1574,7 @@ class DictTest(__TestCase):
d = CustomDict(pairs)
self.assertEqual(pairs, list(dict(d).items()))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class CustomReversedDict(dict):
def keys(self):
return reversed(list(dict.keys(self)))
@ -1607,7 +1607,7 @@ class DictTest(__TestCase):
self.assertTrue(gc.is_tracked(next(it)))
def test_store_evilattr(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class EvilAttr:
def __init__(self, d):
self.d = d
@ -1630,13 +1630,13 @@ class DictTest(__TestCase):
# `str` keys. Make sure the unoptimized path is used when a non-`str`
# key appears.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class StrSub(str):
pass
eq_count = 0
# This class compares equal to the string 'key3'
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key3:
def __hash__(self):
return hash('key3')
@ -1746,7 +1746,7 @@ class CAPITest(__TestCase):
# key does not exist
self.assertRaises(KeyError, dict_getitem_knownhash, {}, 1, hash(1))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Exc(Exception): pass
class BadEq:
def __eq__(self, other):

View File

@ -166,7 +166,7 @@ index 97f951f1299..da82bd190c3 100644
- class CustomStr(str): pass
- class CustomBytes(bytes): pass
- class CustomByteArray(bytearray): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class CustomStr(str): pass
+ class CustomBytes(bytes): pass
+ class CustomByteArray(bytearray): pass
@ -180,7 +180,7 @@ index 97f951f1299..da82bd190c3 100644
- class Foo1(object):
- def __float__(self):
- return 42.
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Foo1(object):
+ def __float__(self):
+ return 42.
@ -231,7 +231,7 @@ index 97f951f1299..da82bd190c3 100644
- class Foo5:
- def __float__(self):
- return ""
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Foo5:
+ def __float__(self):
+ return ""
@ -241,7 +241,7 @@ index 97f951f1299..da82bd190c3 100644
- class F:
- def __float__(self):
- return OtherFloatSubclass(42.)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Issue #24731
+ class F:
+ def __float__(self):
@ -258,7 +258,7 @@ index 97f951f1299..da82bd190c3 100644
- self.value = value
- def __index__(self):
- return self.value
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyIndex:
+ def __init__(self, value):
+ self.value = value
@ -271,7 +271,7 @@ index 97f951f1299..da82bd190c3 100644
- class MyInt:
- def __int__(self):
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyInt:
+ def __int__(self):
+ return 42
@ -284,7 +284,7 @@ index 97f951f1299..da82bd190c3 100644
def test_keywords_in_subclass(self):
- class subclass(float):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass(float):
+ pass
u = subclass(2.5)
@ -296,7 +296,7 @@ index 97f951f1299..da82bd190c3 100644
- class subclass_with_init(float):
- def __init__(self, arg, newarg=None):
- self.newarg = newarg
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_init(float):
+ def __init__(self, arg, newarg=None):
+ self.newarg = newarg
@ -310,7 +310,7 @@ index 97f951f1299..da82bd190c3 100644
- self = super().__new__(cls, arg)
- self.newarg = newarg
- return self
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_new(float):
+ def __new__(cls, arg, newarg=None):
+ self = super().__new__(cls, arg)
@ -328,7 +328,7 @@ index 97f951f1299..da82bd190c3 100644
- return 42
- class F(float, H):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class H:
+ def __hash__(self):
+ return 42
@ -455,7 +455,7 @@ index 97f951f1299..da82bd190c3 100644
- class F(float):
- def __new__(cls, value):
- return float.__new__(cls, value + 1)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class F(float):
+ def __new__(cls, value):
+ return float.__new__(cls, value + 1)
@ -467,7 +467,7 @@ index 97f951f1299..da82bd190c3 100644
- class F2(float):
- def __init__(self, value):
- self.foo = 'bar'
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class F2(float):
+ def __init__(self, value):
+ self.foo = 'bar'

View File

@ -222,7 +222,7 @@ class GeneralFloatCases(__TestCase):
def test_non_numeric_input_types(self):
# Test possible non-numeric types for the argument x, including
# subclasses of the explicitly documented accepted types.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class CustomStr(str): pass
class CustomBytes(bytes): pass
class CustomByteArray(bytearray): pass
@ -312,7 +312,7 @@ class GeneralFloatCases(__TestCase):
def test_floatconversion(self):
# Make sure that calls to __float__() work properly
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Foo1(object):
def __float__(self):
return 42.
@ -345,13 +345,13 @@ class GeneralFloatCases(__TestCase):
self.assertRaises(TypeError, float, Foo4(42))
self.assertEqual(float(FooStr('8')), 9.)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Foo5:
def __float__(self):
return ""
self.assertRaises(TypeError, time.sleep, Foo5())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Issue #24731
class F:
def __float__(self):
@ -365,7 +365,7 @@ class GeneralFloatCases(__TestCase):
with self.assertWarns(DeprecationWarning):
self.assertIs(type(FloatSubclass(F())), FloatSubclass)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyIndex:
def __init__(self, value):
self.value = value
@ -375,7 +375,7 @@ class GeneralFloatCases(__TestCase):
self.assertEqual(float(MyIndex(42)), 42.0)
self.assertRaises(OverflowError, float, MyIndex(2**2000))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyInt:
def __int__(self):
return 42
@ -387,7 +387,7 @@ class GeneralFloatCases(__TestCase):
float(x='3.14')
def test_keywords_in_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass(float):
pass
u = subclass(2.5)
@ -396,7 +396,7 @@ class GeneralFloatCases(__TestCase):
with self.assertRaises(TypeError):
subclass(x=0)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_init(float):
def __init__(self, arg, newarg=None):
self.newarg = newarg
@ -405,7 +405,7 @@ class GeneralFloatCases(__TestCase):
self.assertEqual(float(u), 2.5)
self.assertEqual(u.newarg, 3)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_new(float):
def __new__(cls, arg, newarg=None):
self = super().__new__(cls, arg)
@ -746,7 +746,7 @@ class GeneralFloatCases(__TestCase):
def test_hash_nan(self):
value = float('nan')
self.assertEqual(hash(value), object.__hash__(value))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class H:
def __hash__(self):
return 42
@ -1664,7 +1664,7 @@ class HexFloatTestCase(__TestCase):
self.identical(x, fromHex(toHex(x)))
def test_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class F(float):
def __new__(cls, value):
return float.__new__(cls, value + 1)
@ -1673,7 +1673,7 @@ class HexFloatTestCase(__TestCase):
self.assertIs(type(f), F)
self.assertEqual(f, 2.5)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class F2(float):
def __init__(self, value):
self.foo = 'bar'

View File

@ -165,8 +165,8 @@ index 48825f46911..731680d82a0 100644
- self.value = value
- def __index__(self):
- return self.value
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyIndexable(object):
+ def __init__(self, value):
+ self.value = value
@ -183,7 +183,7 @@ index 48825f46911..731680d82a0 100644
- class CustomBytes(bytes): pass
- class CustomByteArray(bytearray): pass
+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class CustomStr(str): pass
+ class CustomBytes(bytes): pass
+ class CustomByteArray(bytearray): pass
@ -196,14 +196,14 @@ index 48825f46911..731680d82a0 100644
# Test __int__()
- class ClassicMissingMethods:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class ClassicMissingMethods:
+ pass
self.assertRaises(TypeError, int, ClassicMissingMethods())
- class MissingMethods(object):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MissingMethods(object):
+ pass
self.assertRaises(TypeError, int, MissingMethods())
@ -211,7 +211,7 @@ index 48825f46911..731680d82a0 100644
- class Foo0:
- def __int__(self):
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Foo0:
+ def __int__(self):
+ return 42
@ -220,7 +220,7 @@ index 48825f46911..731680d82a0 100644
- class Classic:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Classic:
+ pass
for base in (object, Classic):
@ -229,7 +229,7 @@ index 48825f46911..731680d82a0 100644
- return 42
- def __trunc__(self):
- return -12
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class IntOverridesTrunc(base):
+ def __int__(self):
+ return 42
@ -240,7 +240,7 @@ index 48825f46911..731680d82a0 100644
- class JustTrunc(base):
- def __trunc__(self):
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class JustTrunc(base):
+ def __trunc__(self):
+ return 42
@ -250,7 +250,7 @@ index 48825f46911..731680d82a0 100644
- class ExceptionalTrunc(base):
- def __trunc__(self):
- 1 / 0
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class ExceptionalTrunc(base):
+ def __trunc__(self):
+ 1 / 0
@ -266,7 +266,7 @@ index 48825f46911..731680d82a0 100644
- class TruncReturnsNonInt(base):
- def __trunc__(self):
- return Index()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Index(trunc_result_base):
+ def __index__(self):
+ return 42
@ -280,7 +280,7 @@ index 48825f46911..731680d82a0 100644
- class Intable(trunc_result_base):
- def __int__(self):
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Intable(trunc_result_base):
+ def __int__(self):
+ return 42
@ -298,7 +298,7 @@ index 48825f46911..731680d82a0 100644
- def __trunc__(self):
- # Check that we avoid infinite recursion.
- return NonIntegral()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class NonIntegral(trunc_result_base):
+ def __trunc__(self):
+ # Check that we avoid infinite recursion.
@ -321,7 +321,7 @@ index 48825f46911..731680d82a0 100644
- class BadInt(trunc_result_base):
- def __int__(self):
- return 42.0
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # Regression test for bugs.python.org/issue16060.
+ class BadInt(trunc_result_base):
+ def __int__(self):
@ -342,7 +342,7 @@ index 48825f46911..731680d82a0 100644
- class MyIndex(int):
- def __index__(self):
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyIndex(int):
+ def __index__(self):
+ return 42
@ -363,7 +363,7 @@ index 48825f46911..731680d82a0 100644
- class MyInt(int):
- def __int__(self):
- return 42
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyInt(int):
+ def __int__(self):
+ return 42
@ -384,7 +384,7 @@ index 48825f46911..731680d82a0 100644
- class BadIndex:
- def __index__(self):
- return True
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadIndex:
+ def __index__(self):
+ return True

View File

@ -436,8 +436,8 @@ class IntTestCases(__TestCase):
int('0', 5.0)
def test_int_base_indexable(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
with torch._dynamo.error_on_graph_break(False):
class MyIndexable(object):
def __init__(self, value):
self.value = value
@ -458,7 +458,7 @@ class IntTestCases(__TestCase):
# Test possible non-numeric types for the argument x, including
# subclasses of the explicitly documented accepted types.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class CustomStr(str): pass
class CustomBytes(bytes): pass
class CustomByteArray(bytearray): pass
@ -503,28 +503,28 @@ class IntTestCases(__TestCase):
def test_intconversion(self):
# Test __int__()
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class ClassicMissingMethods:
pass
self.assertRaises(TypeError, int, ClassicMissingMethods())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MissingMethods(object):
pass
self.assertRaises(TypeError, int, MissingMethods())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Foo0:
def __int__(self):
return 42
self.assertEqual(int(Foo0()), 42)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Classic:
pass
for base in (object, Classic):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class IntOverridesTrunc(base):
def __int__(self):
return 42
@ -532,14 +532,14 @@ class IntTestCases(__TestCase):
return -12
self.assertEqual(int(IntOverridesTrunc()), 42)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class JustTrunc(base):
def __trunc__(self):
return 42
with self.assertWarns(DeprecationWarning):
self.assertEqual(int(JustTrunc()), 42)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class ExceptionalTrunc(base):
def __trunc__(self):
1 / 0
@ -548,7 +548,7 @@ class IntTestCases(__TestCase):
int(ExceptionalTrunc())
for trunc_result_base in (object, Classic):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Index(trunc_result_base):
def __index__(self):
return 42
@ -559,7 +559,7 @@ class IntTestCases(__TestCase):
with self.assertWarns(DeprecationWarning):
self.assertEqual(int(TruncReturnsNonInt()), 42)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Intable(trunc_result_base):
def __int__(self):
return 42
@ -570,7 +570,7 @@ class IntTestCases(__TestCase):
with self.assertWarns(DeprecationWarning):
self.assertEqual(int(TruncReturnsNonInt()), 42)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class NonIntegral(trunc_result_base):
def __trunc__(self):
# Check that we avoid infinite recursion.
@ -590,7 +590,7 @@ class IntTestCases(__TestCase):
self.fail("Failed to raise TypeError with %s" %
((base, trunc_result_base),))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# Regression test for bugs.python.org/issue16060.
class BadInt(trunc_result_base):
def __int__(self):
@ -605,7 +605,7 @@ class IntTestCases(__TestCase):
int(TruncReturnsBadInt())
def test_int_subclass_with_index(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyIndex(int):
def __index__(self):
return 42
@ -621,7 +621,7 @@ class IntTestCases(__TestCase):
self.assertEqual(int(BadIndex()), 0)
def test_int_subclass_with_int(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyInt(int):
def __int__(self):
return 42
@ -639,7 +639,7 @@ class IntTestCases(__TestCase):
self.assertRaises(TypeError, int, my_int)
def test_int_returns_int_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadIndex:
def __index__(self):
return True

View File

@ -103,7 +103,7 @@ index 1b9f3cf7624..6560c7423a6 100644
- # the pointers after this call
- list(self.iterator)
- return other == self.name
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class CustomStr:
+ def __init__(self, name, iterator):
+ self.name = name
@ -127,7 +127,7 @@ index 1b9f3cf7624..6560c7423a6 100644
- class IterClass(object):
- def __iter__(self):
- return self
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class IterClass(object):
+ def __iter__(self):
+ return self
@ -143,7 +143,7 @@ index 1b9f3cf7624..6560c7423a6 100644
- if i == 10:
- raise RuntimeError
- return SequenceClass.__getitem__(self, i)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MySequenceClass(SequenceClass):
+ def __getitem__(self, i):
+ if i == 10:
@ -161,7 +161,7 @@ index 1b9f3cf7624..6560c7423a6 100644
- if i == 10:
- raise StopIteration
- return SequenceClass.__getitem__(self, i)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MySequenceClass(SequenceClass):
+ def __getitem__(self, i):
+ if i == 10:
@ -179,7 +179,7 @@ index 1b9f3cf7624..6560c7423a6 100644
- self.truth = truth
- def __bool__(self):
- return self.truth
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Boolean:
+ def __init__(self, truth):
+ self.truth = truth
@ -206,7 +206,7 @@ index 1b9f3cf7624..6560c7423a6 100644
- else:
- raise StopIteration
- return SeqIter(self.vals)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Seq:
+ def __init__(self, *args):
+ self.vals = args
@ -243,7 +243,7 @@ index 1b9f3cf7624..6560c7423a6 100644
- class IntsFrom:
- def __init__(self, start):
- self.i = start
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class IntsFrom:
+ def __init__(self, start):
+ self.i = start
@ -273,7 +273,7 @@ index 1b9f3cf7624..6560c7423a6 100644
- if i >= 5:
- raise IndexError
- return i
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class NoGuessLen5:
+ def __getitem__(self, i):
+ if i >= 5:
@ -304,7 +304,7 @@ index 1b9f3cf7624..6560c7423a6 100644
- def __init__(self, seq):
- self.it = iter(seq)
- self.i = 0
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class OhPhooey:
+ def __init__(self, seq):
+ self.it = iter(seq)
@ -339,7 +339,7 @@ index 1b9f3cf7624..6560c7423a6 100644
- self.start = start
- self.finish = finish
- self.i = self.start
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Iterator:
+ def __init__(self, start, finish):
+ self.start = start
@ -393,7 +393,7 @@ index 1b9f3cf7624..6560c7423a6 100644
- cls = self.__class__
- assert cls.count > 0
- cls.count -= 1
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ count = 0
+ def __new__(cls):
@ -416,7 +416,7 @@ index 1b9f3cf7624..6560c7423a6 100644
- def __next__(self):
- del BadIterator.__next__
- return 1
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadIterator(object):
+ def __iter__(self):
+ return self

View File

@ -314,7 +314,7 @@ class TestCase(__TestCase):
def run(builtin_name, item, sentinel=None):
it = iter(item) if sentinel is None else iter(item, sentinel)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class CustomStr:
def __init__(self, name, iterator):
self.name = name
@ -377,7 +377,7 @@ class TestCase(__TestCase):
# Test a new_style class with __iter__ but no next() method
def test_new_style_iter_class(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class IterClass(object):
def __iter__(self):
return self
@ -449,7 +449,7 @@ class TestCase(__TestCase):
# Test exception propagation through sequence iterator
def test_exception_sequence(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MySequenceClass(SequenceClass):
def __getitem__(self, i):
if i == 10:
@ -466,7 +466,7 @@ class TestCase(__TestCase):
# Test for StopIteration from __getitem__
def test_stop_sequence(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MySequenceClass(SequenceClass):
def __getitem__(self, i):
if i == 10:
@ -598,7 +598,7 @@ class TestCase(__TestCase):
self.assertRaises(TypeError, filter, None, list)
self.assertRaises(TypeError, filter, None, 42)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Boolean:
def __init__(self, truth):
self.truth = truth
@ -607,7 +607,7 @@ class TestCase(__TestCase):
bTrue = Boolean(True)
bFalse = Boolean(False)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Seq:
def __init__(self, *args):
self.vals = args
@ -713,7 +713,7 @@ class TestCase(__TestCase):
self.assertEqual(list(d.items()), list(zip(d, d.values())))
# Generate all ints starting at constructor arg.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class IntsFrom:
def __init__(self, start):
self.i = start
@ -747,7 +747,7 @@ class TestCase(__TestCase):
self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)])
# Classes that lie about their lengths.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class NoGuessLen5:
def __getitem__(self, i):
if i >= 5:
@ -780,7 +780,7 @@ class TestCase(__TestCase):
# This class inserts a Unicode object into its argument's natural
# iteration, in the 3rd position.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class OhPhooey:
def __init__(self, seq):
self.it = iter(seq)
@ -958,7 +958,7 @@ class TestCase(__TestCase):
f.writelines({})
# Try a big chunk too.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Iterator:
def __init__(self, start, finish):
self.start = start
@ -1054,7 +1054,7 @@ class TestCase(__TestCase):
@cpython_only
def test_ref_counting_behavior(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
count = 0
def __new__(cls):
@ -1154,7 +1154,7 @@ class TestCase(__TestCase):
def test_3720(self):
# Avoid a crash, when an iterator deletes its next() method.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadIterator(object):
def __iter__(self):
return self

View File

@ -224,7 +224,7 @@ index 7d5ba727389..ff514815da2 100644
- self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- self.pickletest(proto, filterfalse(isEven, range(6)))
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ self.pickletest(proto, filterfalse(isEven, range(6)))

View File

@ -1056,7 +1056,7 @@ class TestBasicOps(__TestCase):
self.assertRaises(TypeError, filterfalse, lambda x:x)
self.assertRaises(TypeError, filterfalse, lambda x:x, range(6), 7)
self.assertRaises(TypeError, filterfalse, isEven, 3)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
self.pickletest(proto, filterfalse(isEven, range(6)))
@ -1358,7 +1358,7 @@ class TestBasicOps(__TestCase):
argtypes = ['', 'abc', '', range(0), range(4), dict(a=1, b=2, c=3),
set('abcdefg'), range(11), tuple(range(13))]
for i in range(100):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
args = [choice(argtypes) for j in range(randrange(5))]
expected_len = prod(map(len, args))
self.assertEqual(len(list(product(*args))), expected_len)

View File

@ -77,7 +77,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
def test_keywords_in_subclass(self):
- class subclass(list):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass(list):
+ pass
u = subclass([1, 2])
@ -90,7 +90,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
- def __init__(self, seq, newarg=None):
- super().__init__(seq)
- self.newarg = newarg
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_init(list):
+ def __init__(self, seq, newarg=None):
+ super().__init__(seq)
@ -105,7 +105,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
- self = super().__new__(cls, seq)
- self.newarg = newarg
- return self
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_new(list):
+ def __new__(cls, seq, newarg=None):
+ self = super().__new__(cls, seq)
@ -126,7 +126,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
- except IndexError:
- pass
- return 'obj'
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Obj:
+ @staticmethod
+ def __repr__():
@ -143,7 +143,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
# optimization causes failures in code that relies on distinct
# function addresses.
- class L(list): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class L(list): pass
with self.assertRaises(TypeError):
(3,) + L([1,2])
@ -164,7 +164,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
- def __eq__(self, other):
- list3.clear()
- return NotImplemented
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X:
+ def __eq__(self,other) :
+ list2.clear()
@ -191,7 +191,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
- def __lt__(self, other):
- other.clear()
- return NotImplemented
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # See gh-120298
+ class evil:
+ def __lt__(self, other):
@ -210,7 +210,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
- def __iter__(self):
- yield from self.lst
- self.lst.clear()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ # See gh-120384
+ class evil:
+ def __init__(self, lst):
@ -229,7 +229,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
- def __eq__(self, other):
- lst.clear()
- return NotImplemented
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X:
+ def __eq__(self, other):
+ lst.clear()
@ -243,7 +243,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
- def __eq__(self, other):
- str(other)
- return NotImplemented
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class L(list):
+ def __eq__(self, other):
+ str(other)

View File

@ -101,7 +101,7 @@ class ListTest(list_tests.CommonTest):
list(sequence=[])
def test_keywords_in_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass(list):
pass
u = subclass([1, 2])
@ -110,7 +110,7 @@ class ListTest(list_tests.CommonTest):
with self.assertRaises(TypeError):
subclass(sequence=())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_init(list):
def __init__(self, seq, newarg=None):
super().__init__(seq)
@ -120,7 +120,7 @@ class ListTest(list_tests.CommonTest):
self.assertEqual(list(u), [1, 2])
self.assertEqual(u.newarg, 3)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_new(list):
def __new__(cls, seq, newarg=None):
self = super().__new__(cls, seq)
@ -172,7 +172,7 @@ class ListTest(list_tests.CommonTest):
lst *= size
def test_repr_mutate(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Obj:
@staticmethod
def __repr__():
@ -276,14 +276,14 @@ class ListTest(list_tests.CommonTest):
# Issue 8847: In the PGO build, the MSVC linker's COMDAT folding
# optimization causes failures in code that relies on distinct
# function addresses.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class L(list): pass
with self.assertRaises(TypeError):
(3,) + L([1,2])
def test_equal_operator_modifying_operand(self):
# test fix for seg fault reported in bpo-38588 part 2.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X:
def __eq__(self,other) :
list2.clear()
@ -308,7 +308,7 @@ class ListTest(list_tests.CommonTest):
self.assertFalse(list3 == list4)
def test_lt_operator_modifying_operand(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# See gh-120298
class evil:
def __lt__(self, other):
@ -320,7 +320,7 @@ class ListTest(list_tests.CommonTest):
a[0] < a
def test_list_index_modifing_operand(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
# See gh-120384
class evil:
def __init__(self, lst):
@ -346,7 +346,7 @@ class ListTest(list_tests.CommonTest):
# bpo-38610: The count(), index(), and remove() methods were not
# holding strong references to list elements while calling
# PyObject_RichCompareBool().
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X:
def __eq__(self, other):
lst.clear()
@ -356,7 +356,7 @@ class ListTest(list_tests.CommonTest):
with self.assertRaises(ValueError):
lst.index(lst)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class L(list):
def __eq__(self, other):
str(other)

View File

@ -87,7 +87,7 @@ index 5ee3055c871..5402cdc4a6c 100644
- pass
- class TestBadCeil:
- __ceil__ = BadDescr()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class TestCeil:
+ def __ceil__(self):
+ return 42
@ -123,7 +123,7 @@ index 5ee3055c871..5402cdc4a6c 100644
- pass
- class TestBadFloor:
- __floor__ = BadDescr()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class TestFloor:
+ def __floor__(self):
+ return 42
@ -143,7 +143,7 @@ index 5ee3055c871..5402cdc4a6c 100644
# Verify tuple subclasses are allowed
- class T(tuple):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class T(tuple):
+ pass
self.assertEqual(dist(T((1, 2, 3)), ((4, 2, -1))), 5.0)
@ -155,7 +155,7 @@ index 5ee3055c871..5402cdc4a6c 100644
- class BadFloat:
- __float__ = BadDescr()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadFloat:
+ __float__ = BadDescr()
@ -176,7 +176,7 @@ index 5ee3055c871..5402cdc4a6c 100644
- class IntegerLike(object):
- def __init__(self, value):
- self.value = value
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class IntegerLike(object):
+ def __init__(self, value):
+ self.value = value
@ -233,7 +233,7 @@ index 5ee3055c871..5402cdc4a6c 100644
- raise RuntimeError
- def __rmul__(self, other):
- raise RuntimeError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadMultiply:
+ def __mul__(self, other):
+ raise RuntimeError
@ -265,7 +265,7 @@ index 5ee3055c871..5402cdc4a6c 100644
- __rmul__ = __mul__
- def __repr__(self):
- return f'Flt({int(self)})'
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Int(int):
+ def __add__(self, other):
+ return Int(int(self) + int(other))
@ -302,7 +302,7 @@ index 5ee3055c871..5402cdc4a6c 100644
- pass
- class TestBadTrunc:
- __trunc__ = BadDescr()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class TestTrunc:
+ def __trunc__(self):
+ return 23
@ -323,7 +323,7 @@ index 5ee3055c871..5402cdc4a6c 100644
- class BadMultiply:
- def __rmul__(self, other):
- raise RuntimeError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class BadMultiply:
+ def __rmul__(self, other):
+ raise RuntimeError
@ -362,7 +362,7 @@ index 5ee3055c871..5402cdc4a6c 100644
- def __float__(self):
- self.converted = True
- 1/0
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class F:
+ def __float__(self):
+ self.converted = True

View File

@ -475,7 +475,7 @@ class MathTests(__TestCase):
#self.assertEqual(math.ceil(NINF), NINF)
#self.assertTrue(math.isnan(math.ceil(NAN)))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class TestCeil:
def __ceil__(self):
return 42
@ -633,7 +633,7 @@ class MathTests(__TestCase):
#self.assertEqual(math.ceil(NINF), NINF)
#self.assertTrue(math.isnan(math.floor(NAN)))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class TestFloor:
def __floor__(self):
return 42
@ -1056,7 +1056,7 @@ class MathTests(__TestCase):
)
# Verify tuple subclasses are allowed
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class T(tuple):
pass
self.assertEqual(dist(T((1, 2, 3)), ((4, 2, -1))), 5.0)
@ -1090,7 +1090,7 @@ class MathTests(__TestCase):
with self.assertRaises(TypeError):
dist([1], 2)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadFloat:
__float__ = BadDescr()
@ -1165,7 +1165,7 @@ class MathTests(__TestCase):
self.assertIs(type(s), int)
self.assertEqual(s, 0)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class IntegerLike(object):
def __init__(self, value):
self.value = value
@ -1399,7 +1399,7 @@ class MathTests(__TestCase):
self.assertEqual(sumprod([1], BasicIterClass(1)), 0)
# Error in multiplication
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadMultiply:
def __mul__(self, other):
raise RuntimeError
@ -1449,7 +1449,7 @@ class MathTests(__TestCase):
Decimal = decimal.Decimal
Fraction = fractions.Fraction
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Int(int):
def __add__(self, other):
return Int(int(self) + int(other))
@ -1988,7 +1988,7 @@ class MathTests(__TestCase):
self.assertEqual(math.trunc(-0.999999), -0)
self.assertEqual(math.trunc(-100.999), -100)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class TestTrunc:
def __trunc__(self):
return 23
@ -2231,7 +2231,7 @@ class MathTests(__TestCase):
self.assertEqual(prod([1., F(3, 2)]), 1.5)
# Error in multiplication
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class BadMultiply:
def __rmul__(self, other):
raise RuntimeError
@ -2540,7 +2540,7 @@ class MathTests(__TestCase):
def test_issue39871(self):
# A SystemError should not be raised if the first arg to atan2(),
# copysign(), or remainder() cannot be converted to a float.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class F:
def __float__(self):
self.converted = True

View File

@ -33,7 +33,7 @@ index d90f820052c..5d9fdfb70a4 100644
- class C(object):
- def __eq__(self, other):
- raise SyntaxError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ def __eq__(self, other):
+ raise SyntaxError
@ -47,7 +47,7 @@ index d90f820052c..5d9fdfb70a4 100644
- class C(object):
- def __ne__(self, other):
- raise SyntaxError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ def __ne__(self, other):
+ raise SyntaxError
@ -61,7 +61,7 @@ index d90f820052c..5d9fdfb70a4 100644
- class M:
- def __matmul__(self, other):
- return other - 1
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class M:
+ def __matmul__(self, other):
+ return other - 1
@ -75,7 +75,7 @@ index d90f820052c..5d9fdfb70a4 100644
- class C(object):
- def __bool__(self):
- raise SyntaxError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ def __bool__(self):
+ raise SyntaxError
@ -88,7 +88,7 @@ index d90f820052c..5d9fdfb70a4 100644
operator = self.module
- class A:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ pass
a = A()
@ -101,7 +101,7 @@ index d90f820052c..5d9fdfb70a4 100644
- class C(object):
- def __getattr__(self, name):
- raise SyntaxError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ def __getattr__(self, name):
+ raise SyntaxError
@ -115,7 +115,7 @@ index d90f820052c..5d9fdfb70a4 100644
- class C(object):
- def __getitem__(self, name):
- raise SyntaxError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ def __getitem__(self, name):
+ raise SyntaxError
@ -129,7 +129,7 @@ index d90f820052c..5d9fdfb70a4 100644
- class T(tuple):
- 'Tuple subclass'
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class T(tuple):
+ 'Tuple subclass'
+ pass
@ -147,7 +147,7 @@ index d90f820052c..5d9fdfb70a4 100644
- return f
- def baz(*args, **kwds):
- return kwds['name'], kwds['self']
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ def foo(self, *args, **kwds):
+ return args[0] + args[1]
@ -177,7 +177,7 @@ index d90f820052c..5d9fdfb70a4 100644
- def __itruediv__ (self, other): return "itruediv"
- def __ixor__ (self, other): return "ixor"
- def __getitem__(self, other): return 5 # so that C is a sequence
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ def __iadd__ (self, other): return "iadd"
+ def __iand__ (self, other): return "iand"
@ -203,7 +203,7 @@ index d90f820052c..5d9fdfb70a4 100644
- class X:
- def __index__(self):
- return 1
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X:
+ def __index__(self):
+ return 1
@ -217,7 +217,7 @@ index d90f820052c..5d9fdfb70a4 100644
- class C:
- def __bool__(self):
- raise SyntaxError
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ def __bool__(self):
+ raise SyntaxError
@ -231,7 +231,7 @@ index d90f820052c..5d9fdfb70a4 100644
- class X(object):
- def __init__(self, value):
- self.value = value
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X(object):
+ def __init__(self, value):
+ self.value = value
@ -254,7 +254,7 @@ index d90f820052c..5d9fdfb70a4 100644
operator.length_hint(X(LookupError))
- class Y: pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Y: pass
msg = "'str' object cannot be interpreted as an integer"
@ -279,7 +279,7 @@ index d90f820052c..5d9fdfb70a4 100644
attrgetter = self.module.attrgetter
- class A:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ pass
a = A()
@ -296,7 +296,7 @@ index d90f820052c..5d9fdfb70a4 100644
- return f
- def baz(*args, **kwds):
- return kwds['name'], kwds['self']
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ def foo(self, *args, **kwds):
+ return args[0] + args[1]

View File

@ -104,7 +104,7 @@ class OperatorTestCase:
def test_eq(self):
operator = self.module
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
def __eq__(self, other):
raise SyntaxError
@ -119,7 +119,7 @@ class OperatorTestCase:
def test_ne(self):
operator = self.module
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
def __ne__(self, other):
raise SyntaxError
@ -267,7 +267,7 @@ class OperatorTestCase:
operator = self.module
self.assertRaises(TypeError, operator.matmul)
self.assertRaises(TypeError, operator.matmul, 42, 42)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class M:
def __matmul__(self, other):
return other - 1
@ -338,7 +338,7 @@ class OperatorTestCase:
def test_truth(self):
operator = self.module
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
def __bool__(self):
raise SyntaxError
@ -373,7 +373,7 @@ class OperatorTestCase:
def test_attrgetter(self):
operator = self.module
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
pass
a = A()
@ -396,7 +396,7 @@ class OperatorTestCase:
self.assertEqual(operator.attrgetter('x','z','y')(record), ('X', 'Z', 'Y'))
self.assertRaises(TypeError, operator.attrgetter, ('x', (), 'y'))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
def __getattr__(self, name):
raise SyntaxError
@ -437,7 +437,7 @@ class OperatorTestCase:
f = operator.itemgetter(10)
self.assertRaises(IndexError, f, a)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
def __getitem__(self, name):
raise SyntaxError
@ -471,7 +471,7 @@ class OperatorTestCase:
self.assertEqual(operator.itemgetter(slice(2, 4))(t), ('c', 'd'))
# interesting sequences
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class T(tuple):
'Tuple subclass'
pass
@ -483,7 +483,7 @@ class OperatorTestCase:
operator = self.module
self.assertRaises(TypeError, operator.methodcaller)
self.assertRaises(TypeError, operator.methodcaller, 12)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
def foo(self, *args, **kwds):
return args[0] + args[1]
@ -509,7 +509,7 @@ class OperatorTestCase:
def test_inplace(self):
operator = self.module
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
def __iadd__ (self, other): return "iadd"
def __iand__ (self, other): return "iand"
@ -550,7 +550,7 @@ class OperatorTestCase:
def test_index(self):
operator = self.module
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X:
def __index__(self):
return 1
@ -570,7 +570,7 @@ class OperatorTestCase:
def test_not_(self):
operator = self.module
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
def __bool__(self):
raise SyntaxError
@ -583,7 +583,7 @@ class OperatorTestCase:
def test_length_hint(self):
operator = self.module
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X(object):
def __init__(self, value):
self.value = value
@ -607,7 +607,7 @@ class OperatorTestCase:
with self.assertRaises(LookupError):
operator.length_hint(X(LookupError))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Y: pass
msg = "'str' object cannot be interpreted as an integer"
@ -679,7 +679,7 @@ class OperatorPickleTestCase:
def test_attrgetter(self):
attrgetter = self.module.attrgetter
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
pass
a = A()
@ -723,7 +723,7 @@ class OperatorPickleTestCase:
def test_methodcaller(self):
methodcaller = self.module.methodcaller
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
def foo(self, *args, **kwds):
return args[0] + args[1]

View File

@ -74,7 +74,7 @@ index a9b6a84996e..efc4288d1a4 100644
- def items(self):
- calls.append('items')
- return ()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Spam:
+ def keys(self):
+ calls.append('keys')
@ -92,7 +92,7 @@ index a9b6a84996e..efc4288d1a4 100644
- class ODNI(OrderedDict):
- def __init__(*args, **kwargs):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class ODNI(OrderedDict):
+ def __init__(*args, **kwargs):
+ pass
@ -106,7 +106,7 @@ index a9b6a84996e..efc4288d1a4 100644
- class Missing(OrderedDict):
- def __missing__(self, key):
- return 0
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Missing(OrderedDict):
+ def __missing__(self, key):
+ return 0
@ -120,7 +120,7 @@ index a9b6a84996e..efc4288d1a4 100644
- class Missing(OrderedDict):
- def __missing__(self, key):
- return 0
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Missing(OrderedDict):
+ def __missing__(self, key):
+ return 0
@ -134,7 +134,7 @@ index a9b6a84996e..efc4288d1a4 100644
- class MyOD(OrderedDict):
- def update(self, *args, **kwds):
- raise Exception()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyOD(OrderedDict):
+ def update(self, *args, **kwds):
+ raise Exception()
@ -148,7 +148,7 @@ index a9b6a84996e..efc4288d1a4 100644
- class MyOD(OrderedDict):
- def __del__(self):
- deleted.append(self.i)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyOD(OrderedDict):
+ def __del__(self):
+ deleted.append(self.i)
@ -172,7 +172,7 @@ index a9b6a84996e..efc4288d1a4 100644
- return False
- def __repr__(self):
- return self.value
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key:
+ def __init__(self, hash):
+ self._hash = hash
@ -196,7 +196,7 @@ index a9b6a84996e..efc4288d1a4 100644
- class Key:
- def __hash__(self):
- return randrange(100000)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key:
+ def __hash__(self):
+ return randrange(100000)
@ -210,7 +210,7 @@ index a9b6a84996e..efc4288d1a4 100644
- class Key:
- def __hash__(self):
- return 1
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key:
+ def __hash__(self):
+ return 1
@ -231,7 +231,7 @@ index a9b6a84996e..efc4288d1a4 100644
- class Key(_TriggerSideEffectOnEqual):
- def side_effect(self):
- del dict1[TODEL]
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key(_TriggerSideEffectOnEqual):
+ def side_effect(self):
+ del dict1[TODEL]
@ -262,7 +262,7 @@ index a9b6a84996e..efc4288d1a4 100644
- class Key(_TriggerSideEffectOnEqual):
- def side_effect(self):
- dict1.clear()
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key(_TriggerSideEffectOnEqual):
+ def side_effect(self):
+ dict1.clear()
@ -276,7 +276,7 @@ index a9b6a84996e..efc4288d1a4 100644
- class Key(_TriggerSideEffectOnEqual):
- def side_effect(self):
- del dict1[TODEL]
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key(_TriggerSideEffectOnEqual):
+ def side_effect(self):
+ del dict1[TODEL]
@ -291,7 +291,7 @@ index a9b6a84996e..efc4288d1a4 100644
- def side_effect(self):
- dict1.clear()
- dict1['a'] = dict1['b'] = 'c'
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key(_TriggerSideEffectOnEqual):
+ def side_effect(self):
+ dict1.clear()
@ -307,7 +307,7 @@ index a9b6a84996e..efc4288d1a4 100644
- def side_effect(self):
- del dict1[TODEL]
- dict1['a'] = 'c'
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key(_TriggerSideEffectOnEqual):
+ def side_effect(self):
+ del dict1[TODEL]
@ -323,7 +323,7 @@ index a9b6a84996e..efc4288d1a4 100644
- trigger = 0
- def side_effect(self):
- del dict1[TODEL]
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Key(_TriggerSideEffectOnEqual):
+ trigger = 0
+ def side_effect(self):

View File

@ -170,7 +170,7 @@ class OrderedDictTests:
def test_init_calls(self):
calls = []
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Spam:
def keys(self):
calls.append('keys')
@ -187,7 +187,7 @@ class OrderedDictTests:
# a consistent internal state is created in __new__
# rather than __init__.
OrderedDict = self.OrderedDict
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class ODNI(OrderedDict):
def __init__(*args, **kwargs):
pass
@ -326,7 +326,7 @@ class OrderedDictTests:
self.assertEqual(od.pop(k, 12345), 12345)
# make sure pop still works when __missing__ is defined
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Missing(OrderedDict):
def __missing__(self, key):
return 0
@ -476,7 +476,7 @@ class OrderedDictTests:
self.assertEqual(od.setdefault('g', default=9), 9)
# make sure setdefault still works when __missing__ is defined
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Missing(OrderedDict):
def __missing__(self, key):
return 0
@ -545,7 +545,7 @@ class OrderedDictTests:
def test_override_update(self):
OrderedDict = self.OrderedDict
# Verify that subclasses can override update() without breaking __init__()
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyOD(OrderedDict):
def update(self, *args, **kwds):
raise Exception()
@ -569,7 +569,7 @@ class OrderedDictTests:
# should not crash Python.
OrderedDict = self.OrderedDict
deleted = []
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyOD(OrderedDict):
def __del__(self):
deleted.append(self.i)
@ -584,7 +584,7 @@ class OrderedDictTests:
def test_delitem_hash_collision(self):
OrderedDict = self.OrderedDict
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key:
def __init__(self, hash):
self._hash = hash
@ -624,7 +624,7 @@ class OrderedDictTests:
def test_issue24347(self):
OrderedDict = self.OrderedDict
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key:
def __hash__(self):
return randrange(100000)
@ -647,7 +647,7 @@ class OrderedDictTests:
def test_issue24348(self):
OrderedDict = self.OrderedDict
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key:
def __hash__(self):
return 1
@ -832,7 +832,7 @@ class PurePythonOrderedDictTests(OrderedDictTests, __TestCase):
OrderedDict = py_coll.OrderedDict
def test_issue119004_attribute_error(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key(_TriggerSideEffectOnEqual):
def side_effect(self):
del dict1[TODEL]
@ -875,7 +875,7 @@ class CPythonOrderedDictSideEffects:
self.assertRaisesRegex(RuntimeError, msg, operator.eq, dict1, dict2)
def test_issue119004_change_size_by_clear(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key(_TriggerSideEffectOnEqual):
def side_effect(self):
dict1.clear()
@ -888,7 +888,7 @@ class CPythonOrderedDictSideEffects:
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
def test_issue119004_change_size_by_delete_key(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key(_TriggerSideEffectOnEqual):
def side_effect(self):
del dict1[TODEL]
@ -902,7 +902,7 @@ class CPythonOrderedDictSideEffects:
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
def test_issue119004_change_linked_list_by_clear(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key(_TriggerSideEffectOnEqual):
def side_effect(self):
dict1.clear()
@ -916,7 +916,7 @@ class CPythonOrderedDictSideEffects:
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
def test_issue119004_change_linked_list_by_delete_key(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key(_TriggerSideEffectOnEqual):
def side_effect(self):
del dict1[TODEL]
@ -931,7 +931,7 @@ class CPythonOrderedDictSideEffects:
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
def test_issue119004_change_size_by_delete_key_in_dict_eq(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Key(_TriggerSideEffectOnEqual):
trigger = 0
def side_effect(self):

View File

@ -87,7 +87,7 @@ index d9102eb98a5..c8ee5ca451f 100644
- return self.value
- def __deepcopy__(self, memo=None):
- return Tracer(self.value + 1)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Tracer:
+ def __init__(self, value):
+ self.value = value
@ -104,7 +104,7 @@ index d9102eb98a5..c8ee5ca451f 100644
# Create a nest of cycles to exercise overall ref count check
- class A:
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ pass
s = set(A() for i in range(1000))
@ -117,7 +117,7 @@ index d9102eb98a5..c8ee5ca451f 100644
- class H(self.thetype):
- def __hash__(self):
- return int(id(self) & 0x7fffffff)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class H(self.thetype):
+ def __hash__(self):
+ return int(id(self) & 0x7fffffff)
@ -130,7 +130,7 @@ index d9102eb98a5..c8ee5ca451f 100644
# Bug #3680: tp_traverse was not implemented for set iterator object
- class C(object):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ pass
obj = C()
@ -162,7 +162,7 @@ index d9102eb98a5..c8ee5ca451f 100644
- def __le__(self, some_set):
- self.le_called = True
- return False
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class TestRichSetCompare:
+ def __gt__(self, some_set):
+ self.gt_called = True
@ -185,7 +185,7 @@ index d9102eb98a5..c8ee5ca451f 100644
def test_keywords_in_subclass(self):
- class subclass(set):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass(set):
+ pass
u = subclass([1, 2])
@ -198,7 +198,7 @@ index d9102eb98a5..c8ee5ca451f 100644
- def __init__(self, arg, newarg=None):
- super().__init__(arg)
- self.newarg = newarg
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_init(set):
+ def __init__(self, arg, newarg=None):
+ super().__init__(arg)
@ -213,7 +213,7 @@ index d9102eb98a5..c8ee5ca451f 100644
- self = super().__new__(cls, arg)
- self.newarg = newarg
- return self
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_new(set):
+ def __new__(cls, arg, newarg=None):
+ self = super().__new__(cls, arg)
@ -237,7 +237,7 @@ index d9102eb98a5..c8ee5ca451f 100644
def test_keywords_in_subclass(self):
- class subclass(frozenset):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass(frozenset):
+ pass
u = subclass([1, 2])
@ -249,7 +249,7 @@ index d9102eb98a5..c8ee5ca451f 100644
- class subclass_with_init(frozenset):
- def __init__(self, arg, newarg=None):
- self.newarg = newarg
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_init(frozenset):
+ def __init__(self, arg, newarg=None):
+ self.newarg = newarg
@ -263,7 +263,7 @@ index d9102eb98a5..c8ee5ca451f 100644
- self = super().__new__(cls, arg)
- self.newarg = newarg
- return self
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_new(frozenset):
+ def __new__(cls, arg, newarg=None):
+ self = super().__new__(cls, arg)
@ -700,7 +700,7 @@ index d9102eb98a5..c8ee5ca451f 100644
- def __eq__(self, o):
- other.clear()
- return False
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class X:
+ def __hash__(self):
+ return hash(0)
@ -733,7 +733,7 @@ index d9102eb98a5..c8ee5ca451f 100644
- return bool(randrange(2))
- def __hash__(self):
- return randrange(2)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Bad:
+ def __eq__(self, other):
+ if not enabled:

View File

@ -315,7 +315,7 @@ class _TestJointOps:
self.assertEqual(self.thetype(it), data - self.thetype((drop,)))
def test_deepcopy(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Tracer:
def __init__(self, value):
self.value = value
@ -334,7 +334,7 @@ class _TestJointOps:
def test_gc(self):
# Create a nest of cycles to exercise overall ref count check
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class A:
pass
s = set(A() for i in range(1000))
@ -345,7 +345,7 @@ class _TestJointOps:
def test_subclass_with_custom_hash(self):
# Bug #1257731
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class H(self.thetype):
def __hash__(self):
return int(id(self) & 0x7fffffff)
@ -399,7 +399,7 @@ class _TestJointOps:
def test_container_iterator(self):
# Bug #3680: tp_traverse was not implemented for set iterator object
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C(object):
pass
obj = C()
@ -658,7 +658,7 @@ class TestSet(_TestJointOps, __TestCase):
self.assertRaises(ReferenceError, str, p)
def test_rich_compare(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class TestRichSetCompare:
def __gt__(self, some_set):
self.gt_called = True
@ -703,7 +703,7 @@ class TestSetSubclass(TestSet):
basetype = set
def test_keywords_in_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass(set):
pass
u = subclass([1, 2])
@ -712,7 +712,7 @@ class TestSetSubclass(TestSet):
with self.assertRaises(TypeError):
subclass(sequence=())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_init(set):
def __init__(self, arg, newarg=None):
super().__init__(arg)
@ -722,7 +722,7 @@ class TestSetSubclass(TestSet):
self.assertEqual(set(u), {1, 2})
self.assertEqual(u.newarg, 3)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_new(set):
def __new__(cls, arg, newarg=None):
self = super().__new__(cls, arg)
@ -818,7 +818,7 @@ class TestFrozenSetSubclass(TestFrozenSet):
basetype = frozenset
def test_keywords_in_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass(frozenset):
pass
u = subclass([1, 2])
@ -827,7 +827,7 @@ class TestFrozenSetSubclass(TestFrozenSet):
with self.assertRaises(TypeError):
subclass(sequence=())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_init(frozenset):
def __init__(self, arg, newarg=None):
self.newarg = newarg
@ -836,7 +836,7 @@ class TestFrozenSetSubclass(TestFrozenSet):
self.assertEqual(set(u), {1, 2})
self.assertEqual(u.newarg, 3)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_new(frozenset):
def __new__(cls, arg, newarg=None):
self = super().__new__(cls, arg)
@ -1907,7 +1907,7 @@ class TestWeirdBugs(__TestCase):
list(si)
def test_merge_and_mutate(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class X:
def __hash__(self):
return hash(0)
@ -1928,7 +1928,7 @@ class _TestOperationsMutating:
constructor2 = None
def make_sets_of_bad_objects(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Bad:
def __eq__(self, other):
if not enabled:

View File

@ -75,7 +75,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
- class Complains(object):
- maybe_complain = True
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class Complains(object):
+ maybe_complain = True
@ -142,7 +142,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
- else:
- L.append(3)
- return random.random() < 0.5
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ def __lt__(self, other):
+ if L and random.random() < 0.75:
@ -174,7 +174,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
- data[:] = range(20)
- def __lt__(self, other):
- return id(self) < id(other)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class SortKiller(object):
+ def __init__(self, x):
+ pass
@ -195,7 +195,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
- def __del__(self):
- del data[:]
- data[:] = list(range(20))
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class SortKiller(object):
+ def __init__(self, x):
+ if x > 2:
@ -223,7 +223,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
- def __lt__(self, other):
- elem.__class__ = WackyList2
- return int.__lt__(self, other)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class WackyComparator(int):
+ def __lt__(self, other):
+ elem.__class__ = WackyList2
@ -250,7 +250,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
- class PointlessComparator:
- def __lt__(self, other):
- return NotImplemented
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class PointlessComparator:
+ def __lt__(self, other):
+ return NotImplemented

View File

@ -102,7 +102,7 @@ class TestBase(__TestCase):
sizes.extend(range(n-1, n+2))
sizes.extend([10, 100, 1000])
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class Complains(object):
maybe_complain = True
@ -213,7 +213,7 @@ class TestBugs(__TestCase):
# If this fails, the most likely outcome is a core dump.
# Mutations during a list sort should raise a ValueError.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
def __lt__(self, other):
if L and random.random() < 0.75:
@ -284,7 +284,7 @@ class TestDecorateSortUndecorate(__TestCase):
def test_key_with_mutating_del(self):
data = list(range(10))
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class SortKiller(object):
def __init__(self, x):
pass
@ -298,7 +298,7 @@ class TestDecorateSortUndecorate(__TestCase):
def test_key_with_mutating_del_and_exception(self):
data = list(range(10))
## dup = data[:]
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class SortKiller(object):
def __init__(self, x):
if x > 2:
@ -389,7 +389,7 @@ class TestOptimizedCompares(__TestCase):
# This test is by ppperry. It ensures that unsafe_object_compare is
# verifying ms->key_richcompare == tp->richcompare before comparing.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class WackyComparator(int):
def __lt__(self, other):
elem.__class__ = WackyList2
@ -414,7 +414,7 @@ class TestOptimizedCompares(__TestCase):
# The following test is also by ppperry. It ensures that
# unsafe_object_compare handles Py_NotImplemented appropriately.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class PointlessComparator:
def __lt__(self, other):
return NotImplemented

View File

@ -68,7 +68,7 @@ index 9ce80c5e8ea..1080e85e31a 100644
def test_keywords_in_subclass(self):
- class subclass(tuple):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass(tuple):
+ pass
u = subclass([1, 2])
@ -80,7 +80,7 @@ index 9ce80c5e8ea..1080e85e31a 100644
- class subclass_with_init(tuple):
- def __init__(self, arg, newarg=None):
- self.newarg = newarg
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_init(tuple):
+ def __init__(self, arg, newarg=None):
+ self.newarg = newarg
@ -94,7 +94,7 @@ index 9ce80c5e8ea..1080e85e31a 100644
- self = super().__new__(cls, arg)
- self.newarg = newarg
- return self
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_new(tuple):
+ def __new__(cls, arg, newarg=None):
+ self = super().__new__(cls, arg)
@ -109,7 +109,7 @@ index 9ce80c5e8ea..1080e85e31a 100644
# Tuple subtypes must always be tracked
- class MyTuple(tuple):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class MyTuple(tuple):
+ pass
self.check_track_dynamic(MyTuple, True)
@ -120,7 +120,7 @@ index 9ce80c5e8ea..1080e85e31a 100644
# optimization causes failures in code that relies on distinct
# function addresses.
- class T(tuple): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class T(tuple): pass
with self.assertRaises(TypeError):
[3,] + T((1,2))

View File

@ -97,7 +97,7 @@ class TupleTest(seq_tests.CommonTest):
tuple(sequence=())
def test_keywords_in_subclass(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass(tuple):
pass
u = subclass([1, 2])
@ -106,7 +106,7 @@ class TupleTest(seq_tests.CommonTest):
with self.assertRaises(TypeError):
subclass(sequence=())
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_init(tuple):
def __init__(self, arg, newarg=None):
self.newarg = newarg
@ -115,7 +115,7 @@ class TupleTest(seq_tests.CommonTest):
self.assertEqual(list(u), [1, 2])
self.assertEqual(u.newarg, 3)
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class subclass_with_new(tuple):
def __new__(cls, arg, newarg=None):
self = super().__new__(cls, arg)
@ -408,7 +408,7 @@ class TupleTest(seq_tests.CommonTest):
@support.cpython_only
def test_track_subtypes(self):
# Tuple subtypes must always be tracked
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class MyTuple(tuple):
pass
self.check_track_dynamic(MyTuple, True)
@ -462,7 +462,7 @@ class TupleTest(seq_tests.CommonTest):
# Issue 8847: In the PGO build, the MSVC linker's COMDAT folding
# optimization causes failures in code that relies on distinct
# function addresses.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class T(tuple): pass
with self.assertRaises(TypeError):
[3,] + T((1,2))

View File

@ -72,7 +72,7 @@ index 312702c8e39..d3d8dbf394a 100644
- class T(self.type2test):
- def __getitem__(self, key):
- return str(key) + '!!!'
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class T(self.type2test):
+ def __getitem__(self, key):
+ return str(key) + '!!!'

View File

@ -110,7 +110,7 @@ class UserListTest(list_tests.CommonTest):
def test_getitemoverwriteiter(self):
# Verify that __getitem__ overrides *are* recognized by __iter__
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class T(self.type2test):
def __getitem__(self, key):
return str(key) + '!!!'

View File

@ -41,7 +41,7 @@ index 8e9ed8500c7..66c18ad886a 100644
- class LacksEnter(object):
- def __exit__(self, type, value, traceback):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class LacksEnter(object):
+ def __exit__(self, type, value, traceback):
+ pass
@ -54,7 +54,7 @@ index 8e9ed8500c7..66c18ad886a 100644
def testEnterAttributeError2(self):
- class LacksEnterAndExit(object):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class LacksEnterAndExit(object):
+ pass
@ -67,7 +67,7 @@ index 8e9ed8500c7..66c18ad886a 100644
- class LacksExit(object):
- def __enter__(self):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class LacksExit(object):
+ def __enter__(self):
+ pass
@ -83,7 +83,7 @@ index 8e9ed8500c7..66c18ad886a 100644
- raise RuntimeError("Enter threw")
- def __exit__(self, *args):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class EnterThrows(object):
+ def __enter__(self):
+ raise RuntimeError("Enter threw")
@ -101,7 +101,7 @@ index 8e9ed8500c7..66c18ad886a 100644
- return
- def __exit__(self, *args):
- raise RuntimeError(42)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class ExitThrows(object):
+ def __enter__(self):
+ return
@ -154,7 +154,7 @@ index 8e9ed8500c7..66c18ad886a 100644
- pass
- def __exit__(self, type, value, traceback):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class cm(object):
+ def __enter__(self):
+ pass
@ -172,7 +172,7 @@ index 8e9ed8500c7..66c18ad886a 100644
- pass
- def __exit__(self, type, value, traceback):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class cm (object):
+ def __enter__(self):
+ pass
@ -195,7 +195,7 @@ index 8e9ed8500c7..66c18ad886a 100644
- return 3
- def __exit__(self, a, b, c):
- return self.exit_result
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class cm(object):
+ def __init__(self, bool_conversion):
+ class Bool:
@ -232,7 +232,7 @@ index 8e9ed8500c7..66c18ad886a 100644
keys.sort()
self.assertEqual(keys, [1, 2])
- class C: pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C: pass
blah = C()
with mock_contextmanager_generator() as blah.foo:
@ -242,7 +242,7 @@ index 8e9ed8500c7..66c18ad886a 100644
- class C:
- def __enter__(self): return 1, 2, 3
- def __exit__(self, t, v, tb): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class C:
+ def __enter__(self): return 1, 2, 3
+ def __exit__(self, t, v, tb): pass
@ -254,7 +254,7 @@ index 8e9ed8500c7..66c18ad886a 100644
with C() as (targets[1], targets[2], targets[3]):
self.assertEqual(targets, {1: 1, 2: 2, 3: 3})
- class B: pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class B: pass
blah = B()
with C() as (blah.one, blah.two, blah.three):
@ -270,7 +270,7 @@ index 8e9ed8500c7..66c18ad886a 100644
- class AfricanSwallow:
- def __enter__(self): pass
- def __exit__(self, t, v, tb): return True
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class AfricanSwallow:
+ def __enter__(self): pass
+ def __exit__(self, t, v, tb): return True
@ -284,7 +284,7 @@ index 8e9ed8500c7..66c18ad886a 100644
- class EuropeanSwallow:
- def __enter__(self): pass
- def __exit__(self, t, v, tb): return False
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ with torch._dynamo.error_on_graph_break(False):
+ class EuropeanSwallow:
+ def __enter__(self): pass
+ def __exit__(self, t, v, tb): return False

View File

@ -131,7 +131,7 @@ class FailureTestCase(__TestCase):
self.assertRaises(NameError, fooNotDeclared)
def testEnterAttributeError1(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class LacksEnter(object):
def __exit__(self, type, value, traceback):
pass
@ -142,7 +142,7 @@ class FailureTestCase(__TestCase):
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnter)
def testEnterAttributeError2(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class LacksEnterAndExit(object):
pass
@ -152,7 +152,7 @@ class FailureTestCase(__TestCase):
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnterAndExit)
def testExitAttributeError(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class LacksExit(object):
def __enter__(self):
pass
@ -185,7 +185,7 @@ class FailureTestCase(__TestCase):
' pass')
def testEnterThrows(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class EnterThrows(object):
def __enter__(self):
raise RuntimeError("Enter threw")
@ -204,7 +204,7 @@ class FailureTestCase(__TestCase):
self.assertEqual(self.foo, None)
def testExitThrows(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class ExitThrows(object):
def __enter__(self):
return
@ -492,7 +492,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, __TestCase):
def testRaisedStopIteration2(self):
# From bug 1462485
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class cm(object):
def __enter__(self):
pass
@ -534,7 +534,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, __TestCase):
def testRaisedGeneratorExit2(self):
# From bug 1462485
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class cm (object):
def __enter__(self):
pass
@ -551,7 +551,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, __TestCase):
# issue4589: __exit__ return code may raise an exception
# when looking at its truth value.
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class cm(object):
def __init__(self, bool_conversion):
class Bool:
@ -650,14 +650,14 @@ class AssignmentTargetTestCase(__TestCase):
keys = list(targets.keys())
keys.sort()
self.assertEqual(keys, [1, 2])
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C: pass
blah = C()
with mock_contextmanager_generator() as blah.foo:
self.assertEqual(hasattr(blah, "foo"), True)
def testMultipleComplexTargets(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class C:
def __enter__(self): return 1, 2, 3
def __exit__(self, t, v, tb): pass
@ -668,7 +668,7 @@ class AssignmentTargetTestCase(__TestCase):
self.assertEqual(targets, {1: [3, 2, 1]})
with C() as (targets[1], targets[2], targets[3]):
self.assertEqual(targets, {1: 1, 2: 2, 3: 3})
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class B: pass
blah = B()
with C() as (blah.one, blah.two, blah.three):
@ -686,7 +686,7 @@ class AssignmentTargetTestCase(__TestCase):
class ExitSwallowsExceptionTestCase(__TestCase):
def testExitTrueSwallowsException(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class AfricanSwallow:
def __enter__(self): pass
def __exit__(self, t, v, tb): return True
@ -697,7 +697,7 @@ class ExitSwallowsExceptionTestCase(__TestCase):
self.fail("ZeroDivisionError should have been swallowed")
def testExitFalseDoesntSwallowException(self):
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.error_on_graph_break(False):
class EuropeanSwallow:
def __enter__(self): pass
def __exit__(self, t, v, tb): return False

View File

@ -1721,13 +1721,13 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
):
f4(torch.randn(3))
def test_set_fullgraph(self):
def test_error_on_graph_break(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
def f1(x):
x = x + 1
with torch._dynamo.set_fullgraph(False):
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.graph_break()
return x + 2
@ -1738,7 +1738,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
@torch.compile(backend=cnts)
def f2(x):
x = x + 1
with torch._dynamo.set_fullgraph(True):
with torch._dynamo.error_on_graph_break(True):
torch._dynamo.graph_break()
return x + 2
@ -1748,7 +1748,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
@torch.compile(backend=cnts, fullgraph=True)
def f3(x):
x = x + 1
with torch._dynamo.set_fullgraph(False):
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.graph_break()
x = x + 2
torch._dynamo.graph_break()
@ -1766,7 +1766,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
@torch.compile(backend=cnts, fullgraph=True)
def f4(x):
x = x + 1
with torch._dynamo.set_fullgraph(False):
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.skip_frame()
return inner_f4(x)
@ -1774,11 +1774,11 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
self.assertEqual(f4(inp), inp + 7)
self.assertEqual(cnts.frame_count, 2)
def test_set_fullgraph_nested(self):
# set_fullgraph in a nested frame
def test_error_on_graph_break_nested(self):
# error_on_graph_break in a nested frame
cnts = torch._dynamo.testing.CompileCounter()
@torch._dynamo.set_fullgraph(False)
@torch._dynamo.error_on_graph_break(False)
def inner_f5(x):
x = x + 2
torch._dynamo.graph_break()
@ -1795,7 +1795,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def inner_f6(x):
x = x + 2
with torch._dynamo.set_fullgraph(False):
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.graph_break()
return x + 4
@ -1810,7 +1810,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def inner_f7(x):
x = x + 2
with torch._dynamo.set_fullgraph(True):
with torch._dynamo.error_on_graph_break(True):
torch._dynamo.graph_break()
return x + 4
@ -1822,18 +1822,18 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
with self.assertRaises(Unsupported):
f7(inp)
def test_set_fullgraph_nested_with_skip(self):
# set_fullgraph in a nested frame with a skipped frame in between
def test_error_on_graph_break_nested_with_skip(self):
# error_on_graph_break in a nested frame with a skipped frame in between
cnts = torch._dynamo.testing.CompileCounter()
@torch._dynamo.set_fullgraph(False)
@torch._dynamo.error_on_graph_break(False)
def inner2_f8(x):
x = x + 2
torch._dynamo.graph_break()
return x + 4
def inner1_f8(x):
with torch._dynamo.set_fullgraph(False):
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.skip_frame()
return inner2_f8(x)
@ -1848,7 +1848,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def inner2_f9(x):
x = x + 2
with torch._dynamo.set_fullgraph(True):
with torch._dynamo.error_on_graph_break(True):
torch._dynamo.graph_break()
return x + 4
@ -1864,10 +1864,10 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
with self.assertRaises(Unsupported):
f9(inp)
# test export with set_fullgraph(False) still errors
# test export with error_on_graph_break(False) still errors
def test_set_fullgraph_export(self):
@torch._dynamo.set_fullgraph(False)
def test_error_on_graph_break_export(self):
@torch._dynamo.error_on_graph_break(False)
def inner(x):
x = x + 2
torch._dynamo.graph_break()
@ -1880,7 +1880,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
with self.assertRaises(Unsupported):
torch._dynamo.export(f)(torch.ones(3))
def test_set_fullgraph_nested_deep(self):
def test_error_on_graph_break_nested_deep(self):
cnts = torch._dynamo.testing.CompileCounter()
def inner1_f1(x):
@ -1892,7 +1892,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
return inner1_f1(x)
def inner3_f1(x):
with torch._dynamo.set_fullgraph(False):
with torch._dynamo.error_on_graph_break(False):
return inner2_f1(x)
def inner4_f1(x):
@ -1916,7 +1916,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
return inner1_f2(x)
def inner3_f2(x):
with torch._dynamo.set_fullgraph(True):
with torch._dynamo.error_on_graph_break(True):
return inner2_f2(x)
def inner4_f2(x):
@ -1930,20 +1930,20 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
with self.assertRaises(Unsupported):
f2(inp)
def test_set_fullgraph_error(self):
def test_error_on_graph_break_error(self):
@torch.compile(backend="eager")
def f1():
with torch._dynamo.set_fullgraph(foo="bar"):
with torch._dynamo.error_on_graph_break(foo="bar"):
pass
@torch.compile(backend="eager")
def f2():
with torch._dynamo.set_fullgraph():
with torch._dynamo.error_on_graph_break():
pass
@torch.compile(backend="eager")
def f3():
with torch._dynamo.set_fullgraph("foo"):
with torch._dynamo.error_on_graph_break("foo"):
pass
with self.assertRaises(Exception):

View File

@ -21,6 +21,7 @@ from .decorators import (
disable,
disallow_in_graph,
dont_skip_tracing,
error_on_graph_break,
forbid_in_graph,
graph_break,
mark_dynamic,
@ -30,7 +31,6 @@ from .decorators import (
nonstrict_trace,
patch_dynamo_config,
run,
set_fullgraph,
set_stance,
skip_frame,
substitute_in_graph,
@ -90,7 +90,7 @@ __all__ = [
"replay",
"reset",
"run",
"set_fullgraph",
"error_on_graph_break",
"set_stance",
"skip_frame",
"substitute_in_graph",

View File

@ -1747,7 +1747,7 @@ def replay(filename: str) -> None:
record = ExecutionRecord.load(in_file)
record.globals = dict(itertools.chain(record.globals.items(), globals().items()))
with decorators.set_fullgraph(fullgraph=False):
with decorators.error_on_graph_break(False):
try:
_compile(
record.code,

View File

@ -918,15 +918,15 @@ def dont_skip_tracing(fn: Optional[Any] = None) -> Any:
return ctx
class SetFullgraphDecoratorContextManager:
def __init__(self, fullgraph: bool) -> None:
self.fullgraph = fullgraph
class ErrorOnGraphBreakDecoratorContextManager:
def __init__(self, error_on_graph_break: bool) -> None:
self.error_on_graph_break = error_on_graph_break
__call__ = wrap_dunder_call_ctx_manager
def __enter__(self) -> None:
self.prev_fullgraph = _get_error_on_graph_break()
_set_error_on_graph_break(self.fullgraph)
self.prev_error_on_graph_break = _get_error_on_graph_break()
_set_error_on_graph_break(self.error_on_graph_break)
def __exit__(
self,
@ -934,14 +934,16 @@ class SetFullgraphDecoratorContextManager:
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
_set_error_on_graph_break(self.prev_fullgraph)
_set_error_on_graph_break(self.prev_error_on_graph_break)
def set_fullgraph(fullgraph: bool) -> SetFullgraphDecoratorContextManager:
def error_on_graph_break(
error_on_graph_break: bool,
) -> ErrorOnGraphBreakDecoratorContextManager:
"""
Context manager/decorator to toggle fullgraph setting.
Context manager/decorator to toggle error_on_graph_break (i.e. torch.compile's fullgraph) setting.
More precisely, when encountering a graph break, we will decide to resume (fullgraph=False)
or error out (fullgraph=True) based on the fullgraph setting at the location of the graph break.
"""
return SetFullgraphDecoratorContextManager(fullgraph)
return ErrorOnGraphBreakDecoratorContextManager(error_on_graph_break)

View File

@ -858,7 +858,9 @@ class _TorchDynamoContext:
# hooks to properly handle inlining
compile_wrapper._torchdynamo_inline = ( # type: ignore[attr-defined]
external_utils.wrap_inline_with_set_fullgraph(fn, self.error_on_graph_break)
external_utils.wrap_inline_with_error_on_graph_break(
fn, self.error_on_graph_break
)
)
# Save the function pointer to find the original callable while nesting

View File

@ -229,23 +229,23 @@ def call_accumulate_grad(
variable.grad = updated_grad[0]
def wrap_inline_with_set_fullgraph(
fn: Callable[_P, _R], fullgraph: bool
def wrap_inline_with_error_on_graph_break(
fn: Callable[_P, _R], error_on_graph_break: bool
) -> Callable[_P, _R]:
# NB: need multiple definitions in order to prevent `fullgraph` from
# being a freevar of wrapper
if fullgraph:
if error_on_graph_break:
@functools.wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
with torch._dynamo.set_fullgraph(True):
with torch._dynamo.error_on_graph_break(True):
return fn(*args, **kwargs)
else:
@functools.wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
with torch._dynamo.set_fullgraph(False):
with torch._dynamo.error_on_graph_break(False):
return fn(*args, **kwargs)
return wrapper

View File

@ -348,7 +348,7 @@ manual_torch_name_rule_map: dict[
"torch._dynamo.mark_static": UserFunctionVariable,
"torch._dynamo.nonstrict_trace": UserFunctionVariable,
"torch._dynamo.patch_dynamo_config": UserFunctionVariable,
"torch._dynamo.set_fullgraph": UserFunctionVariable,
"torch._dynamo.error_on_graph_break": UserFunctionVariable,
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
"torch.fx.experimental.symbolic_shapes.guard_or_true": TorchInGraphFunctionVariable,
"torch.fx.experimental.symbolic_shapes.guard_or_false": TorchInGraphFunctionVariable,

View File

@ -27,6 +27,7 @@ from .ctx_manager import (
DisabledSavedTensorsHooksVariable,
DualLevelContextManager,
DynamoConfigPatchVariable,
ErrorOnGraphBreakVariable,
FSDPParamGroupUseTrainingStateVariable,
GradIncrementNestingCtxManagerVariable,
GradInplaceRequiresGradCtxManagerVariable,
@ -34,7 +35,6 @@ from .ctx_manager import (
InferenceModeVariable,
JvpIncrementNestingCtxManagerVariable,
SDPAKernelVariable,
SetFullgraphVariable,
SetFwdGradEnabledContextManager,
StreamContextVariable,
StreamVariable,
@ -200,7 +200,7 @@ __all__ = [
"RemovableHandleVariable",
"RepeatIteratorVariable",
"SDPAParamsVariable",
"SetFullgraphVariable",
"ErrorOnGraphBreakVariable",
"SkipFunctionVariable",
"SliceVariable",
"StringFormatVariable",

View File

@ -168,10 +168,10 @@ from .constant import ConstantVariable, EnumVariable
from .ctx_manager import (
AutocastModeVariable,
DynamoConfigPatchVariable,
ErrorOnGraphBreakVariable,
EventVariable,
NullContextVariable,
PreserveVersionContextVariable,
SetFullgraphVariable,
StreamContextVariable,
StreamVariable,
)
@ -630,7 +630,7 @@ class VariableBuilder:
from ..decorators import (
DynamoConfigPatchProxy,
SetFullgraphDecoratorContextManager,
ErrorOnGraphBreakDecoratorContextManager,
)
if has_triton():
@ -988,8 +988,8 @@ class VariableBuilder:
)
elif isinstance(value, DynamoConfigPatchProxy):
return DynamoConfigPatchVariable(value.changes)
elif isinstance(value, SetFullgraphDecoratorContextManager):
return SetFullgraphVariable(value.fullgraph)
elif isinstance(value, ErrorOnGraphBreakDecoratorContextManager):
return ErrorOnGraphBreakVariable(value.error_on_graph_break)
elif callable(value) and trace_rules.lookup_callable(value) is not None:
if trace_rules.is_callable_allowed(value):
self.tx.output.has_user_defined_allowed_in_graph = True

View File

@ -1429,12 +1429,12 @@ class DynamoConfigPatchVariable(ContextWrappingVariable):
return "patch_dynamo_config"
class SetFullgraphVariable(ContextWrappingVariable):
"""represents torch._dynamo.set_fullgraph"""
class ErrorOnGraphBreakVariable(ContextWrappingVariable):
"""represents torch._dynamo.error_on_graph_break"""
def __init__(self, fullgraph, **kwargs) -> None:
def __init__(self, error_on_graph_break, **kwargs) -> None:
super().__init__(
target_values=(fullgraph,),
target_values=(error_on_graph_break,),
initial_values=(_get_error_on_graph_break(),),
**kwargs,
)
@ -1447,7 +1447,7 @@ class SetFullgraphVariable(ContextWrappingVariable):
return "torch._dynamo"
def fn_name(self):
return "set_fullgraph"
return "error_on_graph_break"
class WithExitFunctionVariable(VariableTracker):

View File

@ -521,15 +521,17 @@ class UserFunctionVariable(BaseUserFunctionVariable):
"Please fix your call to patch_dynamo_config by using simpler inputs. "
f"args: {args}, kwargs: {kwargs}"
) from e
elif self.fn is torch._dynamo.set_fullgraph:
elif self.fn is torch._dynamo.error_on_graph_break:
try:
bound = inspect.signature(self.fn).bind(*args, **kwargs)
fullgraph = bound.arguments["fullgraph"].as_python_constant()
assert isinstance(fullgraph, bool)
return variables.SetFullgraphVariable(fullgraph)
error_on_graph_break = bound.arguments[
"error_on_graph_break"
].as_python_constant()
assert isinstance(error_on_graph_break, bool)
return variables.ErrorOnGraphBreakVariable(error_on_graph_break)
except Exception as e:
raise RuntimeError(
"Improper set_fullgraph() call. Please fix your call to set_fullgraph(). "
"Improper error_on_graph_break() call. Please fix your call to error_on_graph_break(). "
f"args: {args}, kwargs: {kwargs}"
) from e
# Handle a `nonstrict_trace(fn)` call