mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[dynamo] rename set_fullgraph to error_on_graph_break (#161739)
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph. I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet). cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739 Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
This commit is contained in:
parent
1281470155
commit
8678d831c4
|
|
@ -62,16 +62,16 @@ index dbc5ef4f9f2..af717703053 100644
|
|||
@@ -5,7 +58,7 @@ Tests common to list and UserList.UserList
|
||||
import sys
|
||||
from functools import cmp_to_key
|
||||
|
||||
|
||||
-from test import seq_tests
|
||||
+import seq_tests
|
||||
from test.support import ALWAYS_EQ, NEVER_EQ, get_c_recursion_limit
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -119,10 +172,6 @@ class CommonTest(seq_tests.CommonTest):
|
||||
a[-1] = 9
|
||||
self.assertEqual(a, self.type2test([5,6,7,8,9]))
|
||||
|
||||
|
||||
- msg = "list indices must be integers or slices"
|
||||
- with self.assertRaisesRegex(TypeError, msg):
|
||||
- a['a'] = "python"
|
||||
|
|
@ -81,7 +81,7 @@ index dbc5ef4f9f2..af717703053 100644
|
|||
del a[1]
|
||||
@@ -270,13 +319,14 @@ class CommonTest(seq_tests.CommonTest):
|
||||
self.assertRaises(TypeError, a.extend)
|
||||
|
||||
|
||||
# overflow test. issue1621
|
||||
- class CustomIter:
|
||||
- def __iter__(self):
|
||||
|
|
@ -90,7 +90,7 @@ index dbc5ef4f9f2..af717703053 100644
|
|||
- raise StopIteration
|
||||
- def __length_hint__(self):
|
||||
- return sys.maxsize
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class CustomIter:
|
||||
+ def __iter__(self):
|
||||
+ return self
|
||||
|
|
@ -104,13 +104,13 @@ index dbc5ef4f9f2..af717703053 100644
|
|||
@@ -337,21 +387,23 @@ class CommonTest(seq_tests.CommonTest):
|
||||
a = self.type2test([NEVER_EQ])
|
||||
self.assertRaises(ValueError, a.remove, ALWAYS_EQ)
|
||||
|
||||
|
||||
- class BadExc(Exception):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class BadExc(Exception):
|
||||
+ pass
|
||||
|
||||
|
||||
- class BadCmp:
|
||||
- def __eq__(self, other):
|
||||
- if other == 2:
|
||||
|
|
@ -121,24 +121,24 @@ index dbc5ef4f9f2..af717703053 100644
|
|||
+ if other == 2:
|
||||
+ raise BadExc()
|
||||
+ return False
|
||||
|
||||
|
||||
a = self.type2test([0, 1, 2, 3])
|
||||
self.assertRaises(BadExc, a.remove, BadCmp())
|
||||
|
||||
|
||||
- class BadCmp2:
|
||||
- def __eq__(self, other):
|
||||
- raise BadExc()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class BadCmp2:
|
||||
+ def __eq__(self, other):
|
||||
+ raise BadExc()
|
||||
|
||||
|
||||
d = self.type2test('abcdefghcij')
|
||||
d.remove('c')
|
||||
@@ -376,13 +428,14 @@ class CommonTest(seq_tests.CommonTest):
|
||||
self.assertRaises(ValueError, a.index, 2, 0, 4)
|
||||
self.assertEqual(a, self.type2test([-2, -1, 0, 1, 2]))
|
||||
|
||||
|
||||
- # Test modifying the list during index's iteration
|
||||
- class EvilCmp:
|
||||
- def __init__(self, victim):
|
||||
|
|
@ -146,7 +146,7 @@ index dbc5ef4f9f2..af717703053 100644
|
|||
- def __eq__(self, other):
|
||||
- del self.victim[:]
|
||||
- return False
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ # Test modifying the list during index's iteration
|
||||
+ class EvilCmp:
|
||||
+ def __init__(self, victim):
|
||||
|
|
|
|||
|
|
@ -319,7 +319,7 @@ class CommonTest(seq_tests.CommonTest):
|
|||
self.assertRaises(TypeError, a.extend)
|
||||
|
||||
# overflow test. issue1621
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class CustomIter:
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
|
@ -387,7 +387,7 @@ class CommonTest(seq_tests.CommonTest):
|
|||
a = self.type2test([NEVER_EQ])
|
||||
self.assertRaises(ValueError, a.remove, ALWAYS_EQ)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BadExc(Exception):
|
||||
pass
|
||||
|
||||
|
|
@ -400,7 +400,7 @@ class CommonTest(seq_tests.CommonTest):
|
|||
a = self.type2test([0, 1, 2, 3])
|
||||
self.assertRaises(BadExc, a.remove, BadCmp())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BadCmp2:
|
||||
def __eq__(self, other):
|
||||
raise BadExc()
|
||||
|
|
@ -428,7 +428,7 @@ class CommonTest(seq_tests.CommonTest):
|
|||
self.assertRaises(ValueError, a.index, 2, 0, 4)
|
||||
self.assertEqual(a, self.type2test([-2, -1, 0, 1, 2]))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
# Test modifying the list during index's iteration
|
||||
class EvilCmp:
|
||||
def __init__(self, victim):
|
||||
|
|
|
|||
|
|
@ -61,16 +61,16 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
import unittest
|
||||
import collections
|
||||
from test.support import get_c_recursion_limit
|
||||
|
||||
|
||||
|
||||
|
||||
-class BasicTestMappingProtocol(unittest.TestCase):
|
||||
+class BasicTestMappingProtocol(__TestCase):
|
||||
# This base class can be used to check that an object conforms to the
|
||||
# mapping protocol
|
||||
|
||||
|
||||
@@ -196,70 +250,76 @@ class BasicTestMappingProtocol(unittest.TestCase):
|
||||
self.assertRaises((TypeError, AttributeError), d.update, 42)
|
||||
|
||||
|
||||
outerself = self
|
||||
- class SimpleUserDict:
|
||||
- def __init__(self):
|
||||
|
|
@ -79,7 +79,7 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
- return self.d.keys()
|
||||
- def __getitem__(self, i):
|
||||
- return self.d[i]
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class SimpleUserDict:
|
||||
+ def __init__(self):
|
||||
+ self.d = outerself.reference
|
||||
|
|
@ -92,23 +92,23 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
i1 = sorted(d.items())
|
||||
i2 = sorted(self.reference.items())
|
||||
self.assertEqual(i1, i2)
|
||||
|
||||
|
||||
- class Exc(Exception): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Exc(Exception): pass
|
||||
|
||||
|
||||
d = self._empty_mapping()
|
||||
- class FailingUserDict:
|
||||
- def keys(self):
|
||||
- raise Exc
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class FailingUserDict:
|
||||
+ def keys(self):
|
||||
+ raise Exc
|
||||
self.assertRaises(Exc, d.update, FailingUserDict())
|
||||
|
||||
|
||||
d.clear()
|
||||
|
||||
|
||||
- class FailingUserDict:
|
||||
- def keys(self):
|
||||
- class BogonIter:
|
||||
|
|
@ -124,7 +124,7 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
- return BogonIter()
|
||||
- def __getitem__(self, key):
|
||||
- return key
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class FailingUserDict:
|
||||
+ def keys(self):
|
||||
+ class BogonIter:
|
||||
|
|
@ -141,7 +141,7 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
+ def __getitem__(self, key):
|
||||
+ return key
|
||||
self.assertRaises(Exc, d.update, FailingUserDict())
|
||||
|
||||
|
||||
- class FailingUserDict:
|
||||
- def keys(self):
|
||||
- class BogonIter:
|
||||
|
|
@ -158,7 +158,7 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
- return BogonIter()
|
||||
- def __getitem__(self, key):
|
||||
- raise Exc
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class FailingUserDict:
|
||||
+ def keys(self):
|
||||
+ class BogonIter:
|
||||
|
|
@ -176,26 +176,26 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
+ def __getitem__(self, key):
|
||||
+ raise Exc
|
||||
self.assertRaises(Exc, d.update, FailingUserDict())
|
||||
|
||||
|
||||
d = self._empty_mapping()
|
||||
- class badseq(object):
|
||||
- def __iter__(self):
|
||||
- return self
|
||||
- def __next__(self):
|
||||
- raise Exc()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class badseq(object):
|
||||
+ def __iter__(self):
|
||||
+ return self
|
||||
+ def __next__(self):
|
||||
+ raise Exc()
|
||||
|
||||
|
||||
self.assertRaises(Exc, d.update, badseq())
|
||||
|
||||
|
||||
@@ -409,13 +469,14 @@ class TestMappingProtocol(BasicTestMappingProtocol):
|
||||
d.update(self._full_mapping({1:2, 3:4, 5:6}).items())
|
||||
self.assertEqual(d, {1:2, 2:4, 3:4, 5:6})
|
||||
|
||||
|
||||
- class SimpleUserDict:
|
||||
- def __init__(self):
|
||||
- self.d = {1:1, 2:2, 3:3}
|
||||
|
|
@ -203,7 +203,7 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
- return self.d.keys()
|
||||
- def __getitem__(self, i):
|
||||
- return self.d[i]
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class SimpleUserDict:
|
||||
+ def __init__(self):
|
||||
+ self.d = {1:1, 2:2, 3:3}
|
||||
|
|
@ -219,7 +219,7 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
self.assertEqual(d.fromkeys(g()), {1:None})
|
||||
self.assertRaises(TypeError, {}.fromkeys, 3)
|
||||
- class dictlike(self.type2test): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class dictlike(self.type2test): pass
|
||||
self.assertEqual(dictlike.fromkeys('a'), {'a':None})
|
||||
self.assertEqual(dictlike().fromkeys('a'), {'a':None})
|
||||
|
|
@ -229,7 +229,7 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
- class mydict(self.type2test):
|
||||
- def __new__(cls):
|
||||
- return collections.UserDict()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class mydict(self.type2test):
|
||||
+ def __new__(cls):
|
||||
+ return collections.UserDict()
|
||||
|
|
@ -237,52 +237,52 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
self.assertEqual(ud, {'a':None, 'b':None})
|
||||
self.assertIsInstance(ud, collections.UserDict)
|
||||
self.assertRaises(TypeError, dict.fromkeys)
|
||||
|
||||
|
||||
- class Exc(Exception): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Exc(Exception): pass
|
||||
|
||||
|
||||
- class baddict1(self.type2test):
|
||||
- def __init__(self, *args, **kwargs):
|
||||
- raise Exc()
|
||||
+ class baddict1(self.type2test):
|
||||
+ def __init__(self, *args, **kwargs):
|
||||
+ raise Exc()
|
||||
|
||||
|
||||
self.assertRaises(Exc, baddict1.fromkeys, [1])
|
||||
|
||||
|
||||
- class BadSeq(object):
|
||||
- def __iter__(self):
|
||||
- return self
|
||||
- def __next__(self):
|
||||
- raise Exc()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class BadSeq(object):
|
||||
+ def __iter__(self):
|
||||
+ return self
|
||||
+ def __next__(self):
|
||||
+ raise Exc()
|
||||
|
||||
|
||||
self.assertRaises(Exc, self.type2test.fromkeys, BadSeq())
|
||||
|
||||
|
||||
- class baddict2(self.type2test):
|
||||
- def __setitem__(self, key, value):
|
||||
- raise Exc()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class baddict2(self.type2test):
|
||||
+ def __setitem__(self, key, value):
|
||||
+ raise Exc()
|
||||
|
||||
|
||||
self.assertRaises(Exc, baddict2.fromkeys, [1])
|
||||
|
||||
|
||||
@@ -537,25 +603,27 @@ class TestHashMappingProtocol(TestMappingProtocol):
|
||||
|
||||
|
||||
def test_getitem(self):
|
||||
TestMappingProtocol.test_getitem(self)
|
||||
- class Exc(Exception): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Exc(Exception): pass
|
||||
|
||||
|
||||
- class BadEq(object):
|
||||
- def __eq__(self, other):
|
||||
- raise Exc()
|
||||
|
|
@ -293,11 +293,11 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
+ raise Exc()
|
||||
+ def __hash__(self):
|
||||
+ return 24
|
||||
|
||||
|
||||
d = self._empty_mapping()
|
||||
d[BadEq()] = 42
|
||||
self.assertRaises(KeyError, d.__getitem__, 23)
|
||||
|
||||
|
||||
- class BadHash(object):
|
||||
- fail = False
|
||||
- def __hash__(self):
|
||||
|
|
@ -305,7 +305,7 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
- raise Exc()
|
||||
- else:
|
||||
- return 42
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class BadHash(object):
|
||||
+ fail = False
|
||||
+ def __hash__(self):
|
||||
|
|
@ -313,17 +313,17 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
+ raise Exc()
|
||||
+ else:
|
||||
+ return 42
|
||||
|
||||
|
||||
d = self._empty_mapping()
|
||||
x = BadHash()
|
||||
@@ -565,9 +633,10 @@ class TestHashMappingProtocol(TestMappingProtocol):
|
||||
|
||||
|
||||
def test_fromkeys(self):
|
||||
TestMappingProtocol.test_fromkeys(self)
|
||||
- class mydict(self.type2test):
|
||||
- def __new__(cls):
|
||||
- return collections.UserDict()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class mydict(self.type2test):
|
||||
+ def __new__(cls):
|
||||
+ return collections.UserDict()
|
||||
|
|
@ -333,11 +333,11 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
@@ -575,15 +644,16 @@ class TestHashMappingProtocol(TestMappingProtocol):
|
||||
def test_pop(self):
|
||||
TestMappingProtocol.test_pop(self)
|
||||
|
||||
|
||||
- class Exc(Exception): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Exc(Exception): pass
|
||||
|
||||
|
||||
- class BadHash(object):
|
||||
- fail = False
|
||||
- def __hash__(self):
|
||||
|
|
@ -352,34 +352,34 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
+ raise Exc()
|
||||
+ else:
|
||||
+ return 42
|
||||
|
||||
|
||||
d = self._empty_mapping()
|
||||
x = BadHash()
|
||||
@@ -613,11 +683,12 @@ class TestHashMappingProtocol(TestMappingProtocol):
|
||||
d[1] = d
|
||||
self.assertEqual(repr(d), '{1: {...}}')
|
||||
|
||||
|
||||
- class Exc(Exception): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Exc(Exception): pass
|
||||
|
||||
|
||||
- class BadRepr(object):
|
||||
- def __repr__(self):
|
||||
- raise Exc()
|
||||
+ class BadRepr(object):
|
||||
+ def __repr__(self):
|
||||
+ raise Exc()
|
||||
|
||||
|
||||
d = self._full_mapping({1: BadRepr()})
|
||||
self.assertRaises(Exc, repr, d)
|
||||
@@ -635,13 +706,14 @@ class TestHashMappingProtocol(TestMappingProtocol):
|
||||
self.assertEqual(self._full_mapping({1: 2}),
|
||||
self._full_mapping({1: 2}))
|
||||
|
||||
|
||||
- class Exc(Exception): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Exc(Exception): pass
|
||||
|
||||
|
||||
- class BadCmp(object):
|
||||
- def __eq__(self, other):
|
||||
- raise Exc()
|
||||
|
|
@ -390,17 +390,17 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
+ raise Exc()
|
||||
+ def __hash__(self):
|
||||
+ return 1
|
||||
|
||||
|
||||
d1 = self._full_mapping({BadCmp(): 1})
|
||||
d2 = self._full_mapping({1: 1})
|
||||
@@ -651,15 +723,16 @@ class TestHashMappingProtocol(TestMappingProtocol):
|
||||
def test_setdefault(self):
|
||||
TestMappingProtocol.test_setdefault(self)
|
||||
|
||||
|
||||
- class Exc(Exception): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Exc(Exception): pass
|
||||
|
||||
|
||||
- class BadHash(object):
|
||||
- fail = False
|
||||
- def __hash__(self):
|
||||
|
|
@ -415,6 +415,6 @@ index ed89a81a6ea..b19cec7cb23 100644
|
|||
+ raise Exc()
|
||||
+ else:
|
||||
+ return 42
|
||||
|
||||
|
||||
d = self._empty_mapping()
|
||||
x = BadHash()
|
||||
|
|
|
|||
|
|
@ -250,7 +250,7 @@ class BasicTestMappingProtocol(__TestCase):
|
|||
self.assertRaises((TypeError, AttributeError), d.update, 42)
|
||||
|
||||
outerself = self
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class SimpleUserDict:
|
||||
def __init__(self):
|
||||
self.d = outerself.reference
|
||||
|
|
@ -264,11 +264,11 @@ class BasicTestMappingProtocol(__TestCase):
|
|||
i2 = sorted(self.reference.items())
|
||||
self.assertEqual(i1, i2)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Exc(Exception): pass
|
||||
|
||||
d = self._empty_mapping()
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class FailingUserDict:
|
||||
def keys(self):
|
||||
raise Exc
|
||||
|
|
@ -276,7 +276,7 @@ class BasicTestMappingProtocol(__TestCase):
|
|||
|
||||
d.clear()
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class FailingUserDict:
|
||||
def keys(self):
|
||||
class BogonIter:
|
||||
|
|
@ -294,7 +294,7 @@ class BasicTestMappingProtocol(__TestCase):
|
|||
return key
|
||||
self.assertRaises(Exc, d.update, FailingUserDict())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class FailingUserDict:
|
||||
def keys(self):
|
||||
class BogonIter:
|
||||
|
|
@ -314,7 +314,7 @@ class BasicTestMappingProtocol(__TestCase):
|
|||
self.assertRaises(Exc, d.update, FailingUserDict())
|
||||
|
||||
d = self._empty_mapping()
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class badseq(object):
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
|
@ -469,7 +469,7 @@ class TestMappingProtocol(BasicTestMappingProtocol):
|
|||
d.update(self._full_mapping({1:2, 3:4, 5:6}).items())
|
||||
self.assertEqual(d, {1:2, 2:4, 3:4, 5:6})
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class SimpleUserDict:
|
||||
def __init__(self):
|
||||
self.d = {1:1, 2:2, 3:3}
|
||||
|
|
@ -492,14 +492,14 @@ class TestMappingProtocol(BasicTestMappingProtocol):
|
|||
yield 1
|
||||
self.assertEqual(d.fromkeys(g()), {1:None})
|
||||
self.assertRaises(TypeError, {}.fromkeys, 3)
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class dictlike(self.type2test): pass
|
||||
self.assertEqual(dictlike.fromkeys('a'), {'a':None})
|
||||
self.assertEqual(dictlike().fromkeys('a'), {'a':None})
|
||||
self.assertTrue(dictlike.fromkeys('a').__class__ is dictlike)
|
||||
self.assertTrue(dictlike().fromkeys('a').__class__ is dictlike)
|
||||
self.assertTrue(type(dictlike.fromkeys('a')) is dictlike)
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class mydict(self.type2test):
|
||||
def __new__(cls):
|
||||
return collections.UserDict()
|
||||
|
|
@ -508,7 +508,7 @@ class TestMappingProtocol(BasicTestMappingProtocol):
|
|||
self.assertIsInstance(ud, collections.UserDict)
|
||||
self.assertRaises(TypeError, dict.fromkeys)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Exc(Exception): pass
|
||||
|
||||
class baddict1(self.type2test):
|
||||
|
|
@ -517,7 +517,7 @@ class TestMappingProtocol(BasicTestMappingProtocol):
|
|||
|
||||
self.assertRaises(Exc, baddict1.fromkeys, [1])
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BadSeq(object):
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
|
@ -526,7 +526,7 @@ class TestMappingProtocol(BasicTestMappingProtocol):
|
|||
|
||||
self.assertRaises(Exc, self.type2test.fromkeys, BadSeq())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class baddict2(self.type2test):
|
||||
def __setitem__(self, key, value):
|
||||
raise Exc()
|
||||
|
|
@ -603,7 +603,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
|
|||
|
||||
def test_getitem(self):
|
||||
TestMappingProtocol.test_getitem(self)
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Exc(Exception): pass
|
||||
|
||||
class BadEq(object):
|
||||
|
|
@ -616,7 +616,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
|
|||
d[BadEq()] = 42
|
||||
self.assertRaises(KeyError, d.__getitem__, 23)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BadHash(object):
|
||||
fail = False
|
||||
def __hash__(self):
|
||||
|
|
@ -633,7 +633,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
|
|||
|
||||
def test_fromkeys(self):
|
||||
TestMappingProtocol.test_fromkeys(self)
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class mydict(self.type2test):
|
||||
def __new__(cls):
|
||||
return collections.UserDict()
|
||||
|
|
@ -644,7 +644,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
|
|||
def test_pop(self):
|
||||
TestMappingProtocol.test_pop(self)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Exc(Exception): pass
|
||||
|
||||
class BadHash(object):
|
||||
|
|
@ -683,7 +683,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
|
|||
d[1] = d
|
||||
self.assertEqual(repr(d), '{1: {...}}')
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Exc(Exception): pass
|
||||
|
||||
class BadRepr(object):
|
||||
|
|
@ -706,7 +706,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
|
|||
self.assertEqual(self._full_mapping({1: 2}),
|
||||
self._full_mapping({1: 2}))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Exc(Exception): pass
|
||||
|
||||
class BadCmp(object):
|
||||
|
|
@ -723,7 +723,7 @@ class TestHashMappingProtocol(TestMappingProtocol):
|
|||
def test_setdefault(self):
|
||||
TestMappingProtocol.test_setdefault(self)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Exc(Exception): pass
|
||||
|
||||
class BadHash(object):
|
||||
|
|
|
|||
|
|
@ -63,15 +63,15 @@ index 719c9434a16..290e57c04a0 100644
|
|||
@@ -95,7 +149,7 @@ class LyingList(list):
|
||||
def __iter__(self):
|
||||
yield 1
|
||||
|
||||
|
||||
-class CommonTest(unittest.TestCase):
|
||||
+class CommonTest(__TestCase):
|
||||
# The type to be tested
|
||||
type2test = None
|
||||
|
||||
|
||||
@@ -115,13 +169,14 @@ class CommonTest(unittest.TestCase):
|
||||
uu2 = self.type2test(u2)
|
||||
|
||||
|
||||
v = self.type2test(tuple(u))
|
||||
- class OtherSeq:
|
||||
- def __init__(self, initseq):
|
||||
|
|
@ -80,7 +80,7 @@ index 719c9434a16..290e57c04a0 100644
|
|||
- return len(self.__data)
|
||||
- def __getitem__(self, i):
|
||||
- return self.__data[i]
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class OtherSeq:
|
||||
+ def __init__(self, initseq):
|
||||
+ self.__data = initseq
|
||||
|
|
@ -100,51 +100,51 @@ index 719c9434a16..290e57c04a0 100644
|
|||
- class StopCompares:
|
||||
- def __eq__(self, other):
|
||||
- raise DoNotTestEq
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class DoNotTestEq(Exception):
|
||||
+ pass
|
||||
+ class StopCompares:
|
||||
+ def __eq__(self, other):
|
||||
+ raise DoNotTestEq
|
||||
|
||||
|
||||
checkfirst = self.type2test([1, StopCompares()])
|
||||
self.assertIn(1, checkfirst)
|
||||
@@ -283,8 +339,9 @@ class CommonTest(unittest.TestCase):
|
||||
self.assertEqual(u2+u2+u2, u2*3)
|
||||
self.assertEqual(u2+u2+u2, 3*u2)
|
||||
|
||||
|
||||
- class subclass(self.type2test):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class subclass(self.type2test):
|
||||
+ pass
|
||||
u3 = subclass([0, 1])
|
||||
self.assertEqual(u3, u3*1)
|
||||
self.assertIsNot(u3, u3*1)
|
||||
@@ -311,9 +368,10 @@ class CommonTest(unittest.TestCase):
|
||||
|
||||
|
||||
def test_getitemoverwriteiter(self):
|
||||
# Verify that __getitem__ overrides are not recognized by __iter__
|
||||
- class T(self.type2test):
|
||||
- def __getitem__(self, key):
|
||||
- return str(key) + '!!!'
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class T(self.type2test):
|
||||
+ def __getitem__(self, key):
|
||||
+ return str(key) + '!!!'
|
||||
self.assertEqual(next(iter(T((1,2)))), 1)
|
||||
|
||||
|
||||
def test_repeat(self):
|
||||
@@ -361,14 +419,15 @@ class CommonTest(unittest.TestCase):
|
||||
|
||||
|
||||
self.assertRaises(TypeError, a.count)
|
||||
|
||||
|
||||
- class BadExc(Exception):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class BadExc(Exception):
|
||||
+ pass
|
||||
|
||||
|
||||
- class BadCmp:
|
||||
- def __eq__(self, other):
|
||||
- if other == 2:
|
||||
|
|
@ -155,19 +155,19 @@ index 719c9434a16..290e57c04a0 100644
|
|||
+ if other == 2:
|
||||
+ raise BadExc()
|
||||
+ return False
|
||||
|
||||
|
||||
self.assertRaises(BadExc, a.count, BadCmp())
|
||||
|
||||
|
||||
@@ -394,14 +453,15 @@ class CommonTest(unittest.TestCase):
|
||||
|
||||
|
||||
self.assertRaises(TypeError, u.index)
|
||||
|
||||
|
||||
- class BadExc(Exception):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class BadExc(Exception):
|
||||
+ pass
|
||||
|
||||
|
||||
- class BadCmp:
|
||||
- def __eq__(self, other):
|
||||
- if other == 2:
|
||||
|
|
@ -178,6 +178,6 @@ index 719c9434a16..290e57c04a0 100644
|
|||
+ if other == 2:
|
||||
+ raise BadExc()
|
||||
+ return False
|
||||
|
||||
|
||||
a = self.type2test([0, 1, 2, 3])
|
||||
self.assertRaises(BadExc, a.index, BadCmp())
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ class CommonTest(__TestCase):
|
|||
uu2 = self.type2test(u2)
|
||||
|
||||
v = self.type2test(tuple(u))
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class OtherSeq:
|
||||
def __init__(self, initseq):
|
||||
self.__data = initseq
|
||||
|
|
@ -294,7 +294,7 @@ class CommonTest(__TestCase):
|
|||
# Sequences must test in-order. If a rich comparison has side
|
||||
# effects, these will be visible to tests against later members.
|
||||
# In this test, the "side effect" is a short-circuiting raise.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class DoNotTestEq(Exception):
|
||||
pass
|
||||
class StopCompares:
|
||||
|
|
@ -339,7 +339,7 @@ class CommonTest(__TestCase):
|
|||
self.assertEqual(u2+u2+u2, u2*3)
|
||||
self.assertEqual(u2+u2+u2, 3*u2)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class subclass(self.type2test):
|
||||
pass
|
||||
u3 = subclass([0, 1])
|
||||
|
|
@ -368,7 +368,7 @@ class CommonTest(__TestCase):
|
|||
|
||||
def test_getitemoverwriteiter(self):
|
||||
# Verify that __getitem__ overrides are not recognized by __iter__
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class T(self.type2test):
|
||||
def __getitem__(self, key):
|
||||
return str(key) + '!!!'
|
||||
|
|
@ -419,7 +419,7 @@ class CommonTest(__TestCase):
|
|||
|
||||
self.assertRaises(TypeError, a.count)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BadExc(Exception):
|
||||
pass
|
||||
|
||||
|
|
@ -453,7 +453,7 @@ class CommonTest(__TestCase):
|
|||
|
||||
self.assertRaises(TypeError, u.index)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BadExc(Exception):
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -24,20 +24,20 @@ index 34ecb45f161..12b719c432b 100644
|
|||
+# ======= END DYNAMO PATCH =======
|
||||
+
|
||||
# Test properties of bool promised by PEP 285
|
||||
|
||||
|
||||
import unittest
|
||||
@@ -5,12 +25,13 @@ from test.support import os_helper
|
||||
|
||||
|
||||
import os
|
||||
|
||||
|
||||
-class BoolTest(unittest.TestCase):
|
||||
+class BoolTest(__TestCase):
|
||||
|
||||
|
||||
def test_subclass(self):
|
||||
try:
|
||||
- class C(bool):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C(bool):
|
||||
+ pass
|
||||
except TypeError:
|
||||
|
|
@ -50,67 +50,67 @@ index 34ecb45f161..12b719c432b 100644
|
|||
- class Foo(object):
|
||||
- def __bool__(self):
|
||||
- return self
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Foo(object):
|
||||
+ def __bool__(self):
|
||||
+ return self
|
||||
check(Foo())
|
||||
|
||||
|
||||
- class Bar(object):
|
||||
- def __bool__(self):
|
||||
- return "Yes"
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Bar(object):
|
||||
+ def __bool__(self):
|
||||
+ return "Yes"
|
||||
check(Bar())
|
||||
|
||||
|
||||
- class Baz(int):
|
||||
- def __bool__(self):
|
||||
- return self
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Baz(int):
|
||||
+ def __bool__(self):
|
||||
+ return self
|
||||
check(Baz())
|
||||
|
||||
|
||||
# __bool__() must return a bool not an int
|
||||
- class Spam(int):
|
||||
- def __bool__(self):
|
||||
- return 1
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Spam(int):
|
||||
+ def __bool__(self):
|
||||
+ return 1
|
||||
check(Spam())
|
||||
|
||||
|
||||
- class Eggs:
|
||||
- def __len__(self):
|
||||
- return -1
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Eggs:
|
||||
+ def __len__(self):
|
||||
+ return -1
|
||||
self.assertRaises(ValueError, bool, Eggs())
|
||||
|
||||
|
||||
def test_interpreter_convert_to_bool_raises(self):
|
||||
- class SymbolicBool:
|
||||
- def __bool__(self):
|
||||
- raise TypeError
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class SymbolicBool:
|
||||
+ def __bool__(self):
|
||||
+ raise TypeError
|
||||
|
||||
|
||||
- class Symbol:
|
||||
- def __gt__(self, other):
|
||||
- return SymbolicBool()
|
||||
+ class Symbol:
|
||||
+ def __gt__(self, other):
|
||||
+ return SymbolicBool()
|
||||
|
||||
|
||||
x = Symbol()
|
||||
|
||||
|
||||
@@ -361,9 +388,10 @@ class BoolTest(unittest.TestCase):
|
||||
# this test just tests our assumptions about __len__
|
||||
# this will start failing if __len__ changes assertions
|
||||
|
|
@ -118,7 +118,7 @@ index 34ecb45f161..12b719c432b 100644
|
|||
- class A:
|
||||
- def __len__(self):
|
||||
- return badval
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A:
|
||||
+ def __len__(self):
|
||||
+ return badval
|
||||
|
|
@ -127,30 +127,30 @@ index 34ecb45f161..12b719c432b 100644
|
|||
except (Exception) as e_bool:
|
||||
@@ -373,14 +401,16 @@ class BoolTest(unittest.TestCase):
|
||||
self.assertEqual(str(e_bool), str(e_len))
|
||||
|
||||
|
||||
def test_blocked(self):
|
||||
- class A:
|
||||
- __bool__ = None
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A:
|
||||
+ __bool__ = None
|
||||
self.assertRaises(TypeError, bool, A())
|
||||
|
||||
|
||||
- class B:
|
||||
- def __len__(self):
|
||||
- return 10
|
||||
- __bool__ = None
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class B:
|
||||
+ def __len__(self):
|
||||
+ return 10
|
||||
+ __bool__ = None
|
||||
self.assertRaises(TypeError, bool, B())
|
||||
|
||||
|
||||
def test_real_and_imag(self):
|
||||
@@ -394,12 +424,13 @@ class BoolTest(unittest.TestCase):
|
||||
self.assertIs(type(False.imag), int)
|
||||
|
||||
|
||||
def test_bool_called_at_least_once(self):
|
||||
- class X:
|
||||
- def __init__(self):
|
||||
|
|
@ -158,19 +158,19 @@ index 34ecb45f161..12b719c432b 100644
|
|||
- def __bool__(self):
|
||||
- self.count += 1
|
||||
- return True
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class X:
|
||||
+ def __init__(self):
|
||||
+ self.count = 0
|
||||
+ def __bool__(self):
|
||||
+ self.count += 1
|
||||
+ return True
|
||||
|
||||
|
||||
def f(x):
|
||||
if x or True:
|
||||
@@ -418,4 +449,4 @@ class BoolTest(unittest.TestCase):
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class BoolTest(__TestCase):
|
|||
|
||||
def test_subclass(self):
|
||||
try:
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C(bool):
|
||||
pass
|
||||
except TypeError:
|
||||
|
|
@ -328,39 +328,39 @@ class BoolTest(__TestCase):
|
|||
# from __bool__(). This isn't really a bool test, but
|
||||
# it's related.
|
||||
check = lambda o: self.assertRaises(TypeError, bool, o)
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Foo(object):
|
||||
def __bool__(self):
|
||||
return self
|
||||
check(Foo())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Bar(object):
|
||||
def __bool__(self):
|
||||
return "Yes"
|
||||
check(Bar())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Baz(int):
|
||||
def __bool__(self):
|
||||
return self
|
||||
check(Baz())
|
||||
|
||||
# __bool__() must return a bool not an int
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Spam(int):
|
||||
def __bool__(self):
|
||||
return 1
|
||||
check(Spam())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Eggs:
|
||||
def __len__(self):
|
||||
return -1
|
||||
self.assertRaises(ValueError, bool, Eggs())
|
||||
|
||||
def test_interpreter_convert_to_bool_raises(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class SymbolicBool:
|
||||
def __bool__(self):
|
||||
raise TypeError
|
||||
|
|
@ -388,7 +388,7 @@ class BoolTest(__TestCase):
|
|||
# this test just tests our assumptions about __len__
|
||||
# this will start failing if __len__ changes assertions
|
||||
for badval in ['illegal', -1, 1 << 32]:
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A:
|
||||
def __len__(self):
|
||||
return badval
|
||||
|
|
@ -401,12 +401,12 @@ class BoolTest(__TestCase):
|
|||
self.assertEqual(str(e_bool), str(e_len))
|
||||
|
||||
def test_blocked(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A:
|
||||
__bool__ = None
|
||||
self.assertRaises(TypeError, bool, A())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class B:
|
||||
def __len__(self):
|
||||
return 10
|
||||
|
|
@ -424,7 +424,7 @@ class BoolTest(__TestCase):
|
|||
self.assertIs(type(False.imag), int)
|
||||
|
||||
def test_bool_called_at_least_once(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class X:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ index a96a5780b31..d00dfca8a17 100644
|
|||
@@ -50,7 +103,7 @@ complex_nans = [complex(x, y) for x, y in [
|
||||
(INF, NAN)
|
||||
]]
|
||||
|
||||
|
||||
-class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||
+class CMathTests(__TestCase):
|
||||
# list of all functions in cmath
|
||||
|
|
@ -74,7 +74,7 @@ index a96a5780b31..d00dfca8a17 100644
|
|||
@@ -66,6 +119,39 @@ class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||
def tearDown(self):
|
||||
self.test_values.close()
|
||||
|
||||
|
||||
+ def assertFloatIdentical(self, x, y):
|
||||
+ """Fail unless floats x and y are identical, in the sense that:
|
||||
+ (1) both x and y are nans, or
|
||||
|
|
@ -113,7 +113,7 @@ index a96a5780b31..d00dfca8a17 100644
|
|||
"""Fail if the two floating-point numbers are not almost equal.
|
||||
@@ -165,38 +251,39 @@ class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||
# end up being passed to the cmath functions
|
||||
|
||||
|
||||
# usual case: new-style class implementing __complex__
|
||||
- class MyComplex:
|
||||
- def __init__(self, value):
|
||||
|
|
@ -127,7 +127,7 @@ index a96a5780b31..d00dfca8a17 100644
|
|||
- class MyComplexException:
|
||||
- def __complex__(self):
|
||||
- raise SomeException
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyComplex:
|
||||
+ def __init__(self, value):
|
||||
+ self.value = value
|
||||
|
|
@ -140,7 +140,7 @@ index a96a5780b31..d00dfca8a17 100644
|
|||
+ class MyComplexException:
|
||||
+ def __complex__(self):
|
||||
+ raise SomeException
|
||||
|
||||
|
||||
- # some classes not providing __float__ or __complex__
|
||||
- class NeitherComplexNorFloat(object):
|
||||
- pass
|
||||
|
|
@ -179,12 +179,12 @@ index a96a5780b31..d00dfca8a17 100644
|
|||
+ class JustFloat:
|
||||
+ def __float__(self):
|
||||
+ return flt_arg
|
||||
|
||||
|
||||
for f in self.test_functions:
|
||||
# usual usage
|
||||
@@ -590,4 +677,4 @@ class IsCloseTests(test_math.IsCloseTests):
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ class CMathTests(__TestCase):
|
|||
# end up being passed to the cmath functions
|
||||
|
||||
# usual case: new-style class implementing __complex__
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyComplex:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
|
|
|||
|
|
@ -24,12 +24,12 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
+# ======= END DYNAMO PATCH =======
|
||||
+
|
||||
"""Unit tests for collections.py."""
|
||||
|
||||
|
||||
import array
|
||||
@@ -29,7 +49,7 @@ from collections.abc import Sequence, MutableSequence
|
||||
from collections.abc import ByteString, Buffer
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestUserObjects(unittest.TestCase):
|
||||
+class TestUserObjects(__TestCase):
|
||||
def _superset_test(self, a, b):
|
||||
|
|
@ -37,12 +37,12 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
set(dir(a)),
|
||||
@@ -73,9 +93,10 @@ class TestUserObjects(unittest.TestCase):
|
||||
self._copy_test(obj)
|
||||
|
||||
|
||||
def test_dict_missing(self):
|
||||
- class A(UserDict):
|
||||
- def __missing__(self, key):
|
||||
- return 456
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A(UserDict):
|
||||
+ def __missing__(self, key):
|
||||
+ return 456
|
||||
|
|
@ -52,20 +52,20 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
@@ -85,7 +106,7 @@ class TestUserObjects(unittest.TestCase):
|
||||
### ChainMap (helper class for configparser and the string module)
|
||||
################################################################################
|
||||
|
||||
|
||||
-class TestChainMap(unittest.TestCase):
|
||||
+class TestChainMap(__TestCase):
|
||||
|
||||
|
||||
def test_basics(self):
|
||||
c = ChainMap()
|
||||
@@ -172,9 +193,10 @@ class TestChainMap(unittest.TestCase):
|
||||
self.assertTrue(ChainMap({}, {1:2}))
|
||||
|
||||
|
||||
def test_missing(self):
|
||||
- class DefaultChainMap(ChainMap):
|
||||
- def __missing__(self, key):
|
||||
- return 999
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class DefaultChainMap(ChainMap):
|
||||
+ def __missing__(self, key):
|
||||
+ return 999
|
||||
|
|
@ -74,7 +74,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
self.assertEqual(d[k], v) # check __getitem__ w/missing
|
||||
@@ -206,13 +228,14 @@ class TestChainMap(unittest.TestCase):
|
||||
('i', 9999), ('j', 0)])
|
||||
|
||||
|
||||
def test_iter_not_calling_getitem_on_maps(self):
|
||||
- class DictWithGetItem(UserDict):
|
||||
- def __init__(self, *args, **kwds):
|
||||
|
|
@ -83,7 +83,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- def __getitem__(self, item):
|
||||
- self.called = True
|
||||
- UserDict.__getitem__(self, item)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class DictWithGetItem(UserDict):
|
||||
+ def __init__(self, *args, **kwds):
|
||||
+ self.called = False
|
||||
|
|
@ -91,12 +91,12 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
+ def __getitem__(self, item):
|
||||
+ self.called = True
|
||||
+ UserDict.__getitem__(self, item)
|
||||
|
||||
|
||||
d = DictWithGetItem(a=1)
|
||||
c = ChainMap(d)
|
||||
@@ -237,15 +260,16 @@ class TestChainMap(unittest.TestCase):
|
||||
self.assertIs(m, d.maps[0])
|
||||
|
||||
|
||||
# Use a different map than a dict
|
||||
- class lowerdict(dict):
|
||||
- def __getitem__(self, key):
|
||||
|
|
@ -107,7 +107,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- if isinstance(key, str):
|
||||
- key = key.lower()
|
||||
- return dict.__contains__(self, key)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class lowerdict(dict):
|
||||
+ def __getitem__(self, key):
|
||||
+ if isinstance(key, str):
|
||||
|
|
@ -117,46 +117,46 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
+ if isinstance(key, str):
|
||||
+ key = key.lower()
|
||||
+ return dict.__contains__(self, key)
|
||||
|
||||
|
||||
c = ChainMap()
|
||||
c['a'] = 1
|
||||
@@ -315,7 +339,7 @@ class TestChainMap(unittest.TestCase):
|
||||
|
||||
|
||||
TestNT = namedtuple('TestNT', 'x y z') # type used for pickle tests
|
||||
|
||||
|
||||
-class TestNamedTuple(unittest.TestCase):
|
||||
+class TestNamedTuple(__TestCase):
|
||||
|
||||
|
||||
def test_factory(self):
|
||||
Point = namedtuple('Point', 'x y')
|
||||
@@ -666,8 +690,9 @@ class TestNamedTuple(unittest.TestCase):
|
||||
NT = namedtuple('NT', ['abc', 'def'], False, True)
|
||||
|
||||
|
||||
def test_namedtuple_subclass_issue_24931(self):
|
||||
- class Point(namedtuple('_Point', ['x', 'y'])):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Point(namedtuple('_Point', ['x', 'y'])):
|
||||
+ pass
|
||||
|
||||
|
||||
a = Point(3, 4)
|
||||
self.assertEqual(a._asdict(), OrderedDict([('x', 3), ('y', 4)]))
|
||||
@@ -722,21 +747,26 @@ class TestNamedTuple(unittest.TestCase):
|
||||
### Abstract Base Classes
|
||||
################################################################################
|
||||
|
||||
|
||||
-class ABCTestCase(unittest.TestCase):
|
||||
+class ABCTestCase(__TestCase):
|
||||
|
||||
|
||||
def validate_abstract_methods(self, abc, *names):
|
||||
methodstubs = dict.fromkeys(names, lambda s, *args: 0)
|
||||
|
||||
|
||||
# everything should work will all required methods are present
|
||||
- C = type('C', (abc,), methodstubs)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ C = type('C', (abc,), methodstubs)
|
||||
C()
|
||||
|
||||
|
||||
+ # Dynamo raises a hard error here that we can't easily capture
|
||||
+ # Commenting this part as this would also fail in eager if a user
|
||||
+ # attempt to run the same code
|
||||
|
|
@ -172,7 +172,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
+ # del stubs[name]
|
||||
+ # C = type('C', (abc,), stubs)
|
||||
+ # self.assertRaises(TypeError, C, name)
|
||||
|
||||
|
||||
def validate_isinstance(self, abc, name):
|
||||
stub = lambda s, *args: 0
|
||||
@@ -981,19 +1011,21 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
||||
|
|
@ -183,7 +183,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- class I(Iterable):
|
||||
- def __iter__(self):
|
||||
- return super().__iter__()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ # Check direct subclassing
|
||||
+ class I(Iterable):
|
||||
+ def __iter__(self):
|
||||
|
|
@ -197,7 +197,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- def __iter__(self): return iter([])
|
||||
- class ItBlocked(It):
|
||||
- __iter__ = None
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ # Check None blocking
|
||||
+ class It:
|
||||
+ def __iter__(self): return iter([])
|
||||
|
|
@ -216,7 +216,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- return iter(list())
|
||||
- def __reversed__(self):
|
||||
- return iter(list())
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ # Check direct subclassing
|
||||
+ class R(Reversible):
|
||||
+ def __iter__(self):
|
||||
|
|
@ -231,7 +231,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- def __reversed__(self): return reversed([])
|
||||
- class RevPlusIter(RevNoIter):
|
||||
- def __iter__(self): return iter([])
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ # Check reversible non-iterable (which is not Reversible)
|
||||
+ class RevNoIter:
|
||||
+ def __reversed__(self): return reversed([])
|
||||
|
|
@ -249,7 +249,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- __iter__ = None
|
||||
- class RevRevBlocked(Rev):
|
||||
- __reversed__ = None
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ # Check None blocking
|
||||
+ class Rev:
|
||||
+ def __iter__(self): return iter([])
|
||||
|
|
@ -274,7 +274,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- def __contains__(self, item):
|
||||
- return False
|
||||
- class DerCol(Col): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ # Check direct subclassing
|
||||
+ class Col(Collection):
|
||||
+ def __iter__(self):
|
||||
|
|
@ -300,7 +300,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- class ColNoCont:
|
||||
- def __iter__(self): return iter([])
|
||||
- def __len__(self): return 0
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class ColNoIter:
|
||||
+ def __len__(self): return 0
|
||||
+ def __contains__(self, item): return False
|
||||
|
|
@ -326,7 +326,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- def __contains__(self): return True
|
||||
- __iter__ = None
|
||||
+
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ # Check None blocking
|
||||
+ class SizeBlock:
|
||||
+ def __iter__(self): return iter([])
|
||||
|
|
@ -350,7 +350,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- return False
|
||||
- class NonCol(ColImpl):
|
||||
- __contains__ = None
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ # Check None blocking in subclass
|
||||
+ class ColImpl:
|
||||
+ def __iter__(self):
|
||||
|
|
@ -363,24 +363,24 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
+ __contains__ = None
|
||||
self.assertFalse(issubclass(NonCol, Collection))
|
||||
self.assertFalse(isinstance(NonCol(), Collection))
|
||||
|
||||
|
||||
@@ -1162,30 +1202,32 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
||||
self.assertTrue(issubclass(type(x), Iterator), repr(type(x)))
|
||||
self.validate_abstract_methods(Iterator, '__next__', '__iter__')
|
||||
|
||||
|
||||
- # Issue 10565
|
||||
- class NextOnly:
|
||||
- def __next__(self):
|
||||
- yield 1
|
||||
- return
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ # Issue 10565
|
||||
+ class NextOnly:
|
||||
+ def __next__(self):
|
||||
+ yield 1
|
||||
+ return
|
||||
self.assertNotIsInstance(NextOnly(), Iterator)
|
||||
|
||||
|
||||
def test_Generator(self):
|
||||
- class NonGen1:
|
||||
- def __iter__(self): return self
|
||||
|
|
@ -398,7 +398,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- def close(self): pass
|
||||
- def send(self, value): return value
|
||||
- def throw(self, typ, val=None, tb=None): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class NonGen1:
|
||||
+ def __iter__(self): return self
|
||||
+ def __next__(self): return None
|
||||
|
|
@ -415,27 +415,27 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
+ def close(self): pass
|
||||
+ def send(self, value): return value
|
||||
+ def throw(self, typ, val=None, tb=None): pass
|
||||
|
||||
|
||||
non_samples = [
|
||||
None, 42, 3.14, 1j, b"", "", (), [], {}, set(),
|
||||
@@ -1194,18 +1236,19 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
||||
self.assertNotIsInstance(x, Generator)
|
||||
self.assertFalse(issubclass(type(x), Generator), repr(type(x)))
|
||||
|
||||
|
||||
- class Gen:
|
||||
- def __iter__(self): return self
|
||||
- def __next__(self): return None
|
||||
- def close(self): pass
|
||||
- def send(self, value): return value
|
||||
- def throw(self, typ, val=None, tb=None): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Gen:
|
||||
+ def __iter__(self): return self
|
||||
+ def __next__(self): return None
|
||||
+ def close(self): pass
|
||||
+ def send(self, value): return value
|
||||
+ def throw(self, typ, val=None, tb=None): pass
|
||||
|
||||
|
||||
- class MinimalGen(Generator):
|
||||
- def send(self, value):
|
||||
- return value
|
||||
|
|
@ -446,50 +446,50 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
+ return value
|
||||
+ def throw(self, typ, val=None, tb=None):
|
||||
+ super().throw(typ, val, tb)
|
||||
|
||||
|
||||
def gen():
|
||||
yield 1
|
||||
@@ -1228,15 +1271,17 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
||||
mgen.throw, ValueError, ValueError("huhu"))
|
||||
self.assertRaises(StopIteration, mgen.throw, StopIteration())
|
||||
|
||||
|
||||
- class FailOnClose(Generator):
|
||||
- def send(self, value): return value
|
||||
- def throw(self, *args): raise ValueError
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class FailOnClose(Generator):
|
||||
+ def send(self, value): return value
|
||||
+ def throw(self, *args): raise ValueError
|
||||
|
||||
|
||||
self.assertRaises(ValueError, FailOnClose().close)
|
||||
|
||||
|
||||
- class IgnoreGeneratorExit(Generator):
|
||||
- def send(self, value): return value
|
||||
- def throw(self, *args): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class IgnoreGeneratorExit(Generator):
|
||||
+ def send(self, value): return value
|
||||
+ def throw(self, *args): pass
|
||||
|
||||
|
||||
self.assertRaises(RuntimeError, IgnoreGeneratorExit().close)
|
||||
|
||||
|
||||
@@ -1379,15 +1424,17 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
||||
|
||||
|
||||
def test_direct_subclassing(self):
|
||||
for B in Hashable, Iterable, Iterator, Reversible, Sized, Container, Callable:
|
||||
- class C(B):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C(B):
|
||||
+ pass
|
||||
self.assertTrue(issubclass(C, B))
|
||||
self.assertFalse(issubclass(int, C))
|
||||
|
||||
|
||||
def test_registration(self):
|
||||
for B in Hashable, Iterable, Iterator, Reversible, Sized, Container, Callable:
|
||||
- class C:
|
||||
- __hash__ = None # Make sure it isn't hashable by default
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C:
|
||||
+ __hash__ = None # Make sure it isn't hashable by default
|
||||
self.assertFalse(issubclass(C, B), B.__name__)
|
||||
|
|
@ -506,7 +506,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- return 0
|
||||
- def __iter__(self):
|
||||
- return iter([])
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MySet(Set):
|
||||
+ def __contains__(self, x):
|
||||
+ return False
|
||||
|
|
@ -515,11 +515,11 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
+ def __iter__(self):
|
||||
+ return iter([])
|
||||
self.validate_comparison(MySet())
|
||||
|
||||
|
||||
def test_hash_Set(self):
|
||||
@@ -1448,15 +1496,16 @@ class TestCollectionABCs(ABCTestCase):
|
||||
self.assertTrue(hash(a) == hash(b))
|
||||
|
||||
|
||||
def test_isdisjoint_Set(self):
|
||||
- class MySet(Set):
|
||||
- def __init__(self, itr):
|
||||
|
|
@ -530,7 +530,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- return iter(self.contents)
|
||||
- def __len__(self):
|
||||
- return len([x for x in self.contents])
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MySet(Set):
|
||||
+ def __init__(self, itr):
|
||||
+ self.contents = itr
|
||||
|
|
@ -545,7 +545,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
s3 = MySet((1, 5, 6))
|
||||
@@ -1464,15 +1513,16 @@ class TestCollectionABCs(ABCTestCase):
|
||||
self.assertFalse(s1.isdisjoint(s3))
|
||||
|
||||
|
||||
def test_equality_Set(self):
|
||||
- class MySet(Set):
|
||||
- def __init__(self, itr):
|
||||
|
|
@ -556,7 +556,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- return iter(self.contents)
|
||||
- def __len__(self):
|
||||
- return len([x for x in self.contents])
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MySet(Set):
|
||||
+ def __init__(self, itr):
|
||||
+ self.contents = itr
|
||||
|
|
@ -571,7 +571,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
s3 = MySet((3, 4))
|
||||
@@ -1486,15 +1536,16 @@ class TestCollectionABCs(ABCTestCase):
|
||||
self.assertNotEqual(s2, s3)
|
||||
|
||||
|
||||
def test_arithmetic_Set(self):
|
||||
- class MySet(Set):
|
||||
- def __init__(self, itr):
|
||||
|
|
@ -582,7 +582,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- return iter(self.contents)
|
||||
- def __len__(self):
|
||||
- return len([x for x in self.contents])
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MySet(Set):
|
||||
+ def __init__(self, itr):
|
||||
+ self.contents = itr
|
||||
|
|
@ -596,7 +596,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
s2 = MySet((3, 4, 5))
|
||||
s3 = s1 & s2
|
||||
@@ -1516,28 +1567,29 @@ class TestCollectionABCs(ABCTestCase):
|
||||
|
||||
|
||||
def test_issue_4920(self):
|
||||
# MutableSet.pop() method did not work
|
||||
- class MySet(MutableSet):
|
||||
|
|
@ -621,7 +621,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- return result
|
||||
- def __repr__(self):
|
||||
- return "MySet(%s)" % repr(list(self))
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MySet(MutableSet):
|
||||
+ __slots__=['__s']
|
||||
+ def __init__(self,items=None):
|
||||
|
|
@ -669,7 +669,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- return NotImplemented
|
||||
- def __lt__(self, x):
|
||||
- return NotImplemented
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyComparableSet(Set):
|
||||
+ def __contains__(self, x):
|
||||
+ return False
|
||||
|
|
@ -688,11 +688,11 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
+ return NotImplemented
|
||||
+ def __lt__(self, x):
|
||||
+ return NotImplemented
|
||||
|
||||
|
||||
cs = MyComparableSet()
|
||||
ncs = MyNonComparableSet()
|
||||
@@ -1591,13 +1644,14 @@ class TestCollectionABCs(ABCTestCase):
|
||||
|
||||
|
||||
def test_issue26915(self):
|
||||
# Container membership test should check identity first
|
||||
- class CustomSequence(Sequence):
|
||||
|
|
@ -702,7 +702,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- return self._seq[index]
|
||||
- def __len__(self):
|
||||
- return len(self._seq)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class CustomSequence(Sequence):
|
||||
+ def __init__(self, seq):
|
||||
+ self._seq = seq
|
||||
|
|
@ -710,11 +710,11 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
+ return self._seq[index]
|
||||
+ def __len__(self):
|
||||
+ return len(self._seq)
|
||||
|
||||
|
||||
nan = float('nan')
|
||||
obj = support.NEVER_EQ
|
||||
@@ -1622,30 +1676,31 @@ class TestCollectionABCs(ABCTestCase):
|
||||
|
||||
|
||||
def test_Set_from_iterable(self):
|
||||
"""Verify _from_iterable overridden to an instance method works."""
|
||||
- class SetUsingInstanceFromIterable(MutableSet):
|
||||
|
|
@ -723,48 +723,48 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- raise ValueError('created_by must be specified')
|
||||
- self.created_by = created_by
|
||||
- self._values = set(values)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class SetUsingInstanceFromIterable(MutableSet):
|
||||
+ def __init__(self, values, created_by):
|
||||
+ if not created_by:
|
||||
+ raise ValueError('created_by must be specified')
|
||||
+ self.created_by = created_by
|
||||
+ self._values = set(values)
|
||||
|
||||
|
||||
- def _from_iterable(self, values):
|
||||
- return type(self)(values, 'from_iterable')
|
||||
+ def _from_iterable(self, values):
|
||||
+ return type(self)(values, 'from_iterable')
|
||||
|
||||
|
||||
- def __contains__(self, value):
|
||||
- return value in self._values
|
||||
+ def __contains__(self, value):
|
||||
+ return value in self._values
|
||||
|
||||
|
||||
- def __iter__(self):
|
||||
- yield from self._values
|
||||
+ def __iter__(self):
|
||||
+ yield from self._values
|
||||
|
||||
|
||||
- def __len__(self):
|
||||
- return len(self._values)
|
||||
+ def __len__(self):
|
||||
+ return len(self._values)
|
||||
|
||||
|
||||
- def add(self, value):
|
||||
- self._values.add(value)
|
||||
+ def add(self, value):
|
||||
+ self._values.add(value)
|
||||
|
||||
|
||||
- def discard(self, value):
|
||||
- self._values.discard(value)
|
||||
+ def discard(self, value):
|
||||
+ self._values.discard(value)
|
||||
|
||||
|
||||
impl = SetUsingInstanceFromIterable([1, 2, 3], 'test')
|
||||
|
||||
|
||||
@@ -1678,20 +1733,21 @@ class TestCollectionABCs(ABCTestCase):
|
||||
|
||||
|
||||
def test_Set_interoperability_with_real_sets(self):
|
||||
# Issue: 8743
|
||||
- class ListSet(Set):
|
||||
|
|
@ -781,7 +781,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- return len(self.data)
|
||||
- def __repr__(self):
|
||||
- return 'Set({!r})'.format(self.data)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class ListSet(Set):
|
||||
+ def __init__(self, elements=()):
|
||||
+ self.data = []
|
||||
|
|
@ -796,7 +796,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
+ return len(self.data)
|
||||
+ def __repr__(self):
|
||||
+ return 'Set({!r})'.format(self.data)
|
||||
|
||||
|
||||
r1 = set('abc')
|
||||
r2 = set('bcd')
|
||||
@@ -1846,13 +1902,14 @@ class TestCollectionABCs(ABCTestCase):
|
||||
|
|
@ -810,7 +810,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- raise IndexError
|
||||
- def __iter__(self):
|
||||
- return iter(())
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyMapping(Mapping):
|
||||
+ def __len__(self):
|
||||
+ return 0
|
||||
|
|
@ -820,7 +820,7 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
+ return iter(())
|
||||
self.validate_comparison(MyMapping())
|
||||
self.assertRaises(TypeError, reversed, MyMapping())
|
||||
|
||||
|
||||
@@ -1860,7 +1917,7 @@ class TestCollectionABCs(ABCTestCase):
|
||||
for sample in [dict]:
|
||||
self.assertIsInstance(sample(), MutableMapping)
|
||||
|
|
@ -828,30 +828,30 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- self.validate_abstract_methods(MutableMapping, '__contains__', '__iter__', '__len__',
|
||||
+ self.validate_abstract_methods(MutableMapping, '__iter__', '__len__',
|
||||
'__getitem__', '__setitem__', '__delitem__')
|
||||
|
||||
|
||||
def test_MutableMapping_subclass(self):
|
||||
@@ -1903,15 +1960,16 @@ class TestCollectionABCs(ABCTestCase):
|
||||
'__getitem__')
|
||||
|
||||
|
||||
def test_Sequence_mixins(self):
|
||||
- class SequenceSubclass(Sequence):
|
||||
- def __init__(self, seq=()):
|
||||
- self.seq = seq
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class SequenceSubclass(Sequence):
|
||||
+ def __init__(self, seq=()):
|
||||
+ self.seq = seq
|
||||
|
||||
|
||||
- def __getitem__(self, index):
|
||||
- return self.seq[index]
|
||||
+ def __getitem__(self, index):
|
||||
+ return self.seq[index]
|
||||
|
||||
|
||||
- def __len__(self):
|
||||
- return len(self.seq)
|
||||
+ def __len__(self):
|
||||
+ return len(self.seq)
|
||||
|
||||
|
||||
# Compare Sequence.index() behavior to (list|str).index() behavior
|
||||
def assert_index_same(seq1, seq2, index_args):
|
||||
@@ -1983,24 +2041,25 @@ class TestCollectionABCs(ABCTestCase):
|
||||
|
|
@ -861,54 +861,54 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
- class MutableSequenceSubclass(MutableSequence):
|
||||
- def __init__(self):
|
||||
- self.lst = []
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MutableSequenceSubclass(MutableSequence):
|
||||
+ def __init__(self):
|
||||
+ self.lst = []
|
||||
|
||||
|
||||
- def __setitem__(self, index, value):
|
||||
- self.lst[index] = value
|
||||
+ def __setitem__(self, index, value):
|
||||
+ self.lst[index] = value
|
||||
|
||||
|
||||
- def __getitem__(self, index):
|
||||
- return self.lst[index]
|
||||
+ def __getitem__(self, index):
|
||||
+ return self.lst[index]
|
||||
|
||||
|
||||
- def __len__(self):
|
||||
- return len(self.lst)
|
||||
+ def __len__(self):
|
||||
+ return len(self.lst)
|
||||
|
||||
|
||||
- def __delitem__(self, index):
|
||||
- del self.lst[index]
|
||||
+ def __delitem__(self, index):
|
||||
+ del self.lst[index]
|
||||
|
||||
|
||||
- def insert(self, index, value):
|
||||
- self.lst.insert(index, value)
|
||||
+ def insert(self, index, value):
|
||||
+ self.lst.insert(index, value)
|
||||
|
||||
|
||||
mss = MutableSequenceSubclass()
|
||||
mss.append(0)
|
||||
@@ -2059,7 +2118,7 @@ class CounterSubclassWithGet(Counter):
|
||||
self.called = True
|
||||
return Counter.get(self, key, default)
|
||||
|
||||
|
||||
-class TestCounter(unittest.TestCase):
|
||||
+class TestCounter(__TestCase):
|
||||
|
||||
|
||||
def test_basics(self):
|
||||
c = Counter('abcaba')
|
||||
@@ -2225,8 +2284,9 @@ class TestCounter(unittest.TestCase):
|
||||
check(Counter(words))
|
||||
|
||||
|
||||
def test_copy_subclass(self):
|
||||
- class MyCounter(Counter):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyCounter(Counter):
|
||||
+ pass
|
||||
c = MyCounter('slartibartfast')
|
||||
|
|
@ -916,8 +916,8 @@ index cafc44007d1..4571e5a14fd 100644
|
|||
self.assertEqual(d, c)
|
||||
@@ -2402,10 +2462,5 @@ class TestCounter(unittest.TestCase):
|
||||
self.assertFalse(Counter(a=2, b=1, c=0) > Counter('aab'))
|
||||
|
||||
|
||||
|
||||
|
||||
-def load_tests(loader, tests, pattern):
|
||||
- tests.addTest(doctest.DocTestSuite(collections))
|
||||
- return tests
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ class TestUserObjects(__TestCase):
|
|||
self._copy_test(obj)
|
||||
|
||||
def test_dict_missing(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A(UserDict):
|
||||
def __missing__(self, key):
|
||||
return 456
|
||||
|
|
@ -193,7 +193,7 @@ class TestChainMap(__TestCase):
|
|||
self.assertTrue(ChainMap({}, {1:2}))
|
||||
|
||||
def test_missing(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class DefaultChainMap(ChainMap):
|
||||
def __missing__(self, key):
|
||||
return 999
|
||||
|
|
@ -228,7 +228,7 @@ class TestChainMap(__TestCase):
|
|||
('i', 9999), ('j', 0)])
|
||||
|
||||
def test_iter_not_calling_getitem_on_maps(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class DictWithGetItem(UserDict):
|
||||
def __init__(self, *args, **kwds):
|
||||
self.called = False
|
||||
|
|
@ -260,7 +260,7 @@ class TestChainMap(__TestCase):
|
|||
self.assertIs(m, d.maps[0])
|
||||
|
||||
# Use a different map than a dict
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class lowerdict(dict):
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, str):
|
||||
|
|
@ -690,7 +690,7 @@ class TestNamedTuple(__TestCase):
|
|||
NT = namedtuple('NT', ['abc', 'def'], False, True)
|
||||
|
||||
def test_namedtuple_subclass_issue_24931(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Point(namedtuple('_Point', ['x', 'y'])):
|
||||
pass
|
||||
|
||||
|
|
@ -753,7 +753,7 @@ class ABCTestCase(__TestCase):
|
|||
methodstubs = dict.fromkeys(names, lambda s, *args: 0)
|
||||
|
||||
# everything should work will all required methods are present
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
C = type('C', (abc,), methodstubs)
|
||||
C()
|
||||
|
||||
|
|
@ -1011,7 +1011,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|||
for x in samples:
|
||||
self.assertIsInstance(x, Iterable)
|
||||
self.assertTrue(issubclass(type(x), Iterable), repr(type(x)))
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
# Check direct subclassing
|
||||
class I(Iterable):
|
||||
def __iter__(self):
|
||||
|
|
@ -1020,7 +1020,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|||
self.assertFalse(issubclass(str, I))
|
||||
self.validate_abstract_methods(Iterable, '__iter__')
|
||||
self.validate_isinstance(Iterable, '__iter__')
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
# Check None blocking
|
||||
class It:
|
||||
def __iter__(self): return iter([])
|
||||
|
|
@ -1055,7 +1055,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|||
self.assertTrue(issubclass(Sequence, Reversible), repr(Sequence))
|
||||
self.assertFalse(issubclass(Mapping, Reversible), repr(Mapping))
|
||||
self.assertFalse(issubclass(MutableMapping, Reversible), repr(MutableMapping))
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
# Check direct subclassing
|
||||
class R(Reversible):
|
||||
def __iter__(self):
|
||||
|
|
@ -1065,7 +1065,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|||
self.assertEqual(list(reversed(R())), [])
|
||||
self.assertFalse(issubclass(float, R))
|
||||
self.validate_abstract_methods(Reversible, '__reversed__', '__iter__')
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
# Check reversible non-iterable (which is not Reversible)
|
||||
class RevNoIter:
|
||||
def __reversed__(self): return reversed([])
|
||||
|
|
@ -1075,7 +1075,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|||
self.assertFalse(isinstance(RevNoIter(), Reversible))
|
||||
self.assertTrue(issubclass(RevPlusIter, Reversible))
|
||||
self.assertTrue(isinstance(RevPlusIter(), Reversible))
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
# Check None blocking
|
||||
class Rev:
|
||||
def __iter__(self): return iter([])
|
||||
|
|
@ -1117,7 +1117,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|||
self.assertTrue(issubclass(Set, Collection), repr(Set))
|
||||
self.assertTrue(issubclass(MutableSet, Collection), repr(MutableSet))
|
||||
self.assertTrue(issubclass(Sequence, Collection), repr(MutableSet))
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
# Check direct subclassing
|
||||
class Col(Collection):
|
||||
def __iter__(self):
|
||||
|
|
@ -1138,7 +1138,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|||
self.validate_abstract_methods(Collection, '__len__', '__iter__',
|
||||
'__contains__')
|
||||
# Check sized container non-iterable (which is not Collection) etc.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class ColNoIter:
|
||||
def __len__(self): return 0
|
||||
def __contains__(self, item): return False
|
||||
|
|
@ -1155,7 +1155,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|||
self.assertFalse(issubclass(ColNoCont, Collection))
|
||||
self.assertFalse(isinstance(ColNoCont(), Collection))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
# Check None blocking
|
||||
class SizeBlock:
|
||||
def __iter__(self): return iter([])
|
||||
|
|
@ -1169,7 +1169,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|||
self.assertFalse(isinstance(SizeBlock(), Collection))
|
||||
self.assertFalse(issubclass(IterBlock, Collection))
|
||||
self.assertFalse(isinstance(IterBlock(), Collection))
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
# Check None blocking in subclass
|
||||
class ColImpl:
|
||||
def __iter__(self):
|
||||
|
|
@ -1202,7 +1202,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|||
self.assertTrue(issubclass(type(x), Iterator), repr(type(x)))
|
||||
self.validate_abstract_methods(Iterator, '__next__', '__iter__')
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
# Issue 10565
|
||||
class NextOnly:
|
||||
def __next__(self):
|
||||
|
|
@ -1211,7 +1211,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|||
self.assertNotIsInstance(NextOnly(), Iterator)
|
||||
|
||||
def test_Generator(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class NonGen1:
|
||||
def __iter__(self): return self
|
||||
def __next__(self): return None
|
||||
|
|
@ -1236,7 +1236,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|||
self.assertNotIsInstance(x, Generator)
|
||||
self.assertFalse(issubclass(type(x), Generator), repr(type(x)))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Gen:
|
||||
def __iter__(self): return self
|
||||
def __next__(self): return None
|
||||
|
|
@ -1271,14 +1271,14 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|||
mgen.throw, ValueError, ValueError("huhu"))
|
||||
self.assertRaises(StopIteration, mgen.throw, StopIteration())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class FailOnClose(Generator):
|
||||
def send(self, value): return value
|
||||
def throw(self, *args): raise ValueError
|
||||
|
||||
self.assertRaises(ValueError, FailOnClose().close)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class IgnoreGeneratorExit(Generator):
|
||||
def send(self, value): return value
|
||||
def throw(self, *args): pass
|
||||
|
|
@ -1424,7 +1424,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|||
|
||||
def test_direct_subclassing(self):
|
||||
for B in Hashable, Iterable, Iterator, Reversible, Sized, Container, Callable:
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C(B):
|
||||
pass
|
||||
self.assertTrue(issubclass(C, B))
|
||||
|
|
@ -1432,7 +1432,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
|
|||
|
||||
def test_registration(self):
|
||||
for B in Hashable, Iterable, Iterator, Reversible, Sized, Container, Callable:
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
__hash__ = None # Make sure it isn't hashable by default
|
||||
self.assertFalse(issubclass(C, B), B.__name__)
|
||||
|
|
@ -1470,7 +1470,7 @@ class TestCollectionABCs(ABCTestCase):
|
|||
self.assertIsInstance(sample(), Set)
|
||||
self.assertTrue(issubclass(sample, Set))
|
||||
self.validate_abstract_methods(Set, '__contains__', '__iter__', '__len__')
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MySet(Set):
|
||||
def __contains__(self, x):
|
||||
return False
|
||||
|
|
@ -1496,7 +1496,7 @@ class TestCollectionABCs(ABCTestCase):
|
|||
self.assertTrue(hash(a) == hash(b))
|
||||
|
||||
def test_isdisjoint_Set(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MySet(Set):
|
||||
def __init__(self, itr):
|
||||
self.contents = itr
|
||||
|
|
@ -1513,7 +1513,7 @@ class TestCollectionABCs(ABCTestCase):
|
|||
self.assertFalse(s1.isdisjoint(s3))
|
||||
|
||||
def test_equality_Set(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MySet(Set):
|
||||
def __init__(self, itr):
|
||||
self.contents = itr
|
||||
|
|
@ -1536,7 +1536,7 @@ class TestCollectionABCs(ABCTestCase):
|
|||
self.assertNotEqual(s2, s3)
|
||||
|
||||
def test_arithmetic_Set(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MySet(Set):
|
||||
def __init__(self, itr):
|
||||
self.contents = itr
|
||||
|
|
@ -1567,7 +1567,7 @@ class TestCollectionABCs(ABCTestCase):
|
|||
|
||||
def test_issue_4920(self):
|
||||
# MutableSet.pop() method did not work
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MySet(MutableSet):
|
||||
__slots__=['__s']
|
||||
def __init__(self,items=None):
|
||||
|
|
@ -1615,7 +1615,7 @@ class TestCollectionABCs(ABCTestCase):
|
|||
def test_issue16373(self):
|
||||
# Recursion error comparing comparable and noncomparable
|
||||
# Set instances
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyComparableSet(Set):
|
||||
def __contains__(self, x):
|
||||
return False
|
||||
|
|
@ -1644,7 +1644,7 @@ class TestCollectionABCs(ABCTestCase):
|
|||
|
||||
def test_issue26915(self):
|
||||
# Container membership test should check identity first
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class CustomSequence(Sequence):
|
||||
def __init__(self, seq):
|
||||
self._seq = seq
|
||||
|
|
@ -1676,7 +1676,7 @@ class TestCollectionABCs(ABCTestCase):
|
|||
|
||||
def test_Set_from_iterable(self):
|
||||
"""Verify _from_iterable overridden to an instance method works."""
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class SetUsingInstanceFromIterable(MutableSet):
|
||||
def __init__(self, values, created_by):
|
||||
if not created_by:
|
||||
|
|
@ -1733,7 +1733,7 @@ class TestCollectionABCs(ABCTestCase):
|
|||
|
||||
def test_Set_interoperability_with_real_sets(self):
|
||||
# Issue: 8743
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class ListSet(Set):
|
||||
def __init__(self, elements=()):
|
||||
self.data = []
|
||||
|
|
@ -1902,7 +1902,7 @@ class TestCollectionABCs(ABCTestCase):
|
|||
self.assertTrue(issubclass(sample, Mapping))
|
||||
self.validate_abstract_methods(Mapping, '__contains__', '__iter__', '__len__',
|
||||
'__getitem__')
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyMapping(Mapping):
|
||||
def __len__(self):
|
||||
return 0
|
||||
|
|
@ -1960,7 +1960,7 @@ class TestCollectionABCs(ABCTestCase):
|
|||
'__getitem__')
|
||||
|
||||
def test_Sequence_mixins(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class SequenceSubclass(Sequence):
|
||||
def __init__(self, seq=()):
|
||||
self.seq = seq
|
||||
|
|
@ -2041,7 +2041,7 @@ class TestCollectionABCs(ABCTestCase):
|
|||
def test_MutableSequence_mixins(self):
|
||||
# Test the mixins of MutableSequence by creating a minimal concrete
|
||||
# class inherited from it.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MutableSequenceSubclass(MutableSequence):
|
||||
def __init__(self):
|
||||
self.lst = []
|
||||
|
|
@ -2284,7 +2284,7 @@ class TestCounter(__TestCase):
|
|||
check(Counter(words))
|
||||
|
||||
def test_copy_subclass(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyCounter(Counter):
|
||||
pass
|
||||
c = MyCounter('slartibartfast')
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ index 6ff1a8ab29d..1572433c5ae 100644
|
|||
+ "test.test_iter",
|
||||
+ "test.typinganndata.ann_module",
|
||||
)
|
||||
|
||||
|
||||
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||
+ def find_spec(self, fullname, path, target=None):
|
||||
+ # Check if the import is the problematic one
|
||||
|
|
@ -74,7 +74,7 @@ index 6ff1a8ab29d..1572433c5ae 100644
|
|||
from math import isnan, copysign
|
||||
+import math
|
||||
import operator
|
||||
|
||||
|
||||
+VALID_UNDERSCORE_LITERALS = [
|
||||
+ '0_0_0',
|
||||
+ '4_2',
|
||||
|
|
@ -158,7 +158,7 @@ index 6ff1a8ab29d..1572433c5ae 100644
|
|||
@@ -45,7 +176,40 @@ class WithComplex:
|
||||
def __complex__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
-class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||
+class ComplexTest(__TestCase):
|
||||
+
|
||||
|
|
@ -194,13 +194,13 @@ index 6ff1a8ab29d..1572433c5ae 100644
|
|||
+ """
|
||||
+ self.assertFloatIdentical(x.real, y.real)
|
||||
+ self.assertFloatIdentical(x.imag, y.imag)
|
||||
|
||||
|
||||
def assertAlmostEqual(self, a, b):
|
||||
if isinstance(a, complex):
|
||||
@@ -74,6 +238,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||
# check that relative difference < eps
|
||||
self.assertTrue(abs((x-y)/y) < eps)
|
||||
|
||||
|
||||
+ def assertFloatsAreIdentical(self, x, y):
|
||||
+ """assert that floats x and y are identical, in the sense that:
|
||||
+ (1) both x and y are nans, or
|
||||
|
|
@ -230,58 +230,58 @@ index 6ff1a8ab29d..1572433c5ae 100644
|
|||
@@ -93,6 +280,7 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||
q = z.__truediv__(y)
|
||||
self.assertClose(q, x)
|
||||
|
||||
|
||||
+ @slowTest
|
||||
def test_truediv(self):
|
||||
simple_real = [float(i) for i in range(-5, 6)]
|
||||
simple_complex = [complex(x, y) for x in simple_real for y in simple_real]
|
||||
@@ -338,7 +526,10 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||
|
||||
|
||||
def test_boolcontext(self):
|
||||
for i in range(100):
|
||||
- self.assertTrue(complex(random() + 1e-6, random() + 1e-6))
|
||||
+ with torch._dynamo.set_fullgraph(False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ r1 = random()
|
||||
+ r2 = random()
|
||||
+ self.assertTrue(complex(r1 + 1e-6, r2 + 1e-6))
|
||||
self.assertTrue(not complex(0.0, 0.0))
|
||||
self.assertTrue(1j)
|
||||
|
||||
|
||||
@@ -431,12 +622,13 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||
self.assertRaises(TypeError, complex, WithComplex(1), object())
|
||||
self.assertRaises(TypeError, complex, WithComplex(None), object())
|
||||
|
||||
|
||||
- class EvilExc(Exception):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class EvilExc(Exception):
|
||||
+ pass
|
||||
|
||||
|
||||
- class evilcomplex:
|
||||
- def __complex__(self):
|
||||
- raise EvilExc
|
||||
+ class evilcomplex:
|
||||
+ def __complex__(self):
|
||||
+ raise EvilExc
|
||||
|
||||
|
||||
self.assertRaises(EvilExc, complex, evilcomplex())
|
||||
|
||||
|
||||
@@ -460,31 +652,33 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||
self.assertRaises(TypeError, complex, WithIndex(None), 1.5)
|
||||
self.assertRaises(TypeError, complex, 1.5, WithIndex(None))
|
||||
|
||||
|
||||
- class MyInt:
|
||||
- def __int__(self):
|
||||
- return 42
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyInt:
|
||||
+ def __int__(self):
|
||||
+ return 42
|
||||
|
||||
|
||||
self.assertRaises(TypeError, complex, MyInt())
|
||||
self.assertRaises(TypeError, complex, MyInt(), 1.5)
|
||||
self.assertRaises(TypeError, complex, 1.5, MyInt())
|
||||
|
||||
|
||||
- class complex0(complex):
|
||||
- """Test usage of __complex__() when inheriting from 'complex'"""
|
||||
- def __complex__(self):
|
||||
|
|
@ -299,7 +299,7 @@ index 6ff1a8ab29d..1572433c5ae 100644
|
|||
- complex is returned"""
|
||||
- def __complex__(self):
|
||||
- return None
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class complex0(complex):
|
||||
+ """Test usage of __complex__() when inheriting from 'complex'"""
|
||||
+ def __complex__(self):
|
||||
|
|
@ -317,12 +317,12 @@ index 6ff1a8ab29d..1572433c5ae 100644
|
|||
+ complex is returned"""
|
||||
+ def __complex__(self):
|
||||
+ return None
|
||||
|
||||
|
||||
check(complex(complex0(1j)), 0.0, 42.0)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
@@ -855,4 +1049,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
|
|
|
|||
|
|
@ -526,7 +526,7 @@ class ComplexTest(__TestCase):
|
|||
|
||||
def test_boolcontext(self):
|
||||
for i in range(100):
|
||||
with torch._dynamo.set_fullgraph(False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
r1 = random()
|
||||
r2 = random()
|
||||
self.assertTrue(complex(r1 + 1e-6, r2 + 1e-6))
|
||||
|
|
@ -622,7 +622,7 @@ class ComplexTest(__TestCase):
|
|||
self.assertRaises(TypeError, complex, WithComplex(1), object())
|
||||
self.assertRaises(TypeError, complex, WithComplex(None), object())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class EvilExc(Exception):
|
||||
pass
|
||||
|
||||
|
|
@ -652,7 +652,7 @@ class ComplexTest(__TestCase):
|
|||
self.assertRaises(TypeError, complex, WithIndex(None), 1.5)
|
||||
self.assertRaises(TypeError, complex, 1.5, WithIndex(None))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyInt:
|
||||
def __int__(self):
|
||||
return 42
|
||||
|
|
@ -661,7 +661,7 @@ class ComplexTest(__TestCase):
|
|||
self.assertRaises(TypeError, complex, MyInt(), 1.5)
|
||||
self.assertRaises(TypeError, complex, 1.5, MyInt())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class complex0(complex):
|
||||
"""Test usage of __complex__() when inheriting from 'complex'"""
|
||||
def __complex__(self):
|
||||
|
|
|
|||
|
|
@ -58,121 +58,121 @@ index cf651959803..256a824932d 100644
|
|||
+# ======= END DYNAMO PATCH =======
|
||||
+
|
||||
"""Unit tests for contextlib.py, and other context managers."""
|
||||
|
||||
|
||||
import io
|
||||
@@ -14,60 +68,67 @@ from test.support.testcase import ExceptionIsLikeMixin
|
||||
import weakref
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestAbstractContextManager(unittest.TestCase):
|
||||
+class TestAbstractContextManager(__TestCase):
|
||||
|
||||
|
||||
def test_enter(self):
|
||||
- class DefaultEnter(AbstractContextManager):
|
||||
- def __exit__(self, *args):
|
||||
- super().__exit__(*args)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class DefaultEnter(AbstractContextManager):
|
||||
+ def __exit__(self, *args):
|
||||
+ super().__exit__(*args)
|
||||
|
||||
|
||||
manager = DefaultEnter()
|
||||
self.assertIs(manager.__enter__(), manager)
|
||||
|
||||
|
||||
def test_slots(self):
|
||||
- class DefaultContextManager(AbstractContextManager):
|
||||
- __slots__ = ()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class DefaultContextManager(AbstractContextManager):
|
||||
+ __slots__ = ()
|
||||
|
||||
|
||||
- def __exit__(self, *args):
|
||||
- super().__exit__(*args)
|
||||
+ def __exit__(self, *args):
|
||||
+ super().__exit__(*args)
|
||||
|
||||
|
||||
with self.assertRaises(AttributeError):
|
||||
DefaultContextManager().var = 42
|
||||
|
||||
|
||||
def test_exit_is_abstract(self):
|
||||
- class MissingExit(AbstractContextManager):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MissingExit(AbstractContextManager):
|
||||
+ pass
|
||||
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
MissingExit()
|
||||
|
||||
|
||||
def test_structural_subclassing(self):
|
||||
- class ManagerFromScratch:
|
||||
- def __enter__(self):
|
||||
- return self
|
||||
- def __exit__(self, exc_type, exc_value, traceback):
|
||||
- return None
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class ManagerFromScratch:
|
||||
+ def __enter__(self):
|
||||
+ return self
|
||||
+ def __exit__(self, exc_type, exc_value, traceback):
|
||||
+ return None
|
||||
|
||||
|
||||
self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
|
||||
|
||||
|
||||
- class DefaultEnter(AbstractContextManager):
|
||||
- def __exit__(self, *args):
|
||||
- super().__exit__(*args)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class DefaultEnter(AbstractContextManager):
|
||||
+ def __exit__(self, *args):
|
||||
+ super().__exit__(*args)
|
||||
|
||||
|
||||
self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
|
||||
|
||||
|
||||
- class NoEnter(ManagerFromScratch):
|
||||
- __enter__ = None
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class NoEnter(ManagerFromScratch):
|
||||
+ __enter__ = None
|
||||
|
||||
|
||||
self.assertFalse(issubclass(NoEnter, AbstractContextManager))
|
||||
|
||||
|
||||
- class NoExit(ManagerFromScratch):
|
||||
- __exit__ = None
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class NoExit(ManagerFromScratch):
|
||||
+ __exit__ = None
|
||||
|
||||
|
||||
self.assertFalse(issubclass(NoExit, AbstractContextManager))
|
||||
|
||||
|
||||
|
||||
|
||||
-class ContextManagerTestCase(unittest.TestCase):
|
||||
+class ContextManagerTestCase(__TestCase):
|
||||
|
||||
|
||||
def test_contextmanager_plain(self):
|
||||
state = []
|
||||
@@ -115,8 +176,9 @@ class ContextManagerTestCase(unittest.TestCase):
|
||||
self.assertEqual(frames[0].line, '1/0')
|
||||
|
||||
|
||||
# Repeat with RuntimeError (which goes through a different code path)
|
||||
- class RuntimeErrorSubclass(RuntimeError):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class RuntimeErrorSubclass(RuntimeError):
|
||||
+ pass
|
||||
|
||||
|
||||
try:
|
||||
with f():
|
||||
@@ -128,8 +190,9 @@ class ContextManagerTestCase(unittest.TestCase):
|
||||
self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
|
||||
self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
|
||||
|
||||
|
||||
- class StopIterationSubclass(StopIteration):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class StopIterationSubclass(StopIteration):
|
||||
+ pass
|
||||
|
||||
|
||||
for stop_exc in (
|
||||
StopIteration('spam'),
|
||||
@@ -169,9 +232,9 @@ class ContextManagerTestCase(unittest.TestCase):
|
||||
|
|
@ -185,7 +185,7 @@ index cf651959803..256a824932d 100644
|
|||
+ # if support.check_impl_detail(cpython=True):
|
||||
+ # # The "gen" attribute is an implementation detail.
|
||||
+ # self.assertFalse(ctx.gen.gi_suspended)
|
||||
|
||||
|
||||
def test_contextmanager_trap_no_yield(self):
|
||||
@contextmanager
|
||||
@@ -191,9 +254,9 @@ class ContextManagerTestCase(unittest.TestCase):
|
||||
|
|
@ -198,50 +198,50 @@ index cf651959803..256a824932d 100644
|
|||
+ # if support.check_impl_detail(cpython=True):
|
||||
+ # # The "gen" attribute is an implementation detail.
|
||||
+ # self.assertFalse(ctx.gen.gi_suspended)
|
||||
|
||||
|
||||
def test_contextmanager_non_normalised(self):
|
||||
@contextmanager
|
||||
@@ -230,8 +293,9 @@ class ContextManagerTestCase(unittest.TestCase):
|
||||
def woohoo():
|
||||
yield
|
||||
|
||||
|
||||
- class StopIterationSubclass(StopIteration):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class StopIterationSubclass(StopIteration):
|
||||
+ pass
|
||||
|
||||
|
||||
for stop_exc in (StopIteration('spam'), StopIterationSubclass('spam')):
|
||||
with self.subTest(type=type(stop_exc)):
|
||||
@@ -344,8 +408,9 @@ def woohoo():
|
||||
self.assertEqual(target, (11, 22, 33, 44))
|
||||
|
||||
|
||||
def test_nokeepref(self):
|
||||
- class A:
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A:
|
||||
+ pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def woohoo(a, b):
|
||||
@@ -396,7 +461,7 @@ def woohoo():
|
||||
self.assertEqual(depth, 0)
|
||||
|
||||
|
||||
|
||||
|
||||
-class ClosingTestCase(unittest.TestCase):
|
||||
+class ClosingTestCase(__TestCase):
|
||||
|
||||
|
||||
@support.requires_docstrings
|
||||
def test_instance_docs(self):
|
||||
@@ -407,9 +472,10 @@ class ClosingTestCase(unittest.TestCase):
|
||||
|
||||
|
||||
def test_closing(self):
|
||||
state = []
|
||||
- class C:
|
||||
- def close(self):
|
||||
- state.append(1)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C:
|
||||
+ def close(self):
|
||||
+ state.append(1)
|
||||
|
|
@ -249,13 +249,13 @@ index cf651959803..256a824932d 100644
|
|||
self.assertEqual(state, [])
|
||||
with closing(x) as y:
|
||||
@@ -418,9 +484,10 @@ class ClosingTestCase(unittest.TestCase):
|
||||
|
||||
|
||||
def test_closing_error(self):
|
||||
state = []
|
||||
- class C:
|
||||
- def close(self):
|
||||
- state.append(1)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C:
|
||||
+ def close(self):
|
||||
+ state.append(1)
|
||||
|
|
@ -264,52 +264,52 @@ index cf651959803..256a824932d 100644
|
|||
with self.assertRaises(ZeroDivisionError):
|
||||
@@ -430,16 +497,17 @@ class ClosingTestCase(unittest.TestCase):
|
||||
self.assertEqual(state, [1])
|
||||
|
||||
|
||||
|
||||
|
||||
-class NullcontextTestCase(unittest.TestCase):
|
||||
+class NullcontextTestCase(__TestCase):
|
||||
def test_nullcontext(self):
|
||||
- class C:
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C:
|
||||
+ pass
|
||||
c = C()
|
||||
with nullcontext(c) as c_in:
|
||||
self.assertIs(c_in, c)
|
||||
|
||||
|
||||
|
||||
|
||||
-class FileContextTestCase(unittest.TestCase):
|
||||
+class FileContextTestCase(__TestCase):
|
||||
|
||||
|
||||
def testWithOpen(self):
|
||||
tfn = tempfile.mktemp()
|
||||
@@ -457,7 +525,7 @@ class FileContextTestCase(unittest.TestCase):
|
||||
finally:
|
||||
os_helper.unlink(tfn)
|
||||
|
||||
|
||||
-class LockContextTestCase(unittest.TestCase):
|
||||
+class LockContextTestCase(__TestCase):
|
||||
|
||||
|
||||
def boilerPlate(self, lock, locked):
|
||||
self.assertFalse(locked())
|
||||
@@ -520,7 +588,7 @@ class mycontext(ContextDecorator):
|
||||
return self.catch
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestContextDecorator(unittest.TestCase):
|
||||
+class TestContextDecorator(__TestCase):
|
||||
|
||||
|
||||
@support.requires_docstrings
|
||||
def test_instance_docs(self):
|
||||
@@ -584,13 +652,14 @@ class TestContextDecorator(unittest.TestCase):
|
||||
def test_decorating_method(self):
|
||||
context = mycontext()
|
||||
|
||||
|
||||
- class Test(object):
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Test(object):
|
||||
|
||||
|
||||
- @context
|
||||
- def method(self, a, b, c=None):
|
||||
- self.a = a
|
||||
|
|
@ -320,84 +320,84 @@ index cf651959803..256a824932d 100644
|
|||
+ self.a = a
|
||||
+ self.b = b
|
||||
+ self.c = c
|
||||
|
||||
|
||||
# these tests are for argument passing when used as a decorator
|
||||
test = Test()
|
||||
@@ -612,11 +681,12 @@ class TestContextDecorator(unittest.TestCase):
|
||||
|
||||
|
||||
|
||||
|
||||
def test_typo_enter(self):
|
||||
- class mycontext(ContextDecorator):
|
||||
- def __unter__(self):
|
||||
- pass
|
||||
- def __exit__(self, *exc):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class mycontext(ContextDecorator):
|
||||
+ def __unter__(self):
|
||||
+ pass
|
||||
+ def __exit__(self, *exc):
|
||||
+ pass
|
||||
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'the context manager'):
|
||||
with mycontext():
|
||||
@@ -624,11 +694,12 @@ class TestContextDecorator(unittest.TestCase):
|
||||
|
||||
|
||||
|
||||
|
||||
def test_typo_exit(self):
|
||||
- class mycontext(ContextDecorator):
|
||||
- def __enter__(self):
|
||||
- pass
|
||||
- def __uxit__(self, *exc):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class mycontext(ContextDecorator):
|
||||
+ def __enter__(self):
|
||||
+ pass
|
||||
+ def __uxit__(self, *exc):
|
||||
+ pass
|
||||
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'the context manager.*__exit__'):
|
||||
with mycontext():
|
||||
@@ -636,19 +707,20 @@ class TestContextDecorator(unittest.TestCase):
|
||||
|
||||
|
||||
|
||||
|
||||
def test_contextdecorator_as_mixin(self):
|
||||
- class somecontext(object):
|
||||
- started = False
|
||||
- exc = None
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class somecontext(object):
|
||||
+ started = False
|
||||
+ exc = None
|
||||
|
||||
|
||||
- def __enter__(self):
|
||||
- self.started = True
|
||||
- return self
|
||||
+ def __enter__(self):
|
||||
+ self.started = True
|
||||
+ return self
|
||||
|
||||
|
||||
- def __exit__(self, *exc):
|
||||
- self.exc = exc
|
||||
+ def __exit__(self, *exc):
|
||||
+ self.exc = exc
|
||||
|
||||
|
||||
- class mycontext(somecontext, ContextDecorator):
|
||||
- pass
|
||||
+ class mycontext(somecontext, ContextDecorator):
|
||||
+ pass
|
||||
|
||||
|
||||
context = mycontext()
|
||||
@context
|
||||
@@ -680,7 +752,7 @@ class TestContextDecorator(unittest.TestCase):
|
||||
self.assertEqual(state, [1, 'something else', 999])
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestBaseExitStack:
|
||||
+class _TestBaseExitStack:
|
||||
exit_stack = None
|
||||
|
||||
|
||||
@support.requires_docstrings
|
||||
@@ -745,13 +817,14 @@ class TestBaseExitStack:
|
||||
self.assertIsNone(exc_type)
|
||||
|
|
@ -410,7 +410,7 @@ index cf651959803..256a824932d 100644
|
|||
- self.fail("Should not be called!")
|
||||
- def __exit__(self, *exc_details):
|
||||
- self.check_exc(*exc_details)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class ExitCM(object):
|
||||
+ def __init__(self, check_exc):
|
||||
+ self.check_exc = check_exc
|
||||
|
|
@ -423,25 +423,25 @@ index cf651959803..256a824932d 100644
|
|||
self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
|
||||
@@ -770,11 +843,12 @@ class TestBaseExitStack:
|
||||
1/0
|
||||
|
||||
|
||||
def test_enter_context(self):
|
||||
- class TestCM(object):
|
||||
- def __enter__(self):
|
||||
- result.append(1)
|
||||
- def __exit__(self, *exc_details):
|
||||
- result.append(3)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class TestCM(object):
|
||||
+ def __enter__(self):
|
||||
+ result.append(1)
|
||||
+ def __exit__(self, *exc_details):
|
||||
+ result.append(3)
|
||||
|
||||
|
||||
result = []
|
||||
cm = TestCM()
|
||||
@@ -789,14 +863,15 @@ class TestBaseExitStack:
|
||||
self.assertEqual(result, [1, 2, 3, 4])
|
||||
|
||||
|
||||
def test_enter_context_errors(self):
|
||||
- class LacksEnterAndExit:
|
||||
- pass
|
||||
|
|
@ -450,7 +450,7 @@ index cf651959803..256a824932d 100644
|
|||
- pass
|
||||
- class LacksExit:
|
||||
- def __enter__(self):
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class LacksEnterAndExit:
|
||||
pass
|
||||
+ class LacksEnter:
|
||||
|
|
@ -459,7 +459,7 @@ index cf651959803..256a824932d 100644
|
|||
+ class LacksExit:
|
||||
+ def __enter__(self):
|
||||
+ pass
|
||||
|
||||
|
||||
with self.exit_stack() as stack:
|
||||
with self.assertRaisesRegex(TypeError, 'the context manager'):
|
||||
@@ -877,32 +952,33 @@ class TestBaseExitStack:
|
||||
|
|
@ -492,7 +492,7 @@ index cf651959803..256a824932d 100644
|
|||
- def __exit__(self, *exc_details):
|
||||
- type(self).saved_details = exc_details
|
||||
- return True
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class RaiseExc:
|
||||
+ def __init__(self, exc):
|
||||
+ self.exc = exc
|
||||
|
|
@ -519,47 +519,47 @@ index cf651959803..256a824932d 100644
|
|||
+ def __exit__(self, *exc_details):
|
||||
+ type(self).saved_details = exc_details
|
||||
+ return True
|
||||
|
||||
|
||||
try:
|
||||
with RaiseExc(IndexError):
|
||||
@@ -957,8 +1033,9 @@ class TestBaseExitStack:
|
||||
# Ensure ExitStack chaining matches actual nested `with` statements
|
||||
# regarding explicit __context__ = None.
|
||||
|
||||
|
||||
- class MyException(Exception):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyException(Exception):
|
||||
+ pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def my_cm():
|
||||
@@ -1096,7 +1173,8 @@ class TestBaseExitStack:
|
||||
stack.callback(int)
|
||||
|
||||
|
||||
def test_instance_bypass(self):
|
||||
- class Example(object): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Example(object): pass
|
||||
cm = Example()
|
||||
cm.__enter__ = object()
|
||||
cm.__exit__ = object()
|
||||
@@ -1108,8 +1186,9 @@ class TestBaseExitStack:
|
||||
|
||||
|
||||
def test_dont_reraise_RuntimeError(self):
|
||||
# https://bugs.python.org/issue27122
|
||||
- class UniqueException(Exception): pass
|
||||
- class UniqueRuntimeError(RuntimeError): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class UniqueException(Exception): pass
|
||||
+ class UniqueRuntimeError(RuntimeError): pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def second():
|
||||
@@ -1141,7 +1220,7 @@ class TestBaseExitStack:
|
||||
self.assertIs(exc.__cause__, exc.__context__)
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestExitStack(TestBaseExitStack, unittest.TestCase):
|
||||
+class TestExitStack(_TestBaseExitStack, __TestCase):
|
||||
exit_stack = ExitStack
|
||||
|
|
@ -567,40 +567,40 @@ index cf651959803..256a824932d 100644
|
|||
('__exit__', 'raise exc'),
|
||||
@@ -1149,7 +1228,7 @@ class TestExitStack(TestBaseExitStack, unittest.TestCase):
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestRedirectStream:
|
||||
+class _TestRedirectStream:
|
||||
|
||||
|
||||
redirect_stream = None
|
||||
orig_stream = None
|
||||
@@ -1206,19 +1285,19 @@ class TestRedirectStream:
|
||||
self.assertEqual(s, "Hello World!\n")
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestRedirectStdout(TestRedirectStream, unittest.TestCase):
|
||||
+class TestRedirectStdout(_TestRedirectStream, __TestCase):
|
||||
|
||||
|
||||
redirect_stream = redirect_stdout
|
||||
orig_stream = "stdout"
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestRedirectStderr(TestRedirectStream, unittest.TestCase):
|
||||
+class TestRedirectStderr(_TestRedirectStream, __TestCase):
|
||||
|
||||
|
||||
redirect_stream = redirect_stderr
|
||||
orig_stream = "stderr"
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase):
|
||||
+class TestSuppress(ExceptionIsLikeMixin, __TestCase):
|
||||
|
||||
|
||||
@support.requires_docstrings
|
||||
def test_instance_docs(self):
|
||||
@@ -1315,7 +1394,7 @@ class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestChdir(unittest.TestCase):
|
||||
+class TestChdir(__TestCase):
|
||||
def make_relative_path(self, *parts):
|
||||
|
|
@ -609,14 +609,14 @@ index cf651959803..256a824932d 100644
|
|||
@@ -1331,6 +1410,7 @@ class TestChdir(unittest.TestCase):
|
||||
self.assertEqual(os.getcwd(), target)
|
||||
self.assertEqual(os.getcwd(), old_cwd)
|
||||
|
||||
|
||||
+ @unittest.skip("Missing archivetestdata")
|
||||
def test_reentrant(self):
|
||||
old_cwd = os.getcwd()
|
||||
target1 = self.make_relative_path('data')
|
||||
@@ -1363,4 +1443,4 @@ class TestChdir(unittest.TestCase):
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ import weakref
|
|||
class TestAbstractContextManager(__TestCase):
|
||||
|
||||
def test_enter(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class DefaultEnter(AbstractContextManager):
|
||||
def __exit__(self, *args):
|
||||
super().__exit__(*args)
|
||||
|
|
@ -80,7 +80,7 @@ class TestAbstractContextManager(__TestCase):
|
|||
self.assertIs(manager.__enter__(), manager)
|
||||
|
||||
def test_slots(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class DefaultContextManager(AbstractContextManager):
|
||||
__slots__ = ()
|
||||
|
||||
|
|
@ -91,7 +91,7 @@ class TestAbstractContextManager(__TestCase):
|
|||
DefaultContextManager().var = 42
|
||||
|
||||
def test_exit_is_abstract(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MissingExit(AbstractContextManager):
|
||||
pass
|
||||
|
||||
|
|
@ -99,7 +99,7 @@ class TestAbstractContextManager(__TestCase):
|
|||
MissingExit()
|
||||
|
||||
def test_structural_subclassing(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class ManagerFromScratch:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
|
@ -108,20 +108,20 @@ class TestAbstractContextManager(__TestCase):
|
|||
|
||||
self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class DefaultEnter(AbstractContextManager):
|
||||
def __exit__(self, *args):
|
||||
super().__exit__(*args)
|
||||
|
||||
self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class NoEnter(ManagerFromScratch):
|
||||
__enter__ = None
|
||||
|
||||
self.assertFalse(issubclass(NoEnter, AbstractContextManager))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class NoExit(ManagerFromScratch):
|
||||
__exit__ = None
|
||||
|
||||
|
|
@ -176,7 +176,7 @@ class ContextManagerTestCase(__TestCase):
|
|||
self.assertEqual(frames[0].line, '1/0')
|
||||
|
||||
# Repeat with RuntimeError (which goes through a different code path)
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class RuntimeErrorSubclass(RuntimeError):
|
||||
pass
|
||||
|
||||
|
|
@ -190,7 +190,7 @@ class ContextManagerTestCase(__TestCase):
|
|||
self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
|
||||
self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class StopIterationSubclass(StopIteration):
|
||||
pass
|
||||
|
||||
|
|
@ -293,7 +293,7 @@ class ContextManagerTestCase(__TestCase):
|
|||
def woohoo():
|
||||
yield
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class StopIterationSubclass(StopIteration):
|
||||
pass
|
||||
|
||||
|
|
@ -408,7 +408,7 @@ def woohoo():
|
|||
self.assertEqual(target, (11, 22, 33, 44))
|
||||
|
||||
def test_nokeepref(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A:
|
||||
pass
|
||||
|
||||
|
|
@ -472,7 +472,7 @@ class ClosingTestCase(__TestCase):
|
|||
|
||||
def test_closing(self):
|
||||
state = []
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
def close(self):
|
||||
state.append(1)
|
||||
|
|
@ -484,7 +484,7 @@ class ClosingTestCase(__TestCase):
|
|||
|
||||
def test_closing_error(self):
|
||||
state = []
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
def close(self):
|
||||
state.append(1)
|
||||
|
|
@ -499,7 +499,7 @@ class ClosingTestCase(__TestCase):
|
|||
|
||||
class NullcontextTestCase(__TestCase):
|
||||
def test_nullcontext(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
pass
|
||||
c = C()
|
||||
|
|
@ -652,7 +652,7 @@ class TestContextDecorator(__TestCase):
|
|||
def test_decorating_method(self):
|
||||
context = mycontext()
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Test(object):
|
||||
|
||||
@context
|
||||
|
|
@ -681,7 +681,7 @@ class TestContextDecorator(__TestCase):
|
|||
|
||||
|
||||
def test_typo_enter(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class mycontext(ContextDecorator):
|
||||
def __unter__(self):
|
||||
pass
|
||||
|
|
@ -694,7 +694,7 @@ class TestContextDecorator(__TestCase):
|
|||
|
||||
|
||||
def test_typo_exit(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class mycontext(ContextDecorator):
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
|
@ -707,7 +707,7 @@ class TestContextDecorator(__TestCase):
|
|||
|
||||
|
||||
def test_contextdecorator_as_mixin(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class somecontext(object):
|
||||
started = False
|
||||
exc = None
|
||||
|
|
@ -817,7 +817,7 @@ class _TestBaseExitStack:
|
|||
self.assertIsNone(exc_type)
|
||||
self.assertIsNone(exc)
|
||||
self.assertIsNone(exc_tb)
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class ExitCM(object):
|
||||
def __init__(self, check_exc):
|
||||
self.check_exc = check_exc
|
||||
|
|
@ -843,7 +843,7 @@ class _TestBaseExitStack:
|
|||
1/0
|
||||
|
||||
def test_enter_context(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class TestCM(object):
|
||||
def __enter__(self):
|
||||
result.append(1)
|
||||
|
|
@ -863,7 +863,7 @@ class _TestBaseExitStack:
|
|||
self.assertEqual(result, [1, 2, 3, 4])
|
||||
|
||||
def test_enter_context_errors(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class LacksEnterAndExit:
|
||||
pass
|
||||
class LacksEnter:
|
||||
|
|
@ -952,7 +952,7 @@ class _TestBaseExitStack:
|
|||
def test_exit_exception_chaining_reference(self):
|
||||
# Sanity check to make sure that ExitStack chaining matches
|
||||
# actual nested with statements
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class RaiseExc:
|
||||
def __init__(self, exc):
|
||||
self.exc = exc
|
||||
|
|
@ -1033,7 +1033,7 @@ class _TestBaseExitStack:
|
|||
# Ensure ExitStack chaining matches actual nested `with` statements
|
||||
# regarding explicit __context__ = None.
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyException(Exception):
|
||||
pass
|
||||
|
||||
|
|
@ -1173,7 +1173,7 @@ class _TestBaseExitStack:
|
|||
stack.callback(int)
|
||||
|
||||
def test_instance_bypass(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Example(object): pass
|
||||
cm = Example()
|
||||
cm.__enter__ = object()
|
||||
|
|
@ -1186,7 +1186,7 @@ class _TestBaseExitStack:
|
|||
|
||||
def test_dont_reraise_RuntimeError(self):
|
||||
# https://bugs.python.org/issue27122
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class UniqueException(Exception): pass
|
||||
class UniqueRuntimeError(RuntimeError): pass
|
||||
|
||||
|
|
|
|||
|
|
@ -61,19 +61,19 @@ index bdbe9b81e8f..d55f1dc54c6 100644
|
|||
+
|
||||
+
|
||||
"""Unit tests for collections.defaultdict."""
|
||||
|
||||
|
||||
import copy
|
||||
@@ -9,7 +66,7 @@ from collections import defaultdict
|
||||
def foobar():
|
||||
return list
|
||||
|
||||
|
||||
-class TestDefaultDict(unittest.TestCase):
|
||||
+class TestDefaultDict(__TestCase):
|
||||
|
||||
|
||||
def test_basic(self):
|
||||
d1 = defaultdict()
|
||||
@@ -127,11 +184,12 @@ class TestDefaultDict(unittest.TestCase):
|
||||
|
||||
|
||||
def test_recursive_repr(self):
|
||||
# Issue2045: stack overflow when default_factory is a bound method
|
||||
- class sub(defaultdict):
|
||||
|
|
@ -81,7 +81,7 @@ index bdbe9b81e8f..d55f1dc54c6 100644
|
|||
- self.default_factory = self._factory
|
||||
- def _factory(self):
|
||||
- return []
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class sub(defaultdict):
|
||||
+ def __init__(self):
|
||||
+ self.default_factory = self._factory
|
||||
|
|
@ -92,7 +92,7 @@ index bdbe9b81e8f..d55f1dc54c6 100644
|
|||
r"sub\(<bound method .*sub\._factory "
|
||||
@@ -187,4 +245,4 @@ class TestDefaultDict(unittest.TestCase):
|
||||
i |= None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ class TestDefaultDict(__TestCase):
|
|||
|
||||
def test_recursive_repr(self):
|
||||
# Issue2045: stack overflow when default_factory is a bound method
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class sub(defaultdict):
|
||||
def __init__(self):
|
||||
self.default_factory = self._factory
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -71,7 +71,7 @@ from test.support import import_helper, get_c_recursion_limit
|
|||
class DictTest(__TestCase):
|
||||
|
||||
def test_invalid_keyword_arguments(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Custom(dict):
|
||||
pass
|
||||
for invalid in {1 : 2}, Custom({1 : 2}):
|
||||
|
|
@ -166,7 +166,7 @@ class DictTest(__TestCase):
|
|||
|
||||
def test_views_mapping(self):
|
||||
mappingproxy = type(type.__dict__)
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Dict(dict):
|
||||
pass
|
||||
for cls in [dict, Dict]:
|
||||
|
|
@ -216,7 +216,7 @@ class DictTest(__TestCase):
|
|||
|
||||
self.assertRaises(TypeError, d.__getitem__)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BadEq(object):
|
||||
def __eq__(self, other):
|
||||
raise Exc()
|
||||
|
|
@ -227,7 +227,7 @@ class DictTest(__TestCase):
|
|||
d[BadEq()] = 42
|
||||
self.assertRaises(KeyError, d.__getitem__, 23)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Exc(Exception): pass
|
||||
|
||||
class BadHash(object):
|
||||
|
|
@ -262,7 +262,7 @@ class DictTest(__TestCase):
|
|||
|
||||
self.assertRaises((TypeError, AttributeError), d.update, None)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class SimpleUserDict:
|
||||
def __init__(self):
|
||||
self.d = {1:1, 2:2, 3:3}
|
||||
|
|
@ -274,18 +274,18 @@ class DictTest(__TestCase):
|
|||
d.update(SimpleUserDict())
|
||||
self.assertEqual(d, {1:1, 2:2, 3:3})
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Exc(Exception): pass
|
||||
|
||||
d.clear()
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class FailingUserDict:
|
||||
def keys(self):
|
||||
raise Exc
|
||||
self.assertRaises(Exc, d.update, FailingUserDict())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class FailingUserDict:
|
||||
def keys(self):
|
||||
class BogonIter:
|
||||
|
|
@ -303,7 +303,7 @@ class DictTest(__TestCase):
|
|||
return key
|
||||
self.assertRaises(Exc, d.update, FailingUserDict())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class FailingUserDict:
|
||||
def keys(self):
|
||||
class BogonIter:
|
||||
|
|
@ -323,7 +323,7 @@ class DictTest(__TestCase):
|
|||
self.assertRaises(Exc, d.update, FailingUserDict())
|
||||
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class badseq(object):
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
|
@ -346,13 +346,13 @@ class DictTest(__TestCase):
|
|||
yield 1
|
||||
self.assertEqual(d.fromkeys(g()), {1:None})
|
||||
self.assertRaises(TypeError, {}.fromkeys, 3)
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class dictlike(dict): pass
|
||||
self.assertEqual(dictlike.fromkeys('a'), {'a':None})
|
||||
self.assertEqual(dictlike().fromkeys('a'), {'a':None})
|
||||
self.assertIsInstance(dictlike.fromkeys('a'), dictlike)
|
||||
self.assertIsInstance(dictlike().fromkeys('a'), dictlike)
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class mydict(dict):
|
||||
def __new__(cls):
|
||||
return collections.UserDict()
|
||||
|
|
@ -361,7 +361,7 @@ class DictTest(__TestCase):
|
|||
self.assertIsInstance(ud, collections.UserDict)
|
||||
self.assertRaises(TypeError, dict.fromkeys)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Exc(Exception): pass
|
||||
|
||||
class baddict1(dict):
|
||||
|
|
@ -370,7 +370,7 @@ class DictTest(__TestCase):
|
|||
|
||||
self.assertRaises(Exc, baddict1.fromkeys, [1])
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BadSeq(object):
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
|
@ -379,7 +379,7 @@ class DictTest(__TestCase):
|
|||
|
||||
self.assertRaises(Exc, dict.fromkeys, BadSeq())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class baddict2(dict):
|
||||
def __setitem__(self, key, value):
|
||||
raise Exc()
|
||||
|
|
@ -398,7 +398,7 @@ class DictTest(__TestCase):
|
|||
self.assertEqual(dict.fromkeys(d, 0), res)
|
||||
|
||||
# test fast path when object's constructor returns large non-empty dict
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class baddict3(dict):
|
||||
def __new__(cls):
|
||||
return d
|
||||
|
|
@ -408,7 +408,7 @@ class DictTest(__TestCase):
|
|||
self.assertEqual(baddict3.fromkeys({"a", "b", "c"}), res)
|
||||
|
||||
# test slow path when object is a proper subclass of dict
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class baddict4(dict):
|
||||
def __init__(self):
|
||||
dict.__init__(self, d)
|
||||
|
|
@ -447,7 +447,7 @@ class DictTest(__TestCase):
|
|||
self.assertEqual(len(d2), len(d) + 1)
|
||||
|
||||
def test_copy_maintains_tracking(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A:
|
||||
pass
|
||||
|
||||
|
|
@ -495,7 +495,7 @@ class DictTest(__TestCase):
|
|||
self.assertRaises(TypeError, d.setdefault)
|
||||
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Exc(Exception): pass
|
||||
|
||||
class BadHash(object):
|
||||
|
|
@ -513,7 +513,7 @@ class DictTest(__TestCase):
|
|||
|
||||
def test_setdefault_atomic(self):
|
||||
# Issue #13521: setdefault() calls __hash__ and __eq__ only once.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Hashed(object):
|
||||
def __init__(self):
|
||||
self.hash_count = 0
|
||||
|
|
@ -533,7 +533,7 @@ class DictTest(__TestCase):
|
|||
self.assertEqual(hashed1.eq_count + hashed2.eq_count, 1)
|
||||
|
||||
def test_setitem_atomic_at_resize(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Hashed(object):
|
||||
def __init__(self):
|
||||
self.hash_count = 0
|
||||
|
|
@ -599,7 +599,7 @@ class DictTest(__TestCase):
|
|||
|
||||
self.assertRaises(TypeError, d.pop)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Exc(Exception): pass
|
||||
|
||||
class BadHash(object):
|
||||
|
|
@ -652,7 +652,7 @@ class DictTest(__TestCase):
|
|||
|
||||
def test_mutating_lookup(self):
|
||||
# changing dict during a lookup (issue #14417)
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class NastyKey:
|
||||
mutate_dict = None
|
||||
|
||||
|
|
@ -686,7 +686,7 @@ class DictTest(__TestCase):
|
|||
d[1] = d
|
||||
self.assertEqual(repr(d), '{1: {...}}')
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Exc(Exception): pass
|
||||
|
||||
class BadRepr(object):
|
||||
|
|
@ -706,7 +706,7 @@ class DictTest(__TestCase):
|
|||
self.assertEqual({}, {})
|
||||
self.assertEqual({1: 2}, {1: 2})
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Exc(Exception): pass
|
||||
|
||||
class BadCmp(object):
|
||||
|
|
@ -770,7 +770,7 @@ class DictTest(__TestCase):
|
|||
self.assertFalse(larger == larger3)
|
||||
|
||||
def test_errors_in_view_containment_check(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
def __eq__(self, other):
|
||||
raise RuntimeError
|
||||
|
|
@ -853,7 +853,7 @@ class DictTest(__TestCase):
|
|||
# (E) subclass defines __missing__ method raising RuntimeError
|
||||
# (F) subclass sets __missing__ instance variable (no effect)
|
||||
# (G) subclass doesn't define __missing__ at all
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class D(dict):
|
||||
def __missing__(self, key):
|
||||
return 42
|
||||
|
|
@ -864,7 +864,7 @@ class DictTest(__TestCase):
|
|||
self.assertNotIn(2, d.keys())
|
||||
self.assertEqual(d[2], 42)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class E(dict):
|
||||
def __missing__(self, key):
|
||||
raise RuntimeError(key)
|
||||
|
|
@ -873,7 +873,7 @@ class DictTest(__TestCase):
|
|||
e[42]
|
||||
self.assertEqual(c.exception.args, (42,))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class F(dict):
|
||||
def __init__(self):
|
||||
# An instance variable __missing__ should have no effect
|
||||
|
|
@ -883,7 +883,7 @@ class DictTest(__TestCase):
|
|||
f[42]
|
||||
self.assertEqual(c.exception.args, (42,))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class G(dict):
|
||||
pass
|
||||
g = G()
|
||||
|
|
@ -900,7 +900,7 @@ class DictTest(__TestCase):
|
|||
|
||||
def test_bad_key(self):
|
||||
# Dictionary lookups should fail if __eq__() raises an exception.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class CustomException(Exception):
|
||||
pass
|
||||
|
||||
|
|
@ -947,7 +947,7 @@ class DictTest(__TestCase):
|
|||
# Another dict resizing bug (SF bug #1456209).
|
||||
# This caused Segmentation faults or Illegal instructions.
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class X(object):
|
||||
def __hash__(self):
|
||||
return 5
|
||||
|
|
@ -977,7 +977,7 @@ class DictTest(__TestCase):
|
|||
def test_container_iterator(self):
|
||||
# Bug #3680: tp_traverse was not implemented for dictiter and
|
||||
# dictview objects.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C(object):
|
||||
pass
|
||||
views = (dict.items, dict.values, dict.keys)
|
||||
|
|
@ -1033,7 +1033,7 @@ class DictTest(__TestCase):
|
|||
def test_track_dynamic(self):
|
||||
# Test GC-optimization of dynamically-created dicts
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyObject(object):
|
||||
pass
|
||||
x, y, z, w, o = 1.5, "a", (1, object()), [], MyObject()
|
||||
|
|
@ -1103,7 +1103,7 @@ class DictTest(__TestCase):
|
|||
self._tracked(MyDict())
|
||||
|
||||
def make_shared_key_dict(self, n):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
pass
|
||||
|
||||
|
|
@ -1194,7 +1194,7 @@ class DictTest(__TestCase):
|
|||
@support.cpython_only
|
||||
def test_splittable_update(self):
|
||||
"""dict.update(other) must preserve order in other."""
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
def __init__(self, order):
|
||||
if order:
|
||||
|
|
@ -1212,7 +1212,7 @@ class DictTest(__TestCase):
|
|||
@support.cpython_only
|
||||
def test_splittable_to_generic_combinedtable(self):
|
||||
"""split table must be correctly resized and converted to generic combined table"""
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
pass
|
||||
|
||||
|
|
@ -1336,19 +1336,19 @@ class DictTest(__TestCase):
|
|||
self.assertEqual(sorted(values), sorted(data.values()))
|
||||
|
||||
def test_instance_dict_getattr_str_subclass(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Foo:
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
f = Foo('123')
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class _str(str):
|
||||
pass
|
||||
self.assertEqual(f.msg, getattr(f, _str('msg')))
|
||||
self.assertEqual(f.msg, f.__dict__[_str('msg')])
|
||||
|
||||
def test_object_set_item_single_instance_non_str_key(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Foo: pass
|
||||
f = Foo()
|
||||
f.__dict__[1] = 1
|
||||
|
|
@ -1359,7 +1359,7 @@ class DictTest(__TestCase):
|
|||
# This object will trigger mutation of the dict when replaced
|
||||
# by another value. Note this relies on refcounting: the test
|
||||
# won't achieve its purpose on fully-GCed Python implementations.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Mutating:
|
||||
def __del__(self):
|
||||
mutate(d)
|
||||
|
|
@ -1385,7 +1385,7 @@ class DictTest(__TestCase):
|
|||
self.check_reentrant_insertion(mutate)
|
||||
|
||||
def test_merge_and_mutate(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class X:
|
||||
def __hash__(self):
|
||||
return 0
|
||||
|
|
@ -1408,7 +1408,7 @@ class DictTest(__TestCase):
|
|||
|
||||
def test_equal_operator_modifying_operand(self):
|
||||
# test fix for seg fault reported in bpo-27945 part 3.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class X():
|
||||
def __del__(self):
|
||||
dict_b.clear()
|
||||
|
|
@ -1425,7 +1425,7 @@ class DictTest(__TestCase):
|
|||
self.assertTrue(dict_a == dict_b)
|
||||
|
||||
# test fix for seg fault reported in bpo-38588 part 1.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Y:
|
||||
def __eq__(self, other):
|
||||
dict_d.clear()
|
||||
|
|
@ -1437,7 +1437,7 @@ class DictTest(__TestCase):
|
|||
|
||||
def test_fromkeys_operator_modifying_dict_operand(self):
|
||||
# test fix for seg fault reported in issue 27945 part 4a.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class X(int):
|
||||
def __hash__(self):
|
||||
return 13
|
||||
|
|
@ -1456,7 +1456,7 @@ class DictTest(__TestCase):
|
|||
|
||||
def test_fromkeys_operator_modifying_set_operand(self):
|
||||
# test fix for seg fault reported in issue 27945 part 4b.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class X(int):
|
||||
def __hash__(self):
|
||||
return 13
|
||||
|
|
@ -1474,7 +1474,7 @@ class DictTest(__TestCase):
|
|||
pass
|
||||
|
||||
def test_dictitems_contains_use_after_free(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class X:
|
||||
def __eq__(self, other):
|
||||
d.clear()
|
||||
|
|
@ -1485,7 +1485,7 @@ class DictTest(__TestCase):
|
|||
|
||||
def test_dict_contain_use_after_free(self):
|
||||
# bpo-40489
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class S(str):
|
||||
def __eq__(self, other):
|
||||
d.clear()
|
||||
|
|
@ -1498,7 +1498,7 @@ class DictTest(__TestCase):
|
|||
self.assertFalse('test' in d)
|
||||
|
||||
def test_init_use_after_free(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class X:
|
||||
def __hash__(self):
|
||||
pair[:] = []
|
||||
|
|
@ -1508,7 +1508,7 @@ class DictTest(__TestCase):
|
|||
dict([pair])
|
||||
|
||||
def test_oob_indexing_dictiter_iternextitem(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class X(int):
|
||||
def __del__(self):
|
||||
d.clear()
|
||||
|
|
@ -1545,7 +1545,7 @@ class DictTest(__TestCase):
|
|||
self.assertEqual(list(reversed(dict().keys())), [])
|
||||
|
||||
def test_reverse_iterator_for_shared_shared_dicts(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A:
|
||||
def __init__(self, x, y):
|
||||
if x: self.x = x
|
||||
|
|
@ -1565,7 +1565,7 @@ class DictTest(__TestCase):
|
|||
self.assertEqual(list(copy.items()), expected)
|
||||
|
||||
# dict subclass doesn't override __iter__
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class CustomDict(dict):
|
||||
pass
|
||||
|
||||
|
|
@ -1574,7 +1574,7 @@ class DictTest(__TestCase):
|
|||
d = CustomDict(pairs)
|
||||
self.assertEqual(pairs, list(dict(d).items()))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class CustomReversedDict(dict):
|
||||
def keys(self):
|
||||
return reversed(list(dict.keys(self)))
|
||||
|
|
@ -1607,7 +1607,7 @@ class DictTest(__TestCase):
|
|||
self.assertTrue(gc.is_tracked(next(it)))
|
||||
|
||||
def test_store_evilattr(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class EvilAttr:
|
||||
def __init__(self, d):
|
||||
self.d = d
|
||||
|
|
@ -1630,13 +1630,13 @@ class DictTest(__TestCase):
|
|||
# `str` keys. Make sure the unoptimized path is used when a non-`str`
|
||||
# key appears.
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class StrSub(str):
|
||||
pass
|
||||
|
||||
eq_count = 0
|
||||
# This class compares equal to the string 'key3'
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Key3:
|
||||
def __hash__(self):
|
||||
return hash('key3')
|
||||
|
|
@ -1746,7 +1746,7 @@ class CAPITest(__TestCase):
|
|||
# key does not exist
|
||||
self.assertRaises(KeyError, dict_getitem_knownhash, {}, 1, hash(1))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Exc(Exception): pass
|
||||
class BadEq:
|
||||
def __eq__(self, other):
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ index 97f951f1299..da82bd190c3 100644
|
|||
import os
|
||||
@@ -8,11 +62,84 @@ import time
|
||||
import unittest
|
||||
|
||||
|
||||
from test import support
|
||||
-from test.support.testcase import FloatsAreIdenticalMixin
|
||||
-from test.support.numbers import (
|
||||
|
|
@ -149,14 +149,14 @@ index 97f951f1299..da82bd190c3 100644
|
|||
+
|
||||
from math import isinf, isnan, copysign, ldexp
|
||||
import math
|
||||
|
||||
|
||||
@@ -35,7 +162,7 @@ class FloatSubclass(float):
|
||||
class OtherFloatSubclass(float):
|
||||
pass
|
||||
|
||||
|
||||
-class GeneralFloatCases(unittest.TestCase):
|
||||
+class GeneralFloatCases(__TestCase):
|
||||
|
||||
|
||||
def test_float(self):
|
||||
self.assertEqual(float(3.14), 3.14)
|
||||
@@ -95,9 +222,10 @@ class GeneralFloatCases(unittest.TestCase):
|
||||
|
|
@ -166,51 +166,51 @@ index 97f951f1299..da82bd190c3 100644
|
|||
- class CustomStr(str): pass
|
||||
- class CustomBytes(bytes): pass
|
||||
- class CustomByteArray(bytearray): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class CustomStr(str): pass
|
||||
+ class CustomBytes(bytes): pass
|
||||
+ class CustomByteArray(bytearray): pass
|
||||
|
||||
|
||||
factories = [
|
||||
bytes,
|
||||
@@ -184,30 +312,31 @@ class GeneralFloatCases(unittest.TestCase):
|
||||
|
||||
|
||||
def test_floatconversion(self):
|
||||
# Make sure that calls to __float__() work properly
|
||||
- class Foo1(object):
|
||||
- def __float__(self):
|
||||
- return 42.
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Foo1(object):
|
||||
+ def __float__(self):
|
||||
+ return 42.
|
||||
|
||||
|
||||
- class Foo2(float):
|
||||
- def __float__(self):
|
||||
- return 42.
|
||||
+ class Foo2(float):
|
||||
+ def __float__(self):
|
||||
+ return 42.
|
||||
|
||||
|
||||
- class Foo3(float):
|
||||
- def __new__(cls, value=0.):
|
||||
- return float.__new__(cls, 2*value)
|
||||
+ class Foo3(float):
|
||||
+ def __new__(cls, value=0.):
|
||||
+ return float.__new__(cls, 2*value)
|
||||
|
||||
|
||||
- def __float__(self):
|
||||
- return self
|
||||
+ def __float__(self):
|
||||
+ return self
|
||||
|
||||
|
||||
- class Foo4(float):
|
||||
- def __float__(self):
|
||||
- return 42
|
||||
+ class Foo4(float):
|
||||
+ def __float__(self):
|
||||
+ return 42
|
||||
|
||||
|
||||
- # Issue 5759: __float__ not called on str subclasses (though it is on
|
||||
- # unicode subclasses).
|
||||
- class FooStr(str):
|
||||
|
|
@ -221,27 +221,27 @@ index 97f951f1299..da82bd190c3 100644
|
|||
+ class FooStr(str):
|
||||
+ def __float__(self):
|
||||
+ return float(str(self)) + 1
|
||||
|
||||
|
||||
self.assertEqual(float(Foo1()), 42.)
|
||||
self.assertEqual(float(Foo2()), 42.)
|
||||
@@ -216,15 +345,17 @@ class GeneralFloatCases(unittest.TestCase):
|
||||
self.assertRaises(TypeError, float, Foo4(42))
|
||||
self.assertEqual(float(FooStr('8')), 9.)
|
||||
|
||||
|
||||
- class Foo5:
|
||||
- def __float__(self):
|
||||
- return ""
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Foo5:
|
||||
+ def __float__(self):
|
||||
+ return ""
|
||||
self.assertRaises(TypeError, time.sleep, Foo5())
|
||||
|
||||
|
||||
- # Issue #24731
|
||||
- class F:
|
||||
- def __float__(self):
|
||||
- return OtherFloatSubclass(42.)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ # Issue #24731
|
||||
+ class F:
|
||||
+ def __float__(self):
|
||||
|
|
@ -252,39 +252,39 @@ index 97f951f1299..da82bd190c3 100644
|
|||
@@ -234,18 +365,20 @@ class GeneralFloatCases(unittest.TestCase):
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
self.assertIs(type(FloatSubclass(F())), FloatSubclass)
|
||||
|
||||
|
||||
- class MyIndex:
|
||||
- def __init__(self, value):
|
||||
- self.value = value
|
||||
- def __index__(self):
|
||||
- return self.value
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyIndex:
|
||||
+ def __init__(self, value):
|
||||
+ self.value = value
|
||||
+ def __index__(self):
|
||||
+ return self.value
|
||||
|
||||
|
||||
self.assertEqual(float(MyIndex(42)), 42.0)
|
||||
self.assertRaises(OverflowError, float, MyIndex(2**2000))
|
||||
|
||||
|
||||
- class MyInt:
|
||||
- def __int__(self):
|
||||
- return 42
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyInt:
|
||||
+ def __int__(self):
|
||||
+ return 42
|
||||
|
||||
|
||||
self.assertRaises(TypeError, float, MyInt())
|
||||
|
||||
|
||||
@@ -254,27 +387,30 @@ class GeneralFloatCases(unittest.TestCase):
|
||||
float(x='3.14')
|
||||
|
||||
|
||||
def test_keywords_in_subclass(self):
|
||||
- class subclass(float):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class subclass(float):
|
||||
+ pass
|
||||
u = subclass(2.5)
|
||||
|
|
@ -292,11 +292,11 @@ index 97f951f1299..da82bd190c3 100644
|
|||
self.assertEqual(float(u), 2.5)
|
||||
with self.assertRaises(TypeError):
|
||||
subclass(x=0)
|
||||
|
||||
|
||||
- class subclass_with_init(float):
|
||||
- def __init__(self, arg, newarg=None):
|
||||
- self.newarg = newarg
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class subclass_with_init(float):
|
||||
+ def __init__(self, arg, newarg=None):
|
||||
+ self.newarg = newarg
|
||||
|
|
@ -304,13 +304,13 @@ index 97f951f1299..da82bd190c3 100644
|
|||
self.assertIs(type(u), subclass_with_init)
|
||||
self.assertEqual(float(u), 2.5)
|
||||
self.assertEqual(u.newarg, 3)
|
||||
|
||||
|
||||
- class subclass_with_new(float):
|
||||
- def __new__(cls, arg, newarg=None):
|
||||
- self = super().__new__(cls, arg)
|
||||
- self.newarg = newarg
|
||||
- return self
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class subclass_with_new(float):
|
||||
+ def __new__(cls, arg, newarg=None):
|
||||
+ self = super().__new__(cls, arg)
|
||||
|
|
@ -328,7 +328,7 @@ index 97f951f1299..da82bd190c3 100644
|
|||
- return 42
|
||||
- class F(float, H):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class H:
|
||||
+ def __hash__(self):
|
||||
+ return 42
|
||||
|
|
@ -336,8 +336,8 @@ index 97f951f1299..da82bd190c3 100644
|
|||
+ pass
|
||||
value = F('nan')
|
||||
self.assertEqual(hash(value), object.__hash__(value))
|
||||
|
||||
|
||||
|
||||
|
||||
@unittest.skipUnless(hasattr(float, "__getformat__"), "requires __getformat__")
|
||||
-class FormatFunctionsTestCase(unittest.TestCase):
|
||||
+class FormatFunctionsTestCase(__TestCase):
|
||||
|
|
@ -347,25 +347,25 @@ index 97f951f1299..da82bd190c3 100644
|
|||
@@ -645,7 +782,7 @@ LE_FLOAT_NAN = bytes(reversed(BE_FLOAT_NAN))
|
||||
# is accident (today).
|
||||
# let's also try to guarantee that -0.0 and 0.0 don't get confused.
|
||||
|
||||
|
||||
-class IEEEFormatTestCase(unittest.TestCase):
|
||||
+class IEEEFormatTestCase(__TestCase):
|
||||
|
||||
|
||||
@support.requires_IEEE_754
|
||||
def test_double_specials_do_unpack(self):
|
||||
@@ -670,7 +807,7 @@ class IEEEFormatTestCase(unittest.TestCase):
|
||||
self.assertEqual(struct.pack("<f", 3.40282356e38), struct.pack("<f", FLT_MAX))
|
||||
self.assertEqual(struct.pack("<f", -3.40282356e38), struct.pack("<f", -FLT_MAX))
|
||||
|
||||
|
||||
-class FormatTestCase(unittest.TestCase):
|
||||
+class FormatTestCase(__TestCase):
|
||||
|
||||
|
||||
def test_format(self):
|
||||
# these should be rewritten to use both format(x, spec) and
|
||||
@@ -767,7 +904,7 @@ class FormatTestCase(unittest.TestCase):
|
||||
self.assertEqual(format(-123.34, '00.10e'), '-1.2334000000e+02')
|
||||
self.assertEqual(format(-123.34, '00.10g'), '-123.34')
|
||||
|
||||
|
||||
-class ReprTestCase(unittest.TestCase):
|
||||
+class ReprTestCase(__TestCase):
|
||||
def test_repr(self):
|
||||
|
|
@ -373,7 +373,7 @@ index 97f951f1299..da82bd190c3 100644
|
|||
'mathdata',
|
||||
@@ -832,7 +969,29 @@ class ReprTestCase(unittest.TestCase):
|
||||
self.assertEqual(repr(float(negs)), str(float(negs)))
|
||||
|
||||
|
||||
@support.requires_IEEE_754
|
||||
-class RoundTestCase(unittest.TestCase, FloatsAreIdenticalMixin):
|
||||
+class RoundTestCase(__TestCase):
|
||||
|
|
@ -399,11 +399,11 @@ index 97f951f1299..da82bd190c3 100644
|
|||
+ else:
|
||||
+ msg += ': zeros have different signs'
|
||||
+ self.fail(msg.format(x, y))
|
||||
|
||||
|
||||
def test_inf_nan(self):
|
||||
self.assertRaises(OverflowError, round, INF)
|
||||
@@ -955,7 +1114,7 @@ class RoundTestCase(unittest.TestCase, FloatsAreIdenticalMixin):
|
||||
|
||||
|
||||
# Beginning with Python 2.6 float has cross platform compatible
|
||||
# ways to create and represent inf and nan
|
||||
-class InfNanTest(unittest.TestCase):
|
||||
|
|
@ -412,7 +412,7 @@ index 97f951f1299..da82bd190c3 100644
|
|||
self.assertTrue(isinf(float("inf")))
|
||||
self.assertTrue(isinf(float("+inf")))
|
||||
@@ -1056,12 +1215,35 @@ class InfNanTest(unittest.TestCase):
|
||||
|
||||
|
||||
fromHex = float.fromhex
|
||||
toHex = float.hex
|
||||
-class HexFloatTestCase(FloatsAreIdenticalMixin, unittest.TestCase):
|
||||
|
|
@ -421,7 +421,7 @@ index 97f951f1299..da82bd190c3 100644
|
|||
MIN = fromHex('0x1p-1022') # min normal
|
||||
TINY = fromHex('0x0.0000000000001p-1022') # min subnormal
|
||||
EPS = fromHex('0x0.0000000000001p0') # diff between 1.0 and next float up
|
||||
|
||||
|
||||
+ def assertFloatsAreIdentical(self, x, y):
|
||||
+ """assert that floats x and y are identical, in the sense that:
|
||||
+ (1) both x and y are nans, or
|
||||
|
|
@ -447,37 +447,37 @@ index 97f951f1299..da82bd190c3 100644
|
|||
+
|
||||
def identical(self, x, y):
|
||||
self.assertFloatsAreIdentical(x, y)
|
||||
|
||||
|
||||
@@ -1482,17 +1664,19 @@ class HexFloatTestCase(FloatsAreIdenticalMixin, unittest.TestCase):
|
||||
self.identical(x, fromHex(toHex(x)))
|
||||
|
||||
|
||||
def test_subclass(self):
|
||||
- class F(float):
|
||||
- def __new__(cls, value):
|
||||
- return float.__new__(cls, value + 1)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class F(float):
|
||||
+ def __new__(cls, value):
|
||||
+ return float.__new__(cls, value + 1)
|
||||
|
||||
|
||||
f = F.fromhex((1.5).hex())
|
||||
self.assertIs(type(f), F)
|
||||
self.assertEqual(f, 2.5)
|
||||
|
||||
|
||||
- class F2(float):
|
||||
- def __init__(self, value):
|
||||
- self.foo = 'bar'
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class F2(float):
|
||||
+ def __init__(self, value):
|
||||
+ self.foo = 'bar'
|
||||
|
||||
|
||||
f = F2.fromhex((1.5).hex())
|
||||
self.assertIs(type(f), F2)
|
||||
@@ -1500,5 +1684,5 @@ class HexFloatTestCase(FloatsAreIdenticalMixin, unittest.TestCase):
|
||||
self.assertEqual(getattr(f, 'foo', 'none'), 'bar')
|
||||
|
||||
|
||||
|
||||
|
||||
-if __name__ == '__main__':
|
||||
- unittest.main()
|
||||
+if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -222,7 +222,7 @@ class GeneralFloatCases(__TestCase):
|
|||
def test_non_numeric_input_types(self):
|
||||
# Test possible non-numeric types for the argument x, including
|
||||
# subclasses of the explicitly documented accepted types.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class CustomStr(str): pass
|
||||
class CustomBytes(bytes): pass
|
||||
class CustomByteArray(bytearray): pass
|
||||
|
|
@ -312,7 +312,7 @@ class GeneralFloatCases(__TestCase):
|
|||
|
||||
def test_floatconversion(self):
|
||||
# Make sure that calls to __float__() work properly
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Foo1(object):
|
||||
def __float__(self):
|
||||
return 42.
|
||||
|
|
@ -345,13 +345,13 @@ class GeneralFloatCases(__TestCase):
|
|||
self.assertRaises(TypeError, float, Foo4(42))
|
||||
self.assertEqual(float(FooStr('8')), 9.)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Foo5:
|
||||
def __float__(self):
|
||||
return ""
|
||||
self.assertRaises(TypeError, time.sleep, Foo5())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
# Issue #24731
|
||||
class F:
|
||||
def __float__(self):
|
||||
|
|
@ -365,7 +365,7 @@ class GeneralFloatCases(__TestCase):
|
|||
with self.assertWarns(DeprecationWarning):
|
||||
self.assertIs(type(FloatSubclass(F())), FloatSubclass)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyIndex:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
|
@ -375,7 +375,7 @@ class GeneralFloatCases(__TestCase):
|
|||
self.assertEqual(float(MyIndex(42)), 42.0)
|
||||
self.assertRaises(OverflowError, float, MyIndex(2**2000))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyInt:
|
||||
def __int__(self):
|
||||
return 42
|
||||
|
|
@ -387,7 +387,7 @@ class GeneralFloatCases(__TestCase):
|
|||
float(x='3.14')
|
||||
|
||||
def test_keywords_in_subclass(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class subclass(float):
|
||||
pass
|
||||
u = subclass(2.5)
|
||||
|
|
@ -396,7 +396,7 @@ class GeneralFloatCases(__TestCase):
|
|||
with self.assertRaises(TypeError):
|
||||
subclass(x=0)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class subclass_with_init(float):
|
||||
def __init__(self, arg, newarg=None):
|
||||
self.newarg = newarg
|
||||
|
|
@ -405,7 +405,7 @@ class GeneralFloatCases(__TestCase):
|
|||
self.assertEqual(float(u), 2.5)
|
||||
self.assertEqual(u.newarg, 3)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class subclass_with_new(float):
|
||||
def __new__(cls, arg, newarg=None):
|
||||
self = super().__new__(cls, arg)
|
||||
|
|
@ -746,7 +746,7 @@ class GeneralFloatCases(__TestCase):
|
|||
def test_hash_nan(self):
|
||||
value = float('nan')
|
||||
self.assertEqual(hash(value), object.__hash__(value))
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class H:
|
||||
def __hash__(self):
|
||||
return 42
|
||||
|
|
@ -1664,7 +1664,7 @@ class HexFloatTestCase(__TestCase):
|
|||
self.identical(x, fromHex(toHex(x)))
|
||||
|
||||
def test_subclass(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class F(float):
|
||||
def __new__(cls, value):
|
||||
return float.__new__(cls, value + 1)
|
||||
|
|
@ -1673,7 +1673,7 @@ class HexFloatTestCase(__TestCase):
|
|||
self.assertIs(type(f), F)
|
||||
self.assertEqual(f, 2.5)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class F2(float):
|
||||
def __init__(self, value):
|
||||
self.foo = 'bar'
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ index 48825f46911..731680d82a0 100644
|
|||
+
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from test import support
|
||||
|
|
@ -144,35 +144,35 @@ index 48825f46911..731680d82a0 100644
|
|||
+ '(1+1.5_j_)',
|
||||
+ '(1+1.5_j)',
|
||||
+]
|
||||
|
||||
|
||||
try:
|
||||
import _pylong
|
||||
@@ -38,7 +165,7 @@ L = [
|
||||
class IntSubclass(int):
|
||||
pass
|
||||
|
||||
|
||||
-class IntTestCases(unittest.TestCase):
|
||||
+class IntTestCases(__TestCase):
|
||||
|
||||
|
||||
def test_basic(self):
|
||||
self.assertEqual(int(314), 314)
|
||||
@@ -309,11 +436,13 @@ class IntTestCases(unittest.TestCase):
|
||||
int('0', 5.0)
|
||||
|
||||
|
||||
def test_int_base_indexable(self):
|
||||
- class MyIndexable(object):
|
||||
- def __init__(self, value):
|
||||
- self.value = value
|
||||
- def __index__(self):
|
||||
- return self.value
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyIndexable(object):
|
||||
+ def __init__(self, value):
|
||||
+ self.value = value
|
||||
+ def __index__(self):
|
||||
+ return self.value
|
||||
|
||||
|
||||
# Check out of range bases.
|
||||
for base in 2**100, -2**100, 1, 37:
|
||||
@@ -328,9 +457,11 @@ class IntTestCases(unittest.TestCase):
|
||||
|
|
@ -183,44 +183,44 @@ index 48825f46911..731680d82a0 100644
|
|||
- class CustomBytes(bytes): pass
|
||||
- class CustomByteArray(bytearray): pass
|
||||
+
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class CustomStr(str): pass
|
||||
+ class CustomBytes(bytes): pass
|
||||
+ class CustomByteArray(bytearray): pass
|
||||
|
||||
|
||||
factories = [
|
||||
bytes,
|
||||
@@ -372,72 +503,82 @@ class IntTestCases(unittest.TestCase):
|
||||
|
||||
|
||||
def test_intconversion(self):
|
||||
# Test __int__()
|
||||
- class ClassicMissingMethods:
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class ClassicMissingMethods:
|
||||
+ pass
|
||||
self.assertRaises(TypeError, int, ClassicMissingMethods())
|
||||
|
||||
|
||||
- class MissingMethods(object):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MissingMethods(object):
|
||||
+ pass
|
||||
self.assertRaises(TypeError, int, MissingMethods())
|
||||
|
||||
|
||||
- class Foo0:
|
||||
- def __int__(self):
|
||||
- return 42
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Foo0:
|
||||
+ def __int__(self):
|
||||
+ return 42
|
||||
|
||||
|
||||
self.assertEqual(int(Foo0()), 42)
|
||||
|
||||
|
||||
- class Classic:
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Classic:
|
||||
+ pass
|
||||
for base in (object, Classic):
|
||||
|
|
@ -229,35 +229,35 @@ index 48825f46911..731680d82a0 100644
|
|||
- return 42
|
||||
- def __trunc__(self):
|
||||
- return -12
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class IntOverridesTrunc(base):
|
||||
+ def __int__(self):
|
||||
+ return 42
|
||||
+ def __trunc__(self):
|
||||
+ return -12
|
||||
self.assertEqual(int(IntOverridesTrunc()), 42)
|
||||
|
||||
|
||||
- class JustTrunc(base):
|
||||
- def __trunc__(self):
|
||||
- return 42
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class JustTrunc(base):
|
||||
+ def __trunc__(self):
|
||||
+ return 42
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
self.assertEqual(int(JustTrunc()), 42)
|
||||
|
||||
|
||||
- class ExceptionalTrunc(base):
|
||||
- def __trunc__(self):
|
||||
- 1 / 0
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class ExceptionalTrunc(base):
|
||||
+ def __trunc__(self):
|
||||
+ 1 / 0
|
||||
with self.assertRaises(ZeroDivisionError), \
|
||||
self.assertWarns(DeprecationWarning):
|
||||
int(ExceptionalTrunc())
|
||||
|
||||
|
||||
for trunc_result_base in (object, Classic):
|
||||
- class Index(trunc_result_base):
|
||||
- def __index__(self):
|
||||
|
|
@ -266,7 +266,7 @@ index 48825f46911..731680d82a0 100644
|
|||
- class TruncReturnsNonInt(base):
|
||||
- def __trunc__(self):
|
||||
- return Index()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Index(trunc_result_base):
|
||||
+ def __index__(self):
|
||||
+ return 42
|
||||
|
|
@ -276,15 +276,15 @@ index 48825f46911..731680d82a0 100644
|
|||
+ return Index()
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
self.assertEqual(int(TruncReturnsNonInt()), 42)
|
||||
|
||||
|
||||
- class Intable(trunc_result_base):
|
||||
- def __int__(self):
|
||||
- return 42
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Intable(trunc_result_base):
|
||||
+ def __int__(self):
|
||||
+ return 42
|
||||
|
||||
|
||||
- class TruncReturnsNonIndex(base):
|
||||
- def __trunc__(self):
|
||||
- return Intable()
|
||||
|
|
@ -293,17 +293,17 @@ index 48825f46911..731680d82a0 100644
|
|||
+ return Intable()
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
self.assertEqual(int(TruncReturnsNonInt()), 42)
|
||||
|
||||
|
||||
- class NonIntegral(trunc_result_base):
|
||||
- def __trunc__(self):
|
||||
- # Check that we avoid infinite recursion.
|
||||
- return NonIntegral()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class NonIntegral(trunc_result_base):
|
||||
+ def __trunc__(self):
|
||||
+ # Check that we avoid infinite recursion.
|
||||
+ return NonIntegral()
|
||||
|
||||
|
||||
- class TruncReturnsNonIntegral(base):
|
||||
- def __trunc__(self):
|
||||
- return NonIntegral()
|
||||
|
|
@ -316,152 +316,152 @@ index 48825f46911..731680d82a0 100644
|
|||
@@ -449,27 +590,29 @@ class IntTestCases(unittest.TestCase):
|
||||
self.fail("Failed to raise TypeError with %s" %
|
||||
((base, trunc_result_base),))
|
||||
|
||||
|
||||
- # Regression test for bugs.python.org/issue16060.
|
||||
- class BadInt(trunc_result_base):
|
||||
- def __int__(self):
|
||||
- return 42.0
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ # Regression test for bugs.python.org/issue16060.
|
||||
+ class BadInt(trunc_result_base):
|
||||
+ def __int__(self):
|
||||
+ return 42.0
|
||||
|
||||
|
||||
- class TruncReturnsBadInt(base):
|
||||
- def __trunc__(self):
|
||||
- return BadInt()
|
||||
+ class TruncReturnsBadInt(base):
|
||||
+ def __trunc__(self):
|
||||
+ return BadInt()
|
||||
|
||||
|
||||
with self.assertRaises(TypeError), \
|
||||
self.assertWarns(DeprecationWarning):
|
||||
int(TruncReturnsBadInt())
|
||||
|
||||
|
||||
def test_int_subclass_with_index(self):
|
||||
- class MyIndex(int):
|
||||
- def __index__(self):
|
||||
- return 42
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyIndex(int):
|
||||
+ def __index__(self):
|
||||
+ return 42
|
||||
|
||||
|
||||
- class BadIndex(int):
|
||||
- def __index__(self):
|
||||
- return 42.0
|
||||
+ class BadIndex(int):
|
||||
+ def __index__(self):
|
||||
+ return 42.0
|
||||
|
||||
|
||||
my_int = MyIndex(7)
|
||||
self.assertEqual(my_int, 7)
|
||||
@@ -478,13 +621,14 @@ class IntTestCases(unittest.TestCase):
|
||||
self.assertEqual(int(BadIndex()), 0)
|
||||
|
||||
|
||||
def test_int_subclass_with_int(self):
|
||||
- class MyInt(int):
|
||||
- def __int__(self):
|
||||
- return 42
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyInt(int):
|
||||
+ def __int__(self):
|
||||
+ return 42
|
||||
|
||||
|
||||
- class BadInt(int):
|
||||
- def __int__(self):
|
||||
- return 42.0
|
||||
+ class BadInt(int):
|
||||
+ def __int__(self):
|
||||
+ return 42.0
|
||||
|
||||
|
||||
my_int = MyInt(7)
|
||||
self.assertEqual(my_int, 7)
|
||||
@@ -495,33 +639,34 @@ class IntTestCases(unittest.TestCase):
|
||||
self.assertRaises(TypeError, int, my_int)
|
||||
|
||||
|
||||
def test_int_returns_int_subclass(self):
|
||||
- class BadIndex:
|
||||
- def __index__(self):
|
||||
- return True
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class BadIndex:
|
||||
+ def __index__(self):
|
||||
+ return True
|
||||
|
||||
|
||||
- class BadIndex2(int):
|
||||
- def __index__(self):
|
||||
- return True
|
||||
+ class BadIndex2(int):
|
||||
+ def __index__(self):
|
||||
+ return True
|
||||
|
||||
|
||||
- class BadInt:
|
||||
- def __int__(self):
|
||||
- return True
|
||||
+ class BadInt:
|
||||
+ def __int__(self):
|
||||
+ return True
|
||||
|
||||
|
||||
- class BadInt2(int):
|
||||
- def __int__(self):
|
||||
- return True
|
||||
+ class BadInt2(int):
|
||||
+ def __int__(self):
|
||||
+ return True
|
||||
|
||||
|
||||
- class TruncReturnsBadIndex:
|
||||
- def __trunc__(self):
|
||||
- return BadIndex()
|
||||
+ class TruncReturnsBadIndex:
|
||||
+ def __trunc__(self):
|
||||
+ return BadIndex()
|
||||
|
||||
|
||||
- class TruncReturnsBadInt:
|
||||
- def __trunc__(self):
|
||||
- return BadInt()
|
||||
+ class TruncReturnsBadInt:
|
||||
+ def __trunc__(self):
|
||||
+ return BadInt()
|
||||
|
||||
|
||||
- class TruncReturnsIntSubclass:
|
||||
- def __trunc__(self):
|
||||
- return True
|
||||
+ class TruncReturnsIntSubclass:
|
||||
+ def __trunc__(self):
|
||||
+ return True
|
||||
|
||||
|
||||
bad_int = BadIndex()
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
@@ -566,6 +711,7 @@ class IntTestCases(unittest.TestCase):
|
||||
self.assertEqual(n, 1)
|
||||
self.assertIs(type(n), IntSubclass)
|
||||
|
||||
|
||||
+ @skipIfTorchDynamo("flaky under dynamo")
|
||||
def test_error_message(self):
|
||||
def check(s, base=None):
|
||||
with self.assertRaises(ValueError,
|
||||
@@ -607,7 +753,7 @@ class IntTestCases(unittest.TestCase):
|
||||
self.assertEqual(int('1_2_3_4_5_6_7', 32), 1144132807)
|
||||
|
||||
|
||||
|
||||
|
||||
-class IntStrDigitLimitsTests(unittest.TestCase):
|
||||
+class IntStrDigitLimitsTests(__TestCase):
|
||||
|
||||
|
||||
int_class = int # Override this in subclasses to reuse the suite.
|
||||
|
||||
|
||||
@@ -818,7 +964,7 @@ class IntSubclassStrDigitLimitsTests(IntStrDigitLimitsTests):
|
||||
int_class = IntSubclass
|
||||
|
||||
|
||||
|
||||
|
||||
-class PyLongModuleTests(unittest.TestCase):
|
||||
+class PyLongModuleTests(__TestCase):
|
||||
# Tests of the functions in _pylong.py. Those get used when the
|
||||
# number of digits in the input values are large enough.
|
||||
|
||||
|
||||
@@ -922,4 +1068,4 @@ class PyLongModuleTests(unittest.TestCase):
|
||||
bits <<= 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
|
|
|
|||
|
|
@ -436,8 +436,8 @@ class IntTestCases(__TestCase):
|
|||
int('0', 5.0)
|
||||
|
||||
def test_int_base_indexable(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyIndexable(object):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
|
@ -458,7 +458,7 @@ class IntTestCases(__TestCase):
|
|||
# Test possible non-numeric types for the argument x, including
|
||||
# subclasses of the explicitly documented accepted types.
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class CustomStr(str): pass
|
||||
class CustomBytes(bytes): pass
|
||||
class CustomByteArray(bytearray): pass
|
||||
|
|
@ -503,28 +503,28 @@ class IntTestCases(__TestCase):
|
|||
|
||||
def test_intconversion(self):
|
||||
# Test __int__()
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class ClassicMissingMethods:
|
||||
pass
|
||||
self.assertRaises(TypeError, int, ClassicMissingMethods())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MissingMethods(object):
|
||||
pass
|
||||
self.assertRaises(TypeError, int, MissingMethods())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Foo0:
|
||||
def __int__(self):
|
||||
return 42
|
||||
|
||||
self.assertEqual(int(Foo0()), 42)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Classic:
|
||||
pass
|
||||
for base in (object, Classic):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class IntOverridesTrunc(base):
|
||||
def __int__(self):
|
||||
return 42
|
||||
|
|
@ -532,14 +532,14 @@ class IntTestCases(__TestCase):
|
|||
return -12
|
||||
self.assertEqual(int(IntOverridesTrunc()), 42)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class JustTrunc(base):
|
||||
def __trunc__(self):
|
||||
return 42
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
self.assertEqual(int(JustTrunc()), 42)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class ExceptionalTrunc(base):
|
||||
def __trunc__(self):
|
||||
1 / 0
|
||||
|
|
@ -548,7 +548,7 @@ class IntTestCases(__TestCase):
|
|||
int(ExceptionalTrunc())
|
||||
|
||||
for trunc_result_base in (object, Classic):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Index(trunc_result_base):
|
||||
def __index__(self):
|
||||
return 42
|
||||
|
|
@ -559,7 +559,7 @@ class IntTestCases(__TestCase):
|
|||
with self.assertWarns(DeprecationWarning):
|
||||
self.assertEqual(int(TruncReturnsNonInt()), 42)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Intable(trunc_result_base):
|
||||
def __int__(self):
|
||||
return 42
|
||||
|
|
@ -570,7 +570,7 @@ class IntTestCases(__TestCase):
|
|||
with self.assertWarns(DeprecationWarning):
|
||||
self.assertEqual(int(TruncReturnsNonInt()), 42)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class NonIntegral(trunc_result_base):
|
||||
def __trunc__(self):
|
||||
# Check that we avoid infinite recursion.
|
||||
|
|
@ -590,7 +590,7 @@ class IntTestCases(__TestCase):
|
|||
self.fail("Failed to raise TypeError with %s" %
|
||||
((base, trunc_result_base),))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
# Regression test for bugs.python.org/issue16060.
|
||||
class BadInt(trunc_result_base):
|
||||
def __int__(self):
|
||||
|
|
@ -605,7 +605,7 @@ class IntTestCases(__TestCase):
|
|||
int(TruncReturnsBadInt())
|
||||
|
||||
def test_int_subclass_with_index(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyIndex(int):
|
||||
def __index__(self):
|
||||
return 42
|
||||
|
|
@ -621,7 +621,7 @@ class IntTestCases(__TestCase):
|
|||
self.assertEqual(int(BadIndex()), 0)
|
||||
|
||||
def test_int_subclass_with_int(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyInt(int):
|
||||
def __int__(self):
|
||||
return 42
|
||||
|
|
@ -639,7 +639,7 @@ class IntTestCases(__TestCase):
|
|||
self.assertRaises(TypeError, int, my_int)
|
||||
|
||||
def test_int_returns_int_subclass(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BadIndex:
|
||||
def __index__(self):
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -61,15 +61,15 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
+# ======= END DYNAMO PATCH =======
|
||||
+
|
||||
# Test iterators.
|
||||
|
||||
|
||||
import sys
|
||||
@@ -104,12 +161,10 @@ class EmptyIterClass:
|
||||
|
||||
|
||||
# Main test suite
|
||||
|
||||
|
||||
-class TestCase(unittest.TestCase):
|
||||
+class TestCase(__TestCase):
|
||||
|
||||
|
||||
# Helper to check that an iterator returns a given sequence
|
||||
def check_iterator(self, it, seq, pickle=True):
|
||||
- if pickle:
|
||||
|
|
@ -78,7 +78,7 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
while 1:
|
||||
try:
|
||||
@@ -121,8 +176,6 @@ class TestCase(unittest.TestCase):
|
||||
|
||||
|
||||
# Helper to check that a for loop generates a given sequence
|
||||
def check_for_loop(self, expr, seq, pickle=True):
|
||||
- if pickle:
|
||||
|
|
@ -89,7 +89,7 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
@@ -261,19 +314,20 @@ class TestCase(unittest.TestCase):
|
||||
def run(builtin_name, item, sentinel=None):
|
||||
it = iter(item) if sentinel is None else iter(item, sentinel)
|
||||
|
||||
|
||||
- class CustomStr:
|
||||
- def __init__(self, name, iterator):
|
||||
- self.name = name
|
||||
|
|
@ -103,7 +103,7 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
- # the pointers after this call
|
||||
- list(self.iterator)
|
||||
- return other == self.name
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class CustomStr:
|
||||
+ def __init__(self, name, iterator):
|
||||
+ self.name = name
|
||||
|
|
@ -117,25 +117,25 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
+ # the pointers after this call
|
||||
+ list(self.iterator)
|
||||
+ return other == self.name
|
||||
|
||||
|
||||
# del is required here
|
||||
# to not prematurely call __eq__ from
|
||||
@@ -323,9 +377,10 @@ class TestCase(unittest.TestCase):
|
||||
|
||||
|
||||
# Test a new_style class with __iter__ but no next() method
|
||||
def test_new_style_iter_class(self):
|
||||
- class IterClass(object):
|
||||
- def __iter__(self):
|
||||
- return self
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class IterClass(object):
|
||||
+ def __iter__(self):
|
||||
+ return self
|
||||
self.assertRaises(TypeError, iter, IterClass())
|
||||
|
||||
|
||||
# Test two-argument iter() with callable instance
|
||||
@@ -394,11 +449,12 @@ class TestCase(unittest.TestCase):
|
||||
|
||||
|
||||
# Test exception propagation through sequence iterator
|
||||
def test_exception_sequence(self):
|
||||
- class MySequenceClass(SequenceClass):
|
||||
|
|
@ -143,7 +143,7 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
- if i == 10:
|
||||
- raise RuntimeError
|
||||
- return SequenceClass.__getitem__(self, i)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MySequenceClass(SequenceClass):
|
||||
+ def __getitem__(self, i):
|
||||
+ if i == 10:
|
||||
|
|
@ -153,7 +153,7 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
try:
|
||||
for x in MySequenceClass(20):
|
||||
@@ -410,11 +466,12 @@ class TestCase(unittest.TestCase):
|
||||
|
||||
|
||||
# Test for StopIteration from __getitem__
|
||||
def test_stop_sequence(self):
|
||||
- class MySequenceClass(SequenceClass):
|
||||
|
|
@ -161,25 +161,25 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
- if i == 10:
|
||||
- raise StopIteration
|
||||
- return SequenceClass.__getitem__(self, i)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MySequenceClass(SequenceClass):
|
||||
+ def __getitem__(self, i):
|
||||
+ if i == 10:
|
||||
+ raise StopIteration
|
||||
+ return SequenceClass.__getitem__(self, i)
|
||||
self.check_for_loop(MySequenceClass(20), list(range(10)), pickle=False)
|
||||
|
||||
|
||||
# Test a big range
|
||||
@@ -541,32 +598,34 @@ class TestCase(unittest.TestCase):
|
||||
self.assertRaises(TypeError, filter, None, list)
|
||||
self.assertRaises(TypeError, filter, None, 42)
|
||||
|
||||
|
||||
- class Boolean:
|
||||
- def __init__(self, truth):
|
||||
- self.truth = truth
|
||||
- def __bool__(self):
|
||||
- return self.truth
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Boolean:
|
||||
+ def __init__(self, truth):
|
||||
+ self.truth = truth
|
||||
|
|
@ -187,7 +187,7 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
+ return self.truth
|
||||
bTrue = Boolean(True)
|
||||
bFalse = Boolean(False)
|
||||
|
||||
|
||||
- class Seq:
|
||||
- def __init__(self, *args):
|
||||
- self.vals = args
|
||||
|
|
@ -206,7 +206,7 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
- else:
|
||||
- raise StopIteration
|
||||
- return SeqIter(self.vals)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Seq:
|
||||
+ def __init__(self, *args):
|
||||
+ self.vals = args
|
||||
|
|
@ -225,12 +225,12 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
+ else:
|
||||
+ raise StopIteration
|
||||
+ return SeqIter(self.vals)
|
||||
|
||||
|
||||
seq = Seq(*([bTrue, bFalse] * 25))
|
||||
self.assertEqual(list(filter(lambda x: not x, seq)), [bFalse]*25)
|
||||
@@ -635,6 +694,7 @@ class TestCase(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
# Test zip()'s use of iterators.
|
||||
+ @skipIfTorchDynamo("infinite loop")
|
||||
def test_builtin_zip(self):
|
||||
|
|
@ -238,21 +238,21 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
self.assertEqual(list(zip(*[])), [])
|
||||
@@ -653,17 +713,18 @@ class TestCase(unittest.TestCase):
|
||||
self.assertEqual(list(d.items()), list(zip(d, d.values())))
|
||||
|
||||
|
||||
# Generate all ints starting at constructor arg.
|
||||
- class IntsFrom:
|
||||
- def __init__(self, start):
|
||||
- self.i = start
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class IntsFrom:
|
||||
+ def __init__(self, start):
|
||||
+ self.i = start
|
||||
|
||||
|
||||
- def __iter__(self):
|
||||
- return self
|
||||
+ def __iter__(self):
|
||||
+ return self
|
||||
|
||||
|
||||
- def __next__(self):
|
||||
- i = self.i
|
||||
- self.i = i+1
|
||||
|
|
@ -261,60 +261,60 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
+ i = self.i
|
||||
+ self.i = i+1
|
||||
+ return i
|
||||
|
||||
|
||||
f = open(TESTFN, "w", encoding="utf-8")
|
||||
try:
|
||||
@@ -686,19 +747,20 @@ class TestCase(unittest.TestCase):
|
||||
self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)])
|
||||
|
||||
|
||||
# Classes that lie about their lengths.
|
||||
- class NoGuessLen5:
|
||||
- def __getitem__(self, i):
|
||||
- if i >= 5:
|
||||
- raise IndexError
|
||||
- return i
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class NoGuessLen5:
|
||||
+ def __getitem__(self, i):
|
||||
+ if i >= 5:
|
||||
+ raise IndexError
|
||||
+ return i
|
||||
|
||||
|
||||
- class Guess3Len5(NoGuessLen5):
|
||||
- def __len__(self):
|
||||
- return 3
|
||||
+ class Guess3Len5(NoGuessLen5):
|
||||
+ def __len__(self):
|
||||
+ return 3
|
||||
|
||||
|
||||
- class Guess30Len5(NoGuessLen5):
|
||||
- def __len__(self):
|
||||
- return 30
|
||||
+ class Guess30Len5(NoGuessLen5):
|
||||
+ def __len__(self):
|
||||
+ return 30
|
||||
|
||||
|
||||
def lzip(*args):
|
||||
return list(zip(*args))
|
||||
@@ -718,20 +780,21 @@ class TestCase(unittest.TestCase):
|
||||
|
||||
|
||||
# This class inserts a Unicode object into its argument's natural
|
||||
# iteration, in the 3rd position.
|
||||
- class OhPhooey:
|
||||
- def __init__(self, seq):
|
||||
- self.it = iter(seq)
|
||||
- self.i = 0
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class OhPhooey:
|
||||
+ def __init__(self, seq):
|
||||
+ self.it = iter(seq)
|
||||
+ self.i = 0
|
||||
|
||||
|
||||
- def __iter__(self):
|
||||
- return self
|
||||
+ def __iter__(self):
|
||||
+ return self
|
||||
|
||||
|
||||
- def __next__(self):
|
||||
- i = self.i
|
||||
- self.i = i+1
|
||||
|
|
@ -327,25 +327,25 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
+ if i == 2:
|
||||
+ return "fooled you!"
|
||||
+ return next(self.it)
|
||||
|
||||
|
||||
f = open(TESTFN, "w", encoding="utf-8")
|
||||
try:
|
||||
@@ -895,29 +958,30 @@ class TestCase(unittest.TestCase):
|
||||
f.writelines({})
|
||||
|
||||
|
||||
# Try a big chunk too.
|
||||
- class Iterator:
|
||||
- def __init__(self, start, finish):
|
||||
- self.start = start
|
||||
- self.finish = finish
|
||||
- self.i = self.start
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Iterator:
|
||||
+ def __init__(self, start, finish):
|
||||
+ self.start = start
|
||||
+ self.finish = finish
|
||||
+ self.i = self.start
|
||||
|
||||
|
||||
- def __next__(self):
|
||||
- if self.i >= self.finish:
|
||||
- raise StopIteration
|
||||
|
|
@ -358,12 +358,12 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
+ result = str(self.i) + '\n'
|
||||
+ self.i += 1
|
||||
+ return result
|
||||
|
||||
|
||||
- def __iter__(self):
|
||||
- return self
|
||||
+ def __iter__(self):
|
||||
+ return self
|
||||
|
||||
|
||||
- class Whatever:
|
||||
- def __init__(self, start, finish):
|
||||
- self.start = start
|
||||
|
|
@ -372,16 +372,16 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
+ def __init__(self, start, finish):
|
||||
+ self.start = start
|
||||
+ self.finish = finish
|
||||
|
||||
|
||||
- def __iter__(self):
|
||||
- return Iterator(self.start, self.finish)
|
||||
+ def __iter__(self):
|
||||
+ return Iterator(self.start, self.finish)
|
||||
|
||||
|
||||
f.writelines(Whatever(6, 6+2000))
|
||||
f.close()
|
||||
@@ -990,15 +1054,16 @@ class TestCase(unittest.TestCase):
|
||||
|
||||
|
||||
@cpython_only
|
||||
def test_ref_counting_behavior(self):
|
||||
- class C(object):
|
||||
|
|
@ -393,7 +393,7 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
- cls = self.__class__
|
||||
- assert cls.count > 0
|
||||
- cls.count -= 1
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C(object):
|
||||
+ count = 0
|
||||
+ def __new__(cls):
|
||||
|
|
@ -407,7 +407,7 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
self.assertEqual(C.count, 1)
|
||||
del x
|
||||
@@ -1089,12 +1154,13 @@ class TestCase(unittest.TestCase):
|
||||
|
||||
|
||||
def test_3720(self):
|
||||
# Avoid a crash, when an iterator deletes its next() method.
|
||||
- class BadIterator(object):
|
||||
|
|
@ -416,19 +416,19 @@ index 1b9f3cf7624..6560c7423a6 100644
|
|||
- def __next__(self):
|
||||
- del BadIterator.__next__
|
||||
- return 1
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class BadIterator(object):
|
||||
+ def __iter__(self):
|
||||
+ return self
|
||||
+ def __next__(self):
|
||||
+ del BadIterator.__next__
|
||||
+ return 1
|
||||
|
||||
|
||||
try:
|
||||
for i in BadIterator() :
|
||||
@@ -1187,4 +1253,4 @@ class TestCase(unittest.TestCase):
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
|
|
|
|||
|
|
@ -314,7 +314,7 @@ class TestCase(__TestCase):
|
|||
def run(builtin_name, item, sentinel=None):
|
||||
it = iter(item) if sentinel is None else iter(item, sentinel)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class CustomStr:
|
||||
def __init__(self, name, iterator):
|
||||
self.name = name
|
||||
|
|
@ -377,7 +377,7 @@ class TestCase(__TestCase):
|
|||
|
||||
# Test a new_style class with __iter__ but no next() method
|
||||
def test_new_style_iter_class(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class IterClass(object):
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
|
@ -449,7 +449,7 @@ class TestCase(__TestCase):
|
|||
|
||||
# Test exception propagation through sequence iterator
|
||||
def test_exception_sequence(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MySequenceClass(SequenceClass):
|
||||
def __getitem__(self, i):
|
||||
if i == 10:
|
||||
|
|
@ -466,7 +466,7 @@ class TestCase(__TestCase):
|
|||
|
||||
# Test for StopIteration from __getitem__
|
||||
def test_stop_sequence(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MySequenceClass(SequenceClass):
|
||||
def __getitem__(self, i):
|
||||
if i == 10:
|
||||
|
|
@ -598,7 +598,7 @@ class TestCase(__TestCase):
|
|||
self.assertRaises(TypeError, filter, None, list)
|
||||
self.assertRaises(TypeError, filter, None, 42)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Boolean:
|
||||
def __init__(self, truth):
|
||||
self.truth = truth
|
||||
|
|
@ -607,7 +607,7 @@ class TestCase(__TestCase):
|
|||
bTrue = Boolean(True)
|
||||
bFalse = Boolean(False)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Seq:
|
||||
def __init__(self, *args):
|
||||
self.vals = args
|
||||
|
|
@ -713,7 +713,7 @@ class TestCase(__TestCase):
|
|||
self.assertEqual(list(d.items()), list(zip(d, d.values())))
|
||||
|
||||
# Generate all ints starting at constructor arg.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class IntsFrom:
|
||||
def __init__(self, start):
|
||||
self.i = start
|
||||
|
|
@ -747,7 +747,7 @@ class TestCase(__TestCase):
|
|||
self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)])
|
||||
|
||||
# Classes that lie about their lengths.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class NoGuessLen5:
|
||||
def __getitem__(self, i):
|
||||
if i >= 5:
|
||||
|
|
@ -780,7 +780,7 @@ class TestCase(__TestCase):
|
|||
|
||||
# This class inserts a Unicode object into its argument's natural
|
||||
# iteration, in the 3rd position.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class OhPhooey:
|
||||
def __init__(self, seq):
|
||||
self.it = iter(seq)
|
||||
|
|
@ -958,7 +958,7 @@ class TestCase(__TestCase):
|
|||
f.writelines({})
|
||||
|
||||
# Try a big chunk too.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Iterator:
|
||||
def __init__(self, start, finish):
|
||||
self.start = start
|
||||
|
|
@ -1054,7 +1054,7 @@ class TestCase(__TestCase):
|
|||
|
||||
@cpython_only
|
||||
def test_ref_counting_behavior(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C(object):
|
||||
count = 0
|
||||
def __new__(cls):
|
||||
|
|
@ -1154,7 +1154,7 @@ class TestCase(__TestCase):
|
|||
|
||||
def test_3720(self):
|
||||
# Avoid a crash, when an iterator deletes its next() method.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BadIterator(object):
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ index 7d5ba727389..ff514815da2 100644
|
|||
@@ -40,6 +62,14 @@ def pickle_deprecated(testfunc):
|
||||
maxsize = support.MAX_Py_ssize_t
|
||||
minsize = -maxsize-1
|
||||
|
||||
|
||||
+@torch._dynamo.disable
|
||||
+def choice(*args):
|
||||
+ return random.choice(*args)
|
||||
|
|
@ -42,33 +42,33 @@ index 7d5ba727389..ff514815da2 100644
|
|||
+
|
||||
def lzip(*args):
|
||||
return list(zip(*args))
|
||||
|
||||
|
||||
@@ -90,10 +120,10 @@ def fact(n):
|
||||
return prod(range(1, n+1))
|
||||
|
||||
|
||||
# root level methods for pickling ability
|
||||
-def testR(r):
|
||||
+def _testR(r):
|
||||
return r[0]
|
||||
|
||||
|
||||
-def testR2(r):
|
||||
+def _testR2(r):
|
||||
return r[2]
|
||||
|
||||
|
||||
def underten(x):
|
||||
@@ -102,7 +132,7 @@ def underten(x):
|
||||
picklecopiers = [lambda s, proto=proto: pickle.loads(pickle.dumps(s, proto))
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1)]
|
||||
|
||||
|
||||
-class TestBasicOps(unittest.TestCase):
|
||||
+class TestBasicOps(__TestCase):
|
||||
|
||||
|
||||
def pickletest(self, protocol, it, stop=4, take=1, compare=None):
|
||||
"""Test that an iterator is the same after pickling, also when part-consumed"""
|
||||
@@ -454,14 +484,8 @@ class TestBasicOps(unittest.TestCase):
|
||||
self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1)
|
||||
self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1)
|
||||
|
||||
|
||||
- @pickle_deprecated
|
||||
def test_permutations(self):
|
||||
- self.assertRaises(TypeError, permutations) # too few arguments
|
||||
|
|
@ -79,11 +79,11 @@ index 7d5ba727389..ff514815da2 100644
|
|||
- self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None
|
||||
self.assertEqual(list(permutations(range(3), 2)),
|
||||
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
|
||||
|
||||
|
||||
@@ -498,7 +522,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
if len(set(indices)) == r:
|
||||
yield tuple(pool[i] for i in indices)
|
||||
|
||||
|
||||
- for n in range(7):
|
||||
+ for n in range(5):
|
||||
values = [5*x-12 for x in range(n)]
|
||||
|
|
@ -92,7 +92,7 @@ index 7d5ba727389..ff514815da2 100644
|
|||
@@ -515,9 +539,6 @@ class TestBasicOps(unittest.TestCase):
|
||||
self.assertEqual(result, list(permutations(values, None))) # test r as None
|
||||
self.assertEqual(result, list(permutations(values))) # test default r
|
||||
|
||||
|
||||
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
- self.pickletest(proto, permutations(values, r)) # test pickling
|
||||
-
|
||||
|
|
@ -107,7 +107,7 @@ index 7d5ba727389..ff514815da2 100644
|
|||
+ # self.assertRaises(TypeError, cycle)
|
||||
self.assertRaises(TypeError, cycle, 5)
|
||||
self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0])
|
||||
|
||||
|
||||
@@ -888,7 +909,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
# Check normal pickled
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
|
|
@ -118,7 +118,7 @@ index 7d5ba727389..ff514815da2 100644
|
|||
self.assertEqual(k, elem[0])
|
||||
dup.append(elem)
|
||||
@@ -896,8 +917,8 @@ class TestBasicOps(unittest.TestCase):
|
||||
|
||||
|
||||
# Check nested case
|
||||
dup = []
|
||||
- for k, g in groupby(s, testR):
|
||||
|
|
@ -140,8 +140,8 @@ index 7d5ba727389..ff514815da2 100644
|
|||
self.assertEqual(k, elem[0])
|
||||
self.assertEqual(ik, elem[2])
|
||||
@@ -917,7 +938,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
|
||||
|
||||
|
||||
|
||||
# Check case where inner iterator is not used
|
||||
- keys = [k for k, g in groupby(s, testR)]
|
||||
+ keys = [k for k, g in groupby(s, _testR)]
|
||||
|
|
@ -159,7 +159,7 @@ index 7d5ba727389..ff514815da2 100644
|
|||
_, g3 = next(it)
|
||||
@@ -936,7 +957,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
self.assertEqual(list(g3), [])
|
||||
|
||||
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
- it = groupby(s, testR)
|
||||
+ it = groupby(s, _testR)
|
||||
|
|
@ -182,7 +182,7 @@ index 7d5ba727389..ff514815da2 100644
|
|||
+ # self.assertRaises(TypeError, filter, isEven, 3)
|
||||
+ # dynamo raises Unsupported in this case
|
||||
+ # self.assertRaises(TypeError, next, filter(range(6), range(6)))
|
||||
|
||||
|
||||
# check copy, deepcopy, pickle
|
||||
- ans = [0,2,4]
|
||||
-
|
||||
|
|
@ -212,7 +212,7 @@ index 7d5ba727389..ff514815da2 100644
|
|||
+ # for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
+ # c = filter(isEven, range(6))
|
||||
+ # self.pickletest(proto, c)
|
||||
|
||||
|
||||
- @pickle_deprecated
|
||||
def test_filterfalse(self):
|
||||
self.assertEqual(list(filterfalse(isEven, range(6))), [1,3,5])
|
||||
|
|
@ -224,11 +224,11 @@ index 7d5ba727389..ff514815da2 100644
|
|||
- self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
|
||||
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
- self.pickletest(proto, filterfalse(isEven, range(6)))
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
|
||||
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
+ self.pickletest(proto, filterfalse(isEven, range(6)))
|
||||
|
||||
|
||||
def test_zip(self):
|
||||
# XXX This is rather silly now that builtin zip() calls zip()...
|
||||
@@ -1047,8 +1070,8 @@ class TestBasicOps(unittest.TestCase):
|
||||
|
|
@ -243,7 +243,7 @@ index 7d5ba727389..ff514815da2 100644
|
|||
lzip('abc', 'def'))
|
||||
self.assertEqual([pair for pair in zip('abc', 'def')],
|
||||
@@ -1105,19 +1128,19 @@ class TestBasicOps(unittest.TestCase):
|
||||
|
||||
|
||||
self.assertEqual(list(zip_longest('abc', 'defg', **{})),
|
||||
list(zip(list('abc')+[None], 'defg'))) # empty keyword dict
|
||||
- self.assertRaises(TypeError, zip_longest, 3)
|
||||
|
|
@ -272,7 +272,7 @@ index 7d5ba727389..ff514815da2 100644
|
|||
+ # pass
|
||||
+ # else:
|
||||
+ # self.fail('Did not raise Type in: ' + stmt)
|
||||
|
||||
|
||||
self.assertEqual([tuple(list(pair)) for pair in zip_longest('abc', 'def')],
|
||||
list(zip('abc', 'def')))
|
||||
@@ -1296,7 +1319,6 @@ class TestBasicOps(unittest.TestCase):
|
||||
|
|
@ -280,7 +280,7 @@ index 7d5ba727389..ff514815da2 100644
|
|||
list(product(*args, **dict(repeat=r))))
|
||||
self.assertEqual(len(list(product(*[range(7)]*6))), 7**6)
|
||||
- self.assertRaises(TypeError, product, range(6), None)
|
||||
|
||||
|
||||
def product1(*args, **kwds):
|
||||
pools = list(map(tuple, args)) * kwds.get('repeat', 1)
|
||||
@@ -1336,7 +1358,8 @@ class TestBasicOps(unittest.TestCase):
|
||||
|
|
@ -295,7 +295,7 @@ index 7d5ba727389..ff514815da2 100644
|
|||
self.assertEqual(list(product(*args)), list(product1(*args)))
|
||||
@@ -1767,6 +1790,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
script_helper.assert_python_ok("-c", script)
|
||||
|
||||
|
||||
# Issue 13454: Crash when deleting backward iterator from tee()
|
||||
+ @skipIfTorchDynamo("infinite loop in torch dynamo")
|
||||
def test_tee_del_backward(self):
|
||||
|
|
@ -303,68 +303,68 @@ index 7d5ba727389..ff514815da2 100644
|
|||
try:
|
||||
@@ -1920,7 +1944,7 @@ class TestBasicOps(unittest.TestCase):
|
||||
tp.foobar = 1
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestExamples(unittest.TestCase):
|
||||
+class TestExamples(__TestCase):
|
||||
|
||||
|
||||
def test_accumulate(self):
|
||||
self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15])
|
||||
@@ -2032,7 +2056,7 @@ class TestExamples(unittest.TestCase):
|
||||
self.assertEqual(list(takewhile(lambda x: x<5, [1,4,6,4,1])), [1,4])
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestPurePythonRoughEquivalents(unittest.TestCase):
|
||||
+class TestPurePythonRoughEquivalents(__TestCase):
|
||||
|
||||
|
||||
def test_batched_recipe(self):
|
||||
def batched_recipe(iterable, n):
|
||||
@@ -2081,6 +2105,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
|
||||
for i, element in zip(range(i + 1, stop), iterable):
|
||||
pass
|
||||
|
||||
|
||||
+ @skipIfTorchDynamo("infinite loop in torch dynamo")
|
||||
def test_islice_recipe(self):
|
||||
self.assertEqual(list(self.islice('ABCDEFG', 2)), list('AB'))
|
||||
self.assertEqual(list(self.islice('ABCDEFG', 2, 4)), list('CD'))
|
||||
@@ -2265,7 +2290,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
|
||||
raise
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestGC(unittest.TestCase):
|
||||
+class TestGC(__TestCase):
|
||||
|
||||
|
||||
def makecycle(self, iterator, container):
|
||||
container.append(iterator)
|
||||
@@ -2465,7 +2490,7 @@ def L(seqn):
|
||||
return chain(map(lambda x:x, R(Ig(G(seqn)))))
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestVariousIteratorArgs(unittest.TestCase):
|
||||
+class TestVariousIteratorArgs(__TestCase):
|
||||
|
||||
|
||||
def test_accumulate(self):
|
||||
s = [1,2,3,4,5]
|
||||
@@ -2644,7 +2669,7 @@ class TestVariousIteratorArgs(unittest.TestCase):
|
||||
self.assertRaises(TypeError, tee, N(s))
|
||||
self.assertRaises(ZeroDivisionError, list, tee(E(s))[0])
|
||||
|
||||
|
||||
-class LengthTransparency(unittest.TestCase):
|
||||
+class LengthTransparency(__TestCase):
|
||||
|
||||
|
||||
def test_repeat(self):
|
||||
self.assertEqual(operator.length_hint(repeat(None, 50)), 50)
|
||||
@@ -2657,7 +2682,7 @@ class LengthTransparency(unittest.TestCase):
|
||||
self.assertEqual(operator.length_hint(repeat(None, times=-1)), 0)
|
||||
self.assertEqual(operator.length_hint(repeat(None, times=-2)), 0)
|
||||
|
||||
|
||||
-class RegressionTests(unittest.TestCase):
|
||||
+class RegressionTests(__TestCase):
|
||||
|
||||
|
||||
def test_sf_793826(self):
|
||||
# Fix Armin Rigo's successful efforts to wreak havoc
|
||||
@@ -2718,6 +2743,7 @@ class RegressionTests(unittest.TestCase):
|
||||
|
||||
|
||||
@support.skip_if_pgo_task
|
||||
@support.requires_resource('cpu')
|
||||
+ @slowTest
|
||||
|
|
@ -373,8 +373,8 @@ index 7d5ba727389..ff514815da2 100644
|
|||
# dealing with long chains of empty iterables. Even with a high
|
||||
@@ -2750,7 +2776,7 @@ class RegressionTests(unittest.TestCase):
|
||||
next(g, None) # shouldn't crash
|
||||
|
||||
|
||||
|
||||
|
||||
-class SubclassWithKwargsTest(unittest.TestCase):
|
||||
+class SubclassWithKwargsTest(__TestCase):
|
||||
def test_keywords_in_subclass(self):
|
||||
|
|
@ -382,8 +382,8 @@ index 7d5ba727389..ff514815da2 100644
|
|||
testcases = [
|
||||
@@ -2805,49 +2831,5 @@ class SubclassWithKwargsTest(unittest.TestCase):
|
||||
self.assertEqual(u.newarg, 3)
|
||||
|
||||
|
||||
|
||||
|
||||
-@support.cpython_only
|
||||
-class SizeofTest(unittest.TestCase):
|
||||
- def setUp(self):
|
||||
|
|
|
|||
|
|
@ -1056,7 +1056,7 @@ class TestBasicOps(__TestCase):
|
|||
self.assertRaises(TypeError, filterfalse, lambda x:x)
|
||||
self.assertRaises(TypeError, filterfalse, lambda x:x, range(6), 7)
|
||||
self.assertRaises(TypeError, filterfalse, isEven, 3)
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
self.pickletest(proto, filterfalse(isEven, range(6)))
|
||||
|
|
@ -1358,7 +1358,7 @@ class TestBasicOps(__TestCase):
|
|||
argtypes = ['', 'abc', '', range(0), range(4), dict(a=1, b=2, c=3),
|
||||
set('abcdefg'), range(11), tuple(range(13))]
|
||||
for i in range(100):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
args = [choice(argtypes) for j in range(randrange(5))]
|
||||
expected_len = prod(map(len, args))
|
||||
self.assertEqual(len(list(product(*args))), expected_len)
|
||||
|
|
|
|||
|
|
@ -67,17 +67,17 @@ index 23ef902aa0b..b9afb1ef26e 100644
|
|||
@@ -36,7 +90,7 @@ class ListTest(list_tests.CommonTest):
|
||||
# earlier due to a newlib bug. See the following mailing list
|
||||
# thread for the details:
|
||||
|
||||
|
||||
self.assertRaises(MemoryError, list, range(sys.maxsize // 2))
|
||||
|
||||
|
||||
# This code used to segfault in Py2.4a3
|
||||
@@ -49,28 +103,31 @@ class ListTest(list_tests.CommonTest):
|
||||
list(sequence=[])
|
||||
|
||||
|
||||
def test_keywords_in_subclass(self):
|
||||
- class subclass(list):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class subclass(list):
|
||||
+ pass
|
||||
u = subclass([1, 2])
|
||||
|
|
@ -85,12 +85,12 @@ index 23ef902aa0b..b9afb1ef26e 100644
|
|||
self.assertEqual(list(u), [1, 2])
|
||||
with self.assertRaises(TypeError):
|
||||
subclass(sequence=())
|
||||
|
||||
|
||||
- class subclass_with_init(list):
|
||||
- def __init__(self, seq, newarg=None):
|
||||
- super().__init__(seq)
|
||||
- self.newarg = newarg
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class subclass_with_init(list):
|
||||
+ def __init__(self, seq, newarg=None):
|
||||
+ super().__init__(seq)
|
||||
|
|
@ -99,13 +99,13 @@ index 23ef902aa0b..b9afb1ef26e 100644
|
|||
self.assertIs(type(u), subclass_with_init)
|
||||
self.assertEqual(list(u), [1, 2])
|
||||
self.assertEqual(u.newarg, 3)
|
||||
|
||||
|
||||
- class subclass_with_new(list):
|
||||
- def __new__(cls, seq, newarg=None):
|
||||
- self = super().__new__(cls, seq)
|
||||
- self.newarg = newarg
|
||||
- return self
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class subclass_with_new(list):
|
||||
+ def __new__(cls, seq, newarg=None):
|
||||
+ self = super().__new__(cls, seq)
|
||||
|
|
@ -116,7 +116,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
|
|||
self.assertEqual(list(u), [1, 2])
|
||||
@@ -117,14 +174,15 @@ class ListTest(list_tests.CommonTest):
|
||||
lst *= size
|
||||
|
||||
|
||||
def test_repr_mutate(self):
|
||||
- class Obj:
|
||||
- @staticmethod
|
||||
|
|
@ -126,7 +126,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
|
|||
- except IndexError:
|
||||
- pass
|
||||
- return 'obj'
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Obj:
|
||||
+ @staticmethod
|
||||
+ def __repr__():
|
||||
|
|
@ -135,7 +135,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
|
|||
+ except IndexError:
|
||||
+ pass
|
||||
+ return 'obj'
|
||||
|
||||
|
||||
mylist = [Obj() for _ in range(5)]
|
||||
self.assertEqual(repr(mylist), '[obj, obj, obj]')
|
||||
@@ -220,26 +278,28 @@ class ListTest(list_tests.CommonTest):
|
||||
|
|
@ -143,11 +143,11 @@ index 23ef902aa0b..b9afb1ef26e 100644
|
|||
# optimization causes failures in code that relies on distinct
|
||||
# function addresses.
|
||||
- class L(list): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class L(list): pass
|
||||
with self.assertRaises(TypeError):
|
||||
(3,) + L([1,2])
|
||||
|
||||
|
||||
def test_equal_operator_modifying_operand(self):
|
||||
# test fix for seg fault reported in bpo-38588 part 2.
|
||||
- class X:
|
||||
|
|
@ -164,7 +164,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
|
|||
- def __eq__(self, other):
|
||||
- list3.clear()
|
||||
- return NotImplemented
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class X:
|
||||
+ def __eq__(self,other) :
|
||||
+ list2.clear()
|
||||
|
|
@ -179,29 +179,29 @@ index 23ef902aa0b..b9afb1ef26e 100644
|
|||
+ def __eq__(self, other):
|
||||
+ list3.clear()
|
||||
+ return NotImplemented
|
||||
|
||||
|
||||
list1 = [X()]
|
||||
list2 = [Y()]
|
||||
@@ -250,24 +310,26 @@ class ListTest(list_tests.CommonTest):
|
||||
self.assertFalse(list3 == list4)
|
||||
|
||||
|
||||
def test_lt_operator_modifying_operand(self):
|
||||
- # See gh-120298
|
||||
- class evil:
|
||||
- def __lt__(self, other):
|
||||
- other.clear()
|
||||
- return NotImplemented
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ # See gh-120298
|
||||
+ class evil:
|
||||
+ def __lt__(self, other):
|
||||
+ other.clear()
|
||||
+ return NotImplemented
|
||||
|
||||
|
||||
a = [[evil()]]
|
||||
with self.assertRaises(TypeError):
|
||||
a[0] < a
|
||||
|
||||
|
||||
def test_list_index_modifing_operand(self):
|
||||
- # See gh-120384
|
||||
- class evil:
|
||||
|
|
@ -210,7 +210,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
|
|||
- def __iter__(self):
|
||||
- yield from self.lst
|
||||
- self.lst.clear()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ # See gh-120384
|
||||
+ class evil:
|
||||
+ def __init__(self, lst):
|
||||
|
|
@ -218,7 +218,7 @@ index 23ef902aa0b..b9afb1ef26e 100644
|
|||
+ def __iter__(self):
|
||||
+ yield from self.lst
|
||||
+ self.lst.clear()
|
||||
|
||||
|
||||
lst = list(range(5))
|
||||
operand = evil(lst)
|
||||
@@ -286,19 +348,21 @@ class ListTest(list_tests.CommonTest):
|
||||
|
|
@ -229,39 +229,39 @@ index 23ef902aa0b..b9afb1ef26e 100644
|
|||
- def __eq__(self, other):
|
||||
- lst.clear()
|
||||
- return NotImplemented
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class X:
|
||||
+ def __eq__(self, other):
|
||||
+ lst.clear()
|
||||
+ return NotImplemented
|
||||
|
||||
|
||||
lst = [X()]
|
||||
with self.assertRaises(ValueError):
|
||||
lst.index(lst)
|
||||
|
||||
|
||||
- class L(list):
|
||||
- def __eq__(self, other):
|
||||
- str(other)
|
||||
- return NotImplemented
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class L(list):
|
||||
+ def __eq__(self, other):
|
||||
+ str(other)
|
||||
+ return NotImplemented
|
||||
|
||||
|
||||
lst = L([X()])
|
||||
lst.count(lst)
|
||||
@@ -324,6 +388,7 @@ class ListTest(list_tests.CommonTest):
|
||||
a.append(4)
|
||||
self.assertEqual(list(it), [])
|
||||
|
||||
|
||||
+ @unittest.skip("Fails on python <=3.13.2 and passes on >=3.13.3")
|
||||
def test_deopt_from_append_list(self):
|
||||
# gh-132011: it used to crash, because
|
||||
# of `CALL_LIST_APPEND` specialization failure.
|
||||
@@ -345,4 +410,4 @@ class ListTest(list_tests.CommonTest):
|
||||
self.assertEqual(rc, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class ListTest(list_tests.CommonTest):
|
|||
list(sequence=[])
|
||||
|
||||
def test_keywords_in_subclass(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class subclass(list):
|
||||
pass
|
||||
u = subclass([1, 2])
|
||||
|
|
@ -110,7 +110,7 @@ class ListTest(list_tests.CommonTest):
|
|||
with self.assertRaises(TypeError):
|
||||
subclass(sequence=())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class subclass_with_init(list):
|
||||
def __init__(self, seq, newarg=None):
|
||||
super().__init__(seq)
|
||||
|
|
@ -120,7 +120,7 @@ class ListTest(list_tests.CommonTest):
|
|||
self.assertEqual(list(u), [1, 2])
|
||||
self.assertEqual(u.newarg, 3)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class subclass_with_new(list):
|
||||
def __new__(cls, seq, newarg=None):
|
||||
self = super().__new__(cls, seq)
|
||||
|
|
@ -172,7 +172,7 @@ class ListTest(list_tests.CommonTest):
|
|||
lst *= size
|
||||
|
||||
def test_repr_mutate(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Obj:
|
||||
@staticmethod
|
||||
def __repr__():
|
||||
|
|
@ -276,14 +276,14 @@ class ListTest(list_tests.CommonTest):
|
|||
# Issue 8847: In the PGO build, the MSVC linker's COMDAT folding
|
||||
# optimization causes failures in code that relies on distinct
|
||||
# function addresses.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class L(list): pass
|
||||
with self.assertRaises(TypeError):
|
||||
(3,) + L([1,2])
|
||||
|
||||
def test_equal_operator_modifying_operand(self):
|
||||
# test fix for seg fault reported in bpo-38588 part 2.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class X:
|
||||
def __eq__(self,other) :
|
||||
list2.clear()
|
||||
|
|
@ -308,7 +308,7 @@ class ListTest(list_tests.CommonTest):
|
|||
self.assertFalse(list3 == list4)
|
||||
|
||||
def test_lt_operator_modifying_operand(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
# See gh-120298
|
||||
class evil:
|
||||
def __lt__(self, other):
|
||||
|
|
@ -320,7 +320,7 @@ class ListTest(list_tests.CommonTest):
|
|||
a[0] < a
|
||||
|
||||
def test_list_index_modifing_operand(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
# See gh-120384
|
||||
class evil:
|
||||
def __init__(self, lst):
|
||||
|
|
@ -346,7 +346,7 @@ class ListTest(list_tests.CommonTest):
|
|||
# bpo-38610: The count(), index(), and remove() methods were not
|
||||
# holding strong references to list elements while calling
|
||||
# PyObject_RichCompareBool().
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class X:
|
||||
def __eq__(self, other):
|
||||
lst.clear()
|
||||
|
|
@ -356,7 +356,7 @@ class ListTest(list_tests.CommonTest):
|
|||
with self.assertRaises(ValueError):
|
||||
lst.index(lst)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class L(list):
|
||||
def __eq__(self, other):
|
||||
str(other)
|
||||
|
|
|
|||
|
|
@ -63,20 +63,20 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
+
|
||||
# Python test set -- math module
|
||||
# XXXX Should not do tests around zero only
|
||||
|
||||
|
||||
@@ -242,7 +300,7 @@ class BadDescr:
|
||||
def __get__(self, obj, objtype=None):
|
||||
raise ValueError
|
||||
|
||||
|
||||
-class MathTests(unittest.TestCase):
|
||||
+class MathTests(__TestCase):
|
||||
|
||||
|
||||
def ftest(self, name, got, expected, ulp_tol=5, abs_tol=0.0):
|
||||
"""Compare arguments expected and got, as floats, if either
|
||||
@@ -417,16 +475,17 @@ class MathTests(unittest.TestCase):
|
||||
#self.assertEqual(math.ceil(NINF), NINF)
|
||||
#self.assertTrue(math.isnan(math.ceil(NAN)))
|
||||
|
||||
|
||||
- class TestCeil:
|
||||
- def __ceil__(self):
|
||||
- return 42
|
||||
|
|
@ -87,7 +87,7 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
- pass
|
||||
- class TestBadCeil:
|
||||
- __ceil__ = BadDescr()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class TestCeil:
|
||||
+ def __ceil__(self):
|
||||
+ return 42
|
||||
|
|
@ -104,7 +104,7 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
@@ -533,6 +592,7 @@ class MathTests(unittest.TestCase):
|
||||
self.ftest('fabs(0)', math.fabs(0), 0)
|
||||
self.ftest('fabs(1)', math.fabs(1), 1)
|
||||
|
||||
|
||||
+ @skipIfTorchDynamo("infinite loop")
|
||||
def testFactorial(self):
|
||||
self.assertEqual(math.factorial(0), 1)
|
||||
|
|
@ -112,7 +112,7 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
@@ -573,16 +633,17 @@ class MathTests(unittest.TestCase):
|
||||
#self.assertEqual(math.ceil(NINF), NINF)
|
||||
#self.assertTrue(math.isnan(math.floor(NAN)))
|
||||
|
||||
|
||||
- class TestFloor:
|
||||
- def __floor__(self):
|
||||
- return 42
|
||||
|
|
@ -123,7 +123,7 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
- pass
|
||||
- class TestBadFloor:
|
||||
- __floor__ = BadDescr()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class TestFloor:
|
||||
+ def __floor__(self):
|
||||
+ return 42
|
||||
|
|
@ -139,32 +139,32 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
self.assertEqual(math.floor(FloatLike(41.9)), 41)
|
||||
@@ -995,8 +1056,9 @@ class MathTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
# Verify tuple subclasses are allowed
|
||||
- class T(tuple):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class T(tuple):
|
||||
+ pass
|
||||
self.assertEqual(dist(T((1, 2, 3)), ((4, 2, -1))), 5.0)
|
||||
|
||||
|
||||
# Test handling of bad arguments
|
||||
@@ -1028,8 +1090,9 @@ class MathTests(unittest.TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
dist([1], 2)
|
||||
|
||||
|
||||
- class BadFloat:
|
||||
- __float__ = BadDescr()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class BadFloat:
|
||||
+ __float__ = BadDescr()
|
||||
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
dist([1], [BadFloat()])
|
||||
@@ -1072,6 +1135,7 @@ class MathTests(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
math.dist([1, 2], [3, 4, 5])
|
||||
|
||||
|
||||
+ @slowTest
|
||||
def testIsqrt(self):
|
||||
# Test a variety of inputs, large and small.
|
||||
|
|
@ -172,26 +172,26 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
@@ -1101,12 +1165,13 @@ class MathTests(unittest.TestCase):
|
||||
self.assertIs(type(s), int)
|
||||
self.assertEqual(s, 0)
|
||||
|
||||
|
||||
- class IntegerLike(object):
|
||||
- def __init__(self, value):
|
||||
- self.value = value
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class IntegerLike(object):
|
||||
+ def __init__(self, value):
|
||||
+ self.value = value
|
||||
|
||||
|
||||
- def __index__(self):
|
||||
- return self.value
|
||||
+ def __index__(self):
|
||||
+ return self.value
|
||||
|
||||
|
||||
s = math.isqrt(IntegerLike(1729))
|
||||
self.assertIs(type(s), int)
|
||||
@@ -1202,12 +1267,6 @@ class MathTests(unittest.TestCase):
|
||||
self.assertEqual(math.ldexp(NINF, n), NINF)
|
||||
self.assertTrue(math.isnan(math.ldexp(NAN, n)))
|
||||
|
||||
|
||||
- @requires_IEEE_754
|
||||
- def testLdexp_denormal(self):
|
||||
- # Denormal output incorrectly rounded (truncated)
|
||||
|
|
@ -204,7 +204,7 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
@@ -1233,6 +1292,7 @@ class MathTests(unittest.TestCase):
|
||||
self.assertRaises(ValueError, math.log1p, -1)
|
||||
self.assertEqual(math.log1p(INF), INF)
|
||||
|
||||
|
||||
+ @skipIfTorchDynamo("Infinite loop")
|
||||
@requires_IEEE_754
|
||||
def testLog2(self):
|
||||
|
|
@ -212,7 +212,7 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
@@ -1251,6 +1311,7 @@ class MathTests(unittest.TestCase):
|
||||
self.assertRaises(ValueError, math.log2, NINF)
|
||||
self.assertTrue(math.isnan(math.log2(NAN)))
|
||||
|
||||
|
||||
+ @skipIfTorchDynamo("Infinite loop")
|
||||
@requires_IEEE_754
|
||||
# log2() is not accurate enough on Mac OS X Tiger (10.4)
|
||||
|
|
@ -220,20 +220,20 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
@@ -1332,17 +1393,18 @@ class MathTests(unittest.TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
sumprod(raise_after(5), range(10))
|
||||
|
||||
|
||||
- from test.test_iter import BasicIterClass
|
||||
+ from test_iter import BasicIterClass
|
||||
|
||||
|
||||
self.assertEqual(sumprod(BasicIterClass(1), [1]), 0)
|
||||
self.assertEqual(sumprod([1], BasicIterClass(1)), 0)
|
||||
|
||||
|
||||
# Error in multiplication
|
||||
- class BadMultiply:
|
||||
- def __mul__(self, other):
|
||||
- raise RuntimeError
|
||||
- def __rmul__(self, other):
|
||||
- raise RuntimeError
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class BadMultiply:
|
||||
+ def __mul__(self, other):
|
||||
+ raise RuntimeError
|
||||
|
|
@ -245,7 +245,7 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
@@ -1387,25 +1449,26 @@ class MathTests(unittest.TestCase):
|
||||
Decimal = decimal.Decimal
|
||||
Fraction = fractions.Fraction
|
||||
|
||||
|
||||
- class Int(int):
|
||||
- def __add__(self, other):
|
||||
- return Int(int(self) + int(other))
|
||||
|
|
@ -265,7 +265,7 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
- __rmul__ = __mul__
|
||||
- def __repr__(self):
|
||||
- return f'Flt({int(self)})'
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Int(int):
|
||||
+ def __add__(self, other):
|
||||
+ return Int(int(self) + int(other))
|
||||
|
|
@ -285,13 +285,13 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
+ __rmul__ = __mul__
|
||||
+ def __repr__(self):
|
||||
+ return f'Flt({int(self)})'
|
||||
|
||||
|
||||
def baseline_sumprod(p, q):
|
||||
"""This defines the target behavior including exceptions and special values.
|
||||
@@ -1925,16 +1988,17 @@ class MathTests(unittest.TestCase):
|
||||
self.assertEqual(math.trunc(-0.999999), -0)
|
||||
self.assertEqual(math.trunc(-100.999), -100)
|
||||
|
||||
|
||||
- class TestTrunc:
|
||||
- def __trunc__(self):
|
||||
- return 23
|
||||
|
|
@ -302,7 +302,7 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
- pass
|
||||
- class TestBadTrunc:
|
||||
- __trunc__ = BadDescr()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class TestTrunc:
|
||||
+ def __trunc__(self):
|
||||
+ return 23
|
||||
|
|
@ -313,27 +313,27 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
+ pass
|
||||
+ class TestBadTrunc:
|
||||
+ __trunc__ = BadDescr()
|
||||
|
||||
|
||||
self.assertEqual(math.trunc(TestTrunc()), 23)
|
||||
self.assertEqual(math.trunc(FloatTrunc()), 23)
|
||||
@@ -2167,9 +2231,10 @@ class MathTests(unittest.TestCase):
|
||||
self.assertEqual(prod([1., F(3, 2)]), 1.5)
|
||||
|
||||
|
||||
# Error in multiplication
|
||||
- class BadMultiply:
|
||||
- def __rmul__(self, other):
|
||||
- raise RuntimeError
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class BadMultiply:
|
||||
+ def __rmul__(self, other):
|
||||
+ raise RuntimeError
|
||||
with self.assertRaises(RuntimeError):
|
||||
prod([10., BadMultiply()])
|
||||
|
||||
|
||||
@@ -2252,6 +2317,7 @@ class MathTests(unittest.TestCase):
|
||||
self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])),
|
||||
decimal.Decimal)
|
||||
|
||||
|
||||
+ @skipIfTorchDynamo("Infinite loop")
|
||||
def testPerm(self):
|
||||
perm = math.perm
|
||||
|
|
@ -341,15 +341,15 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
@@ -2316,6 +2382,7 @@ class MathTests(unittest.TestCase):
|
||||
self.assertIs(type(perm(IntSubclass(5), IntSubclass(k))), int)
|
||||
self.assertIs(type(perm(MyIndexable(5), MyIndexable(k))), int)
|
||||
|
||||
|
||||
+ @skipIfTorchDynamo("infinite loop")
|
||||
def testComb(self):
|
||||
comb = math.comb
|
||||
factorial = math.factorial
|
||||
@@ -2446,6 +2513,7 @@ class MathTests(unittest.TestCase):
|
||||
math.nextafter(1.0, INF, steps=-1)
|
||||
|
||||
|
||||
|
||||
|
||||
+ @unittest.skip("flaky test under torch dynamo") # works on pytest and crashes on unittest
|
||||
@requires_IEEE_754
|
||||
def test_ulp(self):
|
||||
|
|
@ -362,7 +362,7 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
- def __float__(self):
|
||||
- self.converted = True
|
||||
- 1/0
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class F:
|
||||
+ def __float__(self):
|
||||
+ self.converted = True
|
||||
|
|
@ -372,21 +372,21 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
with self.assertRaises(TypeError):
|
||||
@@ -2508,7 +2577,7 @@ class MathTests(unittest.TestCase):
|
||||
self.assertEqual(math.copysign(1.0, x), math.copysign(1.0, y))
|
||||
|
||||
|
||||
|
||||
|
||||
-class IsCloseTests(unittest.TestCase):
|
||||
+class IsCloseTests(__TestCase):
|
||||
isclose = math.isclose # subclasses should override this
|
||||
|
||||
|
||||
def assertIsClose(self, a, b, *args, **kwargs):
|
||||
@@ -2631,7 +2700,7 @@ class IsCloseTests(unittest.TestCase):
|
||||
self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
|
||||
|
||||
|
||||
|
||||
|
||||
-class FMATests(unittest.TestCase):
|
||||
+class FMATests(__TestCase):
|
||||
""" Tests for math.fma. """
|
||||
|
||||
|
||||
def test_fma_nan_results(self):
|
||||
@@ -2719,8 +2788,7 @@ class FMATests(unittest.TestCase):
|
||||
# properly: it doesn't use the right sign when the result is zero.
|
||||
|
|
@ -400,8 +400,8 @@ index 5ee3055c871..5402cdc4a6c 100644
|
|||
nonnegative_finites = [0.0, 1e-300, 2.3, 1e300]
|
||||
@@ -2879,10 +2947,5 @@ class FMATests(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
-def load_tests(loader, tests, pattern):
|
||||
- from doctest import DocFileSuite
|
||||
- tests.addTest(DocFileSuite(os.path.join("mathdata", "ieee754.txt")))
|
||||
|
|
|
|||
|
|
@ -475,7 +475,7 @@ class MathTests(__TestCase):
|
|||
#self.assertEqual(math.ceil(NINF), NINF)
|
||||
#self.assertTrue(math.isnan(math.ceil(NAN)))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class TestCeil:
|
||||
def __ceil__(self):
|
||||
return 42
|
||||
|
|
@ -633,7 +633,7 @@ class MathTests(__TestCase):
|
|||
#self.assertEqual(math.ceil(NINF), NINF)
|
||||
#self.assertTrue(math.isnan(math.floor(NAN)))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class TestFloor:
|
||||
def __floor__(self):
|
||||
return 42
|
||||
|
|
@ -1056,7 +1056,7 @@ class MathTests(__TestCase):
|
|||
)
|
||||
|
||||
# Verify tuple subclasses are allowed
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class T(tuple):
|
||||
pass
|
||||
self.assertEqual(dist(T((1, 2, 3)), ((4, 2, -1))), 5.0)
|
||||
|
|
@ -1090,7 +1090,7 @@ class MathTests(__TestCase):
|
|||
with self.assertRaises(TypeError):
|
||||
dist([1], 2)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BadFloat:
|
||||
__float__ = BadDescr()
|
||||
|
||||
|
|
@ -1165,7 +1165,7 @@ class MathTests(__TestCase):
|
|||
self.assertIs(type(s), int)
|
||||
self.assertEqual(s, 0)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class IntegerLike(object):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
|
@ -1399,7 +1399,7 @@ class MathTests(__TestCase):
|
|||
self.assertEqual(sumprod([1], BasicIterClass(1)), 0)
|
||||
|
||||
# Error in multiplication
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BadMultiply:
|
||||
def __mul__(self, other):
|
||||
raise RuntimeError
|
||||
|
|
@ -1449,7 +1449,7 @@ class MathTests(__TestCase):
|
|||
Decimal = decimal.Decimal
|
||||
Fraction = fractions.Fraction
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Int(int):
|
||||
def __add__(self, other):
|
||||
return Int(int(self) + int(other))
|
||||
|
|
@ -1988,7 +1988,7 @@ class MathTests(__TestCase):
|
|||
self.assertEqual(math.trunc(-0.999999), -0)
|
||||
self.assertEqual(math.trunc(-100.999), -100)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class TestTrunc:
|
||||
def __trunc__(self):
|
||||
return 23
|
||||
|
|
@ -2231,7 +2231,7 @@ class MathTests(__TestCase):
|
|||
self.assertEqual(prod([1., F(3, 2)]), 1.5)
|
||||
|
||||
# Error in multiplication
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BadMultiply:
|
||||
def __rmul__(self, other):
|
||||
raise RuntimeError
|
||||
|
|
@ -2540,7 +2540,7 @@ class MathTests(__TestCase):
|
|||
def test_issue39871(self):
|
||||
# A SystemError should not be raised if the first arg to atan2(),
|
||||
# copysign(), or remainder() cannot be converted to a float.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class F:
|
||||
def __float__(self):
|
||||
self.converted = True
|
||||
|
|
|
|||
|
|
@ -27,13 +27,13 @@ index d90f820052c..5d9fdfb70a4 100644
|
|||
import inspect
|
||||
import pickle
|
||||
@@ -84,9 +104,10 @@ class OperatorTestCase:
|
||||
|
||||
|
||||
def test_eq(self):
|
||||
operator = self.module
|
||||
- class C(object):
|
||||
- def __eq__(self, other):
|
||||
- raise SyntaxError
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C(object):
|
||||
+ def __eq__(self, other):
|
||||
+ raise SyntaxError
|
||||
|
|
@ -41,13 +41,13 @@ index d90f820052c..5d9fdfb70a4 100644
|
|||
self.assertRaises(SyntaxError, operator.eq, C(), C())
|
||||
self.assertFalse(operator.eq(1, 0))
|
||||
@@ -98,9 +119,10 @@ class OperatorTestCase:
|
||||
|
||||
|
||||
def test_ne(self):
|
||||
operator = self.module
|
||||
- class C(object):
|
||||
- def __ne__(self, other):
|
||||
- raise SyntaxError
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C(object):
|
||||
+ def __ne__(self, other):
|
||||
+ raise SyntaxError
|
||||
|
|
@ -61,21 +61,21 @@ index d90f820052c..5d9fdfb70a4 100644
|
|||
- class M:
|
||||
- def __matmul__(self, other):
|
||||
- return other - 1
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class M:
|
||||
+ def __matmul__(self, other):
|
||||
+ return other - 1
|
||||
self.assertEqual(M() @ 42, 41)
|
||||
|
||||
|
||||
def test_neg(self):
|
||||
@@ -315,9 +338,10 @@ class OperatorTestCase:
|
||||
|
||||
|
||||
def test_truth(self):
|
||||
operator = self.module
|
||||
- class C(object):
|
||||
- def __bool__(self):
|
||||
- raise SyntaxError
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C(object):
|
||||
+ def __bool__(self):
|
||||
+ raise SyntaxError
|
||||
|
|
@ -83,12 +83,12 @@ index d90f820052c..5d9fdfb70a4 100644
|
|||
self.assertRaises(SyntaxError, operator.truth, C())
|
||||
self.assertTrue(operator.truth(5))
|
||||
@@ -349,8 +373,9 @@ class OperatorTestCase:
|
||||
|
||||
|
||||
def test_attrgetter(self):
|
||||
operator = self.module
|
||||
- class A:
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A:
|
||||
+ pass
|
||||
a = A()
|
||||
|
|
@ -97,39 +97,39 @@ index d90f820052c..5d9fdfb70a4 100644
|
|||
@@ -371,9 +396,10 @@ class OperatorTestCase:
|
||||
self.assertEqual(operator.attrgetter('x','z','y')(record), ('X', 'Z', 'Y'))
|
||||
self.assertRaises(TypeError, operator.attrgetter, ('x', (), 'y'))
|
||||
|
||||
|
||||
- class C(object):
|
||||
- def __getattr__(self, name):
|
||||
- raise SyntaxError
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C(object):
|
||||
+ def __getattr__(self, name):
|
||||
+ raise SyntaxError
|
||||
self.assertRaises(SyntaxError, operator.attrgetter('foo'), C())
|
||||
|
||||
|
||||
# recursive gets
|
||||
@@ -411,9 +437,10 @@ class OperatorTestCase:
|
||||
f = operator.itemgetter(10)
|
||||
self.assertRaises(IndexError, f, a)
|
||||
|
||||
|
||||
- class C(object):
|
||||
- def __getitem__(self, name):
|
||||
- raise SyntaxError
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C(object):
|
||||
+ def __getitem__(self, name):
|
||||
+ raise SyntaxError
|
||||
self.assertRaises(SyntaxError, operator.itemgetter(42), C())
|
||||
|
||||
|
||||
f = operator.itemgetter('name')
|
||||
@@ -444,9 +471,10 @@ class OperatorTestCase:
|
||||
self.assertEqual(operator.itemgetter(slice(2, 4))(t), ('c', 'd'))
|
||||
|
||||
|
||||
# interesting sequences
|
||||
- class T(tuple):
|
||||
- 'Tuple subclass'
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class T(tuple):
|
||||
+ 'Tuple subclass'
|
||||
+ pass
|
||||
|
|
@ -147,7 +147,7 @@ index d90f820052c..5d9fdfb70a4 100644
|
|||
- return f
|
||||
- def baz(*args, **kwds):
|
||||
- return kwds['name'], kwds['self']
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A:
|
||||
+ def foo(self, *args, **kwds):
|
||||
+ return args[0] + args[1]
|
||||
|
|
@ -159,7 +159,7 @@ index d90f820052c..5d9fdfb70a4 100644
|
|||
f = operator.methodcaller('foo')
|
||||
self.assertRaises(IndexError, f, a)
|
||||
@@ -480,21 +509,22 @@ class OperatorTestCase:
|
||||
|
||||
|
||||
def test_inplace(self):
|
||||
operator = self.module
|
||||
- class C(object):
|
||||
|
|
@ -177,7 +177,7 @@ index d90f820052c..5d9fdfb70a4 100644
|
|||
- def __itruediv__ (self, other): return "itruediv"
|
||||
- def __ixor__ (self, other): return "ixor"
|
||||
- def __getitem__(self, other): return 5 # so that C is a sequence
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C(object):
|
||||
+ def __iadd__ (self, other): return "iadd"
|
||||
+ def __iand__ (self, other): return "iand"
|
||||
|
|
@ -197,27 +197,27 @@ index d90f820052c..5d9fdfb70a4 100644
|
|||
self.assertEqual(operator.iadd (c, 5), "iadd")
|
||||
self.assertEqual(operator.iand (c, 5), "iand")
|
||||
@@ -520,9 +550,10 @@ class OperatorTestCase:
|
||||
|
||||
|
||||
def test_index(self):
|
||||
operator = self.module
|
||||
- class X:
|
||||
- def __index__(self):
|
||||
- return 1
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class X:
|
||||
+ def __index__(self):
|
||||
+ return 1
|
||||
|
||||
|
||||
self.assertEqual(operator.index(X()), 1)
|
||||
self.assertEqual(operator.index(0), 0)
|
||||
@@ -539,9 +570,10 @@ class OperatorTestCase:
|
||||
|
||||
|
||||
def test_not_(self):
|
||||
operator = self.module
|
||||
- class C:
|
||||
- def __bool__(self):
|
||||
- raise SyntaxError
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C:
|
||||
+ def __bool__(self):
|
||||
+ raise SyntaxError
|
||||
|
|
@ -225,17 +225,17 @@ index d90f820052c..5d9fdfb70a4 100644
|
|||
self.assertRaises(SyntaxError, operator.not_, C())
|
||||
self.assertFalse(operator.not_(5))
|
||||
@@ -551,15 +583,16 @@ class OperatorTestCase:
|
||||
|
||||
|
||||
def test_length_hint(self):
|
||||
operator = self.module
|
||||
- class X(object):
|
||||
- def __init__(self, value):
|
||||
- self.value = value
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class X(object):
|
||||
+ def __init__(self, value):
|
||||
+ self.value = value
|
||||
|
||||
|
||||
- def __length_hint__(self):
|
||||
- if type(self.value) is type:
|
||||
- raise self.value
|
||||
|
|
@ -246,47 +246,47 @@ index d90f820052c..5d9fdfb70a4 100644
|
|||
+ raise self.value
|
||||
+ else:
|
||||
+ return self.value
|
||||
|
||||
|
||||
self.assertEqual(operator.length_hint([], 2), 0)
|
||||
self.assertEqual(operator.length_hint(iter([1, 2, 3])), 3)
|
||||
@@ -574,7 +607,8 @@ class OperatorTestCase:
|
||||
with self.assertRaises(LookupError):
|
||||
operator.length_hint(X(LookupError))
|
||||
|
||||
|
||||
- class Y: pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Y: pass
|
||||
|
||||
|
||||
msg = "'str' object cannot be interpreted as an integer"
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
@@ -628,11 +662,11 @@ class OperatorTestCase:
|
||||
self.assertEqual(str(sig), '(obj, /)')
|
||||
|
||||
|
||||
|
||||
|
||||
-class PyOperatorTestCase(OperatorTestCase, unittest.TestCase):
|
||||
+class PyOperatorTestCase(OperatorTestCase, __TestCase):
|
||||
module = py_operator
|
||||
|
||||
|
||||
@unittest.skipUnless(c_operator, 'requires _operator')
|
||||
-class COperatorTestCase(OperatorTestCase, unittest.TestCase):
|
||||
+class COperatorTestCase(OperatorTestCase, __TestCase):
|
||||
module = c_operator
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -645,8 +679,9 @@ class OperatorPickleTestCase:
|
||||
|
||||
|
||||
def test_attrgetter(self):
|
||||
attrgetter = self.module.attrgetter
|
||||
- class A:
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A:
|
||||
+ pass
|
||||
a = A()
|
||||
a.x = 'X'
|
||||
a.y = 'Y'
|
||||
@@ -688,13 +723,14 @@ class OperatorPickleTestCase:
|
||||
|
||||
|
||||
def test_methodcaller(self):
|
||||
methodcaller = self.module.methodcaller
|
||||
- class A:
|
||||
|
|
@ -296,7 +296,7 @@ index d90f820052c..5d9fdfb70a4 100644
|
|||
- return f
|
||||
- def baz(*args, **kwds):
|
||||
- return kwds['name'], kwds['self']
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A:
|
||||
+ def foo(self, *args, **kwds):
|
||||
+ return args[0] + args[1]
|
||||
|
|
@ -310,31 +310,31 @@ index d90f820052c..5d9fdfb70a4 100644
|
|||
@@ -717,25 +753,25 @@ class OperatorPickleTestCase:
|
||||
# Can't test repr consistently with multiple keyword args
|
||||
self.assertEqual(f2(a), f(a))
|
||||
|
||||
|
||||
-class PyPyOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
|
||||
+class PyPyOperatorPickleTestCase(OperatorPickleTestCase, __TestCase):
|
||||
module = py_operator
|
||||
module2 = py_operator
|
||||
|
||||
|
||||
@unittest.skipUnless(c_operator, 'requires _operator')
|
||||
-class PyCOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
|
||||
+class PyCOperatorPickleTestCase(OperatorPickleTestCase, __TestCase):
|
||||
module = py_operator
|
||||
module2 = c_operator
|
||||
|
||||
|
||||
@unittest.skipUnless(c_operator, 'requires _operator')
|
||||
-class CPyOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
|
||||
+class CPyOperatorPickleTestCase(OperatorPickleTestCase, __TestCase):
|
||||
module = c_operator
|
||||
module2 = py_operator
|
||||
|
||||
|
||||
@unittest.skipUnless(c_operator, 'requires _operator')
|
||||
-class CCOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
|
||||
+class CCOperatorPickleTestCase(OperatorPickleTestCase, __TestCase):
|
||||
module = c_operator
|
||||
module2 = c_operator
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ class OperatorTestCase:
|
|||
|
||||
def test_eq(self):
|
||||
operator = self.module
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C(object):
|
||||
def __eq__(self, other):
|
||||
raise SyntaxError
|
||||
|
|
@ -119,7 +119,7 @@ class OperatorTestCase:
|
|||
|
||||
def test_ne(self):
|
||||
operator = self.module
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C(object):
|
||||
def __ne__(self, other):
|
||||
raise SyntaxError
|
||||
|
|
@ -267,7 +267,7 @@ class OperatorTestCase:
|
|||
operator = self.module
|
||||
self.assertRaises(TypeError, operator.matmul)
|
||||
self.assertRaises(TypeError, operator.matmul, 42, 42)
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class M:
|
||||
def __matmul__(self, other):
|
||||
return other - 1
|
||||
|
|
@ -338,7 +338,7 @@ class OperatorTestCase:
|
|||
|
||||
def test_truth(self):
|
||||
operator = self.module
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C(object):
|
||||
def __bool__(self):
|
||||
raise SyntaxError
|
||||
|
|
@ -373,7 +373,7 @@ class OperatorTestCase:
|
|||
|
||||
def test_attrgetter(self):
|
||||
operator = self.module
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A:
|
||||
pass
|
||||
a = A()
|
||||
|
|
@ -396,7 +396,7 @@ class OperatorTestCase:
|
|||
self.assertEqual(operator.attrgetter('x','z','y')(record), ('X', 'Z', 'Y'))
|
||||
self.assertRaises(TypeError, operator.attrgetter, ('x', (), 'y'))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C(object):
|
||||
def __getattr__(self, name):
|
||||
raise SyntaxError
|
||||
|
|
@ -437,7 +437,7 @@ class OperatorTestCase:
|
|||
f = operator.itemgetter(10)
|
||||
self.assertRaises(IndexError, f, a)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C(object):
|
||||
def __getitem__(self, name):
|
||||
raise SyntaxError
|
||||
|
|
@ -471,7 +471,7 @@ class OperatorTestCase:
|
|||
self.assertEqual(operator.itemgetter(slice(2, 4))(t), ('c', 'd'))
|
||||
|
||||
# interesting sequences
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class T(tuple):
|
||||
'Tuple subclass'
|
||||
pass
|
||||
|
|
@ -483,7 +483,7 @@ class OperatorTestCase:
|
|||
operator = self.module
|
||||
self.assertRaises(TypeError, operator.methodcaller)
|
||||
self.assertRaises(TypeError, operator.methodcaller, 12)
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A:
|
||||
def foo(self, *args, **kwds):
|
||||
return args[0] + args[1]
|
||||
|
|
@ -509,7 +509,7 @@ class OperatorTestCase:
|
|||
|
||||
def test_inplace(self):
|
||||
operator = self.module
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C(object):
|
||||
def __iadd__ (self, other): return "iadd"
|
||||
def __iand__ (self, other): return "iand"
|
||||
|
|
@ -550,7 +550,7 @@ class OperatorTestCase:
|
|||
|
||||
def test_index(self):
|
||||
operator = self.module
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class X:
|
||||
def __index__(self):
|
||||
return 1
|
||||
|
|
@ -570,7 +570,7 @@ class OperatorTestCase:
|
|||
|
||||
def test_not_(self):
|
||||
operator = self.module
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
def __bool__(self):
|
||||
raise SyntaxError
|
||||
|
|
@ -583,7 +583,7 @@ class OperatorTestCase:
|
|||
|
||||
def test_length_hint(self):
|
||||
operator = self.module
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class X(object):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
|
@ -607,7 +607,7 @@ class OperatorTestCase:
|
|||
with self.assertRaises(LookupError):
|
||||
operator.length_hint(X(LookupError))
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Y: pass
|
||||
|
||||
msg = "'str' object cannot be interpreted as an integer"
|
||||
|
|
@ -679,7 +679,7 @@ class OperatorPickleTestCase:
|
|||
|
||||
def test_attrgetter(self):
|
||||
attrgetter = self.module.attrgetter
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A:
|
||||
pass
|
||||
a = A()
|
||||
|
|
@ -723,7 +723,7 @@ class OperatorPickleTestCase:
|
|||
|
||||
def test_methodcaller(self):
|
||||
methodcaller = self.module.methodcaller
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A:
|
||||
def foo(self, *args, **kwds):
|
||||
return args[0] + args[1]
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ index a9b6a84996e..efc4288d1a4 100644
|
|||
import contextlib
|
||||
import copy
|
||||
@@ -113,13 +170,14 @@ class OrderedDictTests:
|
||||
|
||||
|
||||
def test_init_calls(self):
|
||||
calls = []
|
||||
- class Spam:
|
||||
|
|
@ -74,7 +74,7 @@ index a9b6a84996e..efc4288d1a4 100644
|
|||
- def items(self):
|
||||
- calls.append('items')
|
||||
- return ()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Spam:
|
||||
+ def keys(self):
|
||||
+ calls.append('keys')
|
||||
|
|
@ -82,7 +82,7 @@ index a9b6a84996e..efc4288d1a4 100644
|
|||
+ def items(self):
|
||||
+ calls.append('items')
|
||||
+ return ()
|
||||
|
||||
|
||||
self.OrderedDict(Spam())
|
||||
self.assertEqual(calls, ['keys'])
|
||||
@@ -129,9 +187,10 @@ class OrderedDictTests:
|
||||
|
|
@ -92,21 +92,21 @@ index a9b6a84996e..efc4288d1a4 100644
|
|||
- class ODNI(OrderedDict):
|
||||
- def __init__(*args, **kwargs):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class ODNI(OrderedDict):
|
||||
+ def __init__(*args, **kwargs):
|
||||
+ pass
|
||||
od = ODNI()
|
||||
od['a'] = 1 # This used to fail because __init__ was bypassed
|
||||
|
||||
|
||||
@@ -267,9 +326,10 @@ class OrderedDictTests:
|
||||
self.assertEqual(od.pop(k, 12345), 12345)
|
||||
|
||||
|
||||
# make sure pop still works when __missing__ is defined
|
||||
- class Missing(OrderedDict):
|
||||
- def __missing__(self, key):
|
||||
- return 0
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Missing(OrderedDict):
|
||||
+ def __missing__(self, key):
|
||||
+ return 0
|
||||
|
|
@ -115,17 +115,17 @@ index a9b6a84996e..efc4288d1a4 100644
|
|||
self.assertEqual(m.pop('a', 6), 1)
|
||||
@@ -416,9 +476,10 @@ class OrderedDictTests:
|
||||
self.assertEqual(od.setdefault('g', default=9), 9)
|
||||
|
||||
|
||||
# make sure setdefault still works when __missing__ is defined
|
||||
- class Missing(OrderedDict):
|
||||
- def __missing__(self, key):
|
||||
- return 0
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Missing(OrderedDict):
|
||||
+ def __missing__(self, key):
|
||||
+ return 0
|
||||
self.assertEqual(Missing().setdefault(5, 9), 9)
|
||||
|
||||
|
||||
def test_reinsert(self):
|
||||
@@ -484,9 +545,10 @@ class OrderedDictTests:
|
||||
def test_override_update(self):
|
||||
|
|
@ -134,13 +134,13 @@ index a9b6a84996e..efc4288d1a4 100644
|
|||
- class MyOD(OrderedDict):
|
||||
- def update(self, *args, **kwds):
|
||||
- raise Exception()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyOD(OrderedDict):
|
||||
+ def update(self, *args, **kwds):
|
||||
+ raise Exception()
|
||||
items = [('a', 1), ('c', 3), ('b', 2)]
|
||||
self.assertEqual(list(MyOD(items).items()), items)
|
||||
|
||||
|
||||
@@ -507,9 +569,10 @@ class OrderedDictTests:
|
||||
# should not crash Python.
|
||||
OrderedDict = self.OrderedDict
|
||||
|
|
@ -148,7 +148,7 @@ index a9b6a84996e..efc4288d1a4 100644
|
|||
- class MyOD(OrderedDict):
|
||||
- def __del__(self):
|
||||
- deleted.append(self.i)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyOD(OrderedDict):
|
||||
+ def __del__(self):
|
||||
+ deleted.append(self.i)
|
||||
|
|
@ -158,7 +158,7 @@ index a9b6a84996e..efc4288d1a4 100644
|
|||
@@ -521,19 +584,20 @@ class OrderedDictTests:
|
||||
def test_delitem_hash_collision(self):
|
||||
OrderedDict = self.OrderedDict
|
||||
|
||||
|
||||
- class Key:
|
||||
- def __init__(self, hash):
|
||||
- self._hash = hash
|
||||
|
|
@ -172,7 +172,7 @@ index a9b6a84996e..efc4288d1a4 100644
|
|||
- return False
|
||||
- def __repr__(self):
|
||||
- return self.value
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Key:
|
||||
+ def __init__(self, hash):
|
||||
+ self._hash = hash
|
||||
|
|
@ -186,149 +186,149 @@ index a9b6a84996e..efc4288d1a4 100644
|
|||
+ return False
|
||||
+ def __repr__(self):
|
||||
+ return self.value
|
||||
|
||||
|
||||
def blocking_hash(hash):
|
||||
# See the collision-handling in lookdict (in Objects/dictobject.c).
|
||||
@@ -560,9 +624,10 @@ class OrderedDictTests:
|
||||
def test_issue24347(self):
|
||||
OrderedDict = self.OrderedDict
|
||||
|
||||
|
||||
- class Key:
|
||||
- def __hash__(self):
|
||||
- return randrange(100000)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Key:
|
||||
+ def __hash__(self):
|
||||
+ return randrange(100000)
|
||||
|
||||
|
||||
od = OrderedDict()
|
||||
for i in range(100):
|
||||
@@ -582,9 +647,10 @@ class OrderedDictTests:
|
||||
def test_issue24348(self):
|
||||
OrderedDict = self.OrderedDict
|
||||
|
||||
|
||||
- class Key:
|
||||
- def __hash__(self):
|
||||
- return 1
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Key:
|
||||
+ def __hash__(self):
|
||||
+ return 1
|
||||
|
||||
|
||||
od = OrderedDict()
|
||||
od[Key()] = 0
|
||||
@@ -760,15 +826,16 @@ class _TriggerSideEffectOnEqual:
|
||||
def side_effect(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
-class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase):
|
||||
+class PurePythonOrderedDictTests(OrderedDictTests, __TestCase):
|
||||
|
||||
|
||||
module = py_coll
|
||||
OrderedDict = py_coll.OrderedDict
|
||||
|
||||
|
||||
def test_issue119004_attribute_error(self):
|
||||
- class Key(_TriggerSideEffectOnEqual):
|
||||
- def side_effect(self):
|
||||
- del dict1[TODEL]
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Key(_TriggerSideEffectOnEqual):
|
||||
+ def side_effect(self):
|
||||
+ del dict1[TODEL]
|
||||
|
||||
|
||||
TODEL = Key()
|
||||
dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2)))
|
||||
@@ -781,7 +848,7 @@ class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase):
|
||||
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
|
||||
|
||||
|
||||
|
||||
|
||||
-class CPythonBuiltinDictTests(unittest.TestCase):
|
||||
+class CPythonBuiltinDictTests(__TestCase):
|
||||
"""Builtin dict preserves insertion order.
|
||||
|
||||
|
||||
Reuse some of tests in OrderedDict selectively.
|
||||
@@ -800,6 +867,7 @@ for method in (
|
||||
del method
|
||||
|
||||
|
||||
|
||||
|
||||
+
|
||||
class CPythonOrderedDictSideEffects:
|
||||
|
||||
|
||||
def check_runtime_error_issue119004(self, dict1, dict2):
|
||||
@@ -807,9 +875,10 @@ class CPythonOrderedDictSideEffects:
|
||||
self.assertRaisesRegex(RuntimeError, msg, operator.eq, dict1, dict2)
|
||||
|
||||
|
||||
def test_issue119004_change_size_by_clear(self):
|
||||
- class Key(_TriggerSideEffectOnEqual):
|
||||
- def side_effect(self):
|
||||
- dict1.clear()
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Key(_TriggerSideEffectOnEqual):
|
||||
+ def side_effect(self):
|
||||
+ dict1.clear()
|
||||
|
||||
|
||||
dict1 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2)))
|
||||
dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2)))
|
||||
@@ -819,9 +888,10 @@ class CPythonOrderedDictSideEffects:
|
||||
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
|
||||
|
||||
|
||||
def test_issue119004_change_size_by_delete_key(self):
|
||||
- class Key(_TriggerSideEffectOnEqual):
|
||||
- def side_effect(self):
|
||||
- del dict1[TODEL]
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Key(_TriggerSideEffectOnEqual):
|
||||
+ def side_effect(self):
|
||||
+ del dict1[TODEL]
|
||||
|
||||
|
||||
TODEL = Key()
|
||||
dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2)))
|
||||
@@ -832,10 +902,11 @@ class CPythonOrderedDictSideEffects:
|
||||
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
|
||||
|
||||
|
||||
def test_issue119004_change_linked_list_by_clear(self):
|
||||
- class Key(_TriggerSideEffectOnEqual):
|
||||
- def side_effect(self):
|
||||
- dict1.clear()
|
||||
- dict1['a'] = dict1['b'] = 'c'
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Key(_TriggerSideEffectOnEqual):
|
||||
+ def side_effect(self):
|
||||
+ dict1.clear()
|
||||
+ dict1['a'] = dict1['b'] = 'c'
|
||||
|
||||
|
||||
dict1 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2)))
|
||||
dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2)))
|
||||
@@ -845,10 +916,11 @@ class CPythonOrderedDictSideEffects:
|
||||
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
|
||||
|
||||
|
||||
def test_issue119004_change_linked_list_by_delete_key(self):
|
||||
- class Key(_TriggerSideEffectOnEqual):
|
||||
- def side_effect(self):
|
||||
- del dict1[TODEL]
|
||||
- dict1['a'] = 'c'
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Key(_TriggerSideEffectOnEqual):
|
||||
+ def side_effect(self):
|
||||
+ del dict1[TODEL]
|
||||
+ dict1['a'] = 'c'
|
||||
|
||||
|
||||
TODEL = Key()
|
||||
dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2)))
|
||||
@@ -859,10 +931,11 @@ class CPythonOrderedDictSideEffects:
|
||||
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
|
||||
|
||||
|
||||
def test_issue119004_change_size_by_delete_key_in_dict_eq(self):
|
||||
- class Key(_TriggerSideEffectOnEqual):
|
||||
- trigger = 0
|
||||
- def side_effect(self):
|
||||
- del dict1[TODEL]
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Key(_TriggerSideEffectOnEqual):
|
||||
+ trigger = 0
|
||||
+ def side_effect(self):
|
||||
+ del dict1[TODEL]
|
||||
|
||||
|
||||
TODEL = Key()
|
||||
dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2)))
|
||||
@@ -878,7 +951,7 @@ class CPythonOrderedDictSideEffects:
|
||||
|
|
@ -337,25 +337,25 @@ index a9b6a84996e..efc4288d1a4 100644
|
|||
CPythonOrderedDictSideEffects,
|
||||
- unittest.TestCase):
|
||||
+ __TestCase):
|
||||
|
||||
|
||||
module = c_coll
|
||||
OrderedDict = c_coll.OrderedDict
|
||||
@@ -986,7 +1059,7 @@ class CPythonOrderedDictSubclassTests(CPythonOrderedDictTests):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
-class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
|
||||
+class PurePythonOrderedDictWithSlotsCopyingTests(__TestCase):
|
||||
|
||||
|
||||
module = py_coll
|
||||
class OrderedDict(py_coll.OrderedDict):
|
||||
@@ -995,7 +1068,7 @@ class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
|
||||
|
||||
|
||||
|
||||
|
||||
@unittest.skipUnless(c_coll, 'requires the C version of the collections module')
|
||||
-class CPythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
|
||||
+class CPythonOrderedDictWithSlotsCopyingTests(__TestCase):
|
||||
|
||||
|
||||
module = c_coll
|
||||
class OrderedDict(c_coll.OrderedDict):
|
||||
@@ -1008,6 +1081,7 @@ class PurePythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
|
||||
|
|
@ -363,7 +363,7 @@ index a9b6a84996e..efc4288d1a4 100644
|
|||
def setUpClass(cls):
|
||||
cls.type2test = py_coll.OrderedDict
|
||||
+ super().setUpClass()
|
||||
|
||||
|
||||
def test_popitem(self):
|
||||
d = self._empty_mapping()
|
||||
@@ -1020,6 +1094,7 @@ class CPythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
|
||||
|
|
@ -371,7 +371,7 @@ index a9b6a84996e..efc4288d1a4 100644
|
|||
def setUpClass(cls):
|
||||
cls.type2test = c_coll.OrderedDict
|
||||
+ super().setUpClass()
|
||||
|
||||
|
||||
def test_popitem(self):
|
||||
d = self._empty_mapping()
|
||||
@@ -1033,6 +1108,7 @@ class PurePythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol):
|
||||
|
|
@ -379,7 +379,7 @@ index a9b6a84996e..efc4288d1a4 100644
|
|||
pass
|
||||
cls.type2test = MyOrderedDict
|
||||
+ super().setUpClass()
|
||||
|
||||
|
||||
def test_popitem(self):
|
||||
d = self._empty_mapping()
|
||||
@@ -1047,6 +1123,7 @@ class CPythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol):
|
||||
|
|
@ -387,32 +387,32 @@ index a9b6a84996e..efc4288d1a4 100644
|
|||
pass
|
||||
cls.type2test = MyOrderedDict
|
||||
+ super().setUpClass()
|
||||
|
||||
|
||||
def test_popitem(self):
|
||||
d = self._empty_mapping()
|
||||
@@ -1120,21 +1197,22 @@ class SimpleLRUCacheTests:
|
||||
self.assertEqual(list(c), [1, 3, 2])
|
||||
|
||||
|
||||
|
||||
|
||||
-class PySimpleLRUCacheTests(SimpleLRUCacheTests, unittest.TestCase):
|
||||
+class PySimpleLRUCacheTests(SimpleLRUCacheTests, __TestCase):
|
||||
|
||||
|
||||
class type2test(SimpleLRUCache, py_coll.OrderedDict):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
@unittest.skipUnless(c_coll, 'requires the C version of the collections module')
|
||||
-class CSimpleLRUCacheTests(SimpleLRUCacheTests, unittest.TestCase):
|
||||
+class CSimpleLRUCacheTests(SimpleLRUCacheTests, __TestCase):
|
||||
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
class type2test(SimpleLRUCache, c_coll.OrderedDict):
|
||||
pass
|
||||
cls.type2test = type2test
|
||||
+ super().setUpClass()
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ class OrderedDictTests:
|
|||
|
||||
def test_init_calls(self):
|
||||
calls = []
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Spam:
|
||||
def keys(self):
|
||||
calls.append('keys')
|
||||
|
|
@ -187,7 +187,7 @@ class OrderedDictTests:
|
|||
# a consistent internal state is created in __new__
|
||||
# rather than __init__.
|
||||
OrderedDict = self.OrderedDict
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class ODNI(OrderedDict):
|
||||
def __init__(*args, **kwargs):
|
||||
pass
|
||||
|
|
@ -326,7 +326,7 @@ class OrderedDictTests:
|
|||
self.assertEqual(od.pop(k, 12345), 12345)
|
||||
|
||||
# make sure pop still works when __missing__ is defined
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Missing(OrderedDict):
|
||||
def __missing__(self, key):
|
||||
return 0
|
||||
|
|
@ -476,7 +476,7 @@ class OrderedDictTests:
|
|||
self.assertEqual(od.setdefault('g', default=9), 9)
|
||||
|
||||
# make sure setdefault still works when __missing__ is defined
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Missing(OrderedDict):
|
||||
def __missing__(self, key):
|
||||
return 0
|
||||
|
|
@ -545,7 +545,7 @@ class OrderedDictTests:
|
|||
def test_override_update(self):
|
||||
OrderedDict = self.OrderedDict
|
||||
# Verify that subclasses can override update() without breaking __init__()
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyOD(OrderedDict):
|
||||
def update(self, *args, **kwds):
|
||||
raise Exception()
|
||||
|
|
@ -569,7 +569,7 @@ class OrderedDictTests:
|
|||
# should not crash Python.
|
||||
OrderedDict = self.OrderedDict
|
||||
deleted = []
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyOD(OrderedDict):
|
||||
def __del__(self):
|
||||
deleted.append(self.i)
|
||||
|
|
@ -584,7 +584,7 @@ class OrderedDictTests:
|
|||
def test_delitem_hash_collision(self):
|
||||
OrderedDict = self.OrderedDict
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Key:
|
||||
def __init__(self, hash):
|
||||
self._hash = hash
|
||||
|
|
@ -624,7 +624,7 @@ class OrderedDictTests:
|
|||
def test_issue24347(self):
|
||||
OrderedDict = self.OrderedDict
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Key:
|
||||
def __hash__(self):
|
||||
return randrange(100000)
|
||||
|
|
@ -647,7 +647,7 @@ class OrderedDictTests:
|
|||
def test_issue24348(self):
|
||||
OrderedDict = self.OrderedDict
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Key:
|
||||
def __hash__(self):
|
||||
return 1
|
||||
|
|
@ -832,7 +832,7 @@ class PurePythonOrderedDictTests(OrderedDictTests, __TestCase):
|
|||
OrderedDict = py_coll.OrderedDict
|
||||
|
||||
def test_issue119004_attribute_error(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Key(_TriggerSideEffectOnEqual):
|
||||
def side_effect(self):
|
||||
del dict1[TODEL]
|
||||
|
|
@ -875,7 +875,7 @@ class CPythonOrderedDictSideEffects:
|
|||
self.assertRaisesRegex(RuntimeError, msg, operator.eq, dict1, dict2)
|
||||
|
||||
def test_issue119004_change_size_by_clear(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Key(_TriggerSideEffectOnEqual):
|
||||
def side_effect(self):
|
||||
dict1.clear()
|
||||
|
|
@ -888,7 +888,7 @@ class CPythonOrderedDictSideEffects:
|
|||
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
|
||||
|
||||
def test_issue119004_change_size_by_delete_key(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Key(_TriggerSideEffectOnEqual):
|
||||
def side_effect(self):
|
||||
del dict1[TODEL]
|
||||
|
|
@ -902,7 +902,7 @@ class CPythonOrderedDictSideEffects:
|
|||
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
|
||||
|
||||
def test_issue119004_change_linked_list_by_clear(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Key(_TriggerSideEffectOnEqual):
|
||||
def side_effect(self):
|
||||
dict1.clear()
|
||||
|
|
@ -916,7 +916,7 @@ class CPythonOrderedDictSideEffects:
|
|||
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
|
||||
|
||||
def test_issue119004_change_linked_list_by_delete_key(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Key(_TriggerSideEffectOnEqual):
|
||||
def side_effect(self):
|
||||
del dict1[TODEL]
|
||||
|
|
@ -931,7 +931,7 @@ class CPythonOrderedDictSideEffects:
|
|||
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
|
||||
|
||||
def test_issue119004_change_size_by_delete_key_in_dict_eq(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Key(_TriggerSideEffectOnEqual):
|
||||
trigger = 0
|
||||
def side_effect(self):
|
||||
|
|
|
|||
|
|
@ -62,23 +62,23 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
@@ -38,7 +91,7 @@ class HashCountingInt(int):
|
||||
self.hash_count += 1
|
||||
return int.__hash__(self)
|
||||
|
||||
|
||||
-class TestJointOps:
|
||||
+class _TestJointOps:
|
||||
# Tests common to both set and frozenset
|
||||
|
||||
|
||||
def setUp(self):
|
||||
@@ -47,6 +100,7 @@ class TestJointOps:
|
||||
self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||
self.s = self.thetype(word)
|
||||
self.d = dict.fromkeys(word)
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
def test_new_or_init(self):
|
||||
self.assertRaises(TypeError, self.thetype, [], 2)
|
||||
@@ -261,13 +315,14 @@ class TestJointOps:
|
||||
self.assertEqual(self.thetype(it), data - self.thetype((drop,)))
|
||||
|
||||
|
||||
def test_deepcopy(self):
|
||||
- class Tracer:
|
||||
- def __init__(self, value):
|
||||
|
|
@ -87,7 +87,7 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
- return self.value
|
||||
- def __deepcopy__(self, memo=None):
|
||||
- return Tracer(self.value + 1)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Tracer:
|
||||
+ def __init__(self, value):
|
||||
+ self.value = value
|
||||
|
|
@ -99,25 +99,25 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
s = self.thetype([t])
|
||||
dup = copy.deepcopy(s)
|
||||
@@ -279,8 +334,9 @@ class TestJointOps:
|
||||
|
||||
|
||||
def test_gc(self):
|
||||
# Create a nest of cycles to exercise overall ref count check
|
||||
- class A:
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A:
|
||||
+ pass
|
||||
s = set(A() for i in range(1000))
|
||||
for elem in s:
|
||||
elem.cycle = s
|
||||
@@ -289,9 +345,10 @@ class TestJointOps:
|
||||
|
||||
|
||||
def test_subclass_with_custom_hash(self):
|
||||
# Bug #1257731
|
||||
- class H(self.thetype):
|
||||
- def __hash__(self):
|
||||
- return int(id(self) & 0x7fffffff)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class H(self.thetype):
|
||||
+ def __hash__(self):
|
||||
+ return int(id(self) & 0x7fffffff)
|
||||
|
|
@ -125,12 +125,12 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
f=set()
|
||||
f.add(s)
|
||||
@@ -342,8 +399,9 @@ class TestJointOps:
|
||||
|
||||
|
||||
def test_container_iterator(self):
|
||||
# Bug #3680: tp_traverse was not implemented for set iterator object
|
||||
- class C(object):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C(object):
|
||||
+ pass
|
||||
obj = C()
|
||||
|
|
@ -139,15 +139,15 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
@@ -355,7 +413,7 @@ class TestJointOps:
|
||||
def test_free_after_iterating(self):
|
||||
support.check_free_after_iterating(self, iter, self.thetype)
|
||||
|
||||
|
||||
-class TestSet(TestJointOps, unittest.TestCase):
|
||||
+class TestSet(_TestJointOps, __TestCase):
|
||||
thetype = set
|
||||
basetype = set
|
||||
|
||||
|
||||
@@ -600,19 +658,20 @@ class TestSet(TestJointOps, unittest.TestCase):
|
||||
self.assertRaises(ReferenceError, str, p)
|
||||
|
||||
|
||||
def test_rich_compare(self):
|
||||
- class TestRichSetCompare:
|
||||
- def __gt__(self, some_set):
|
||||
|
|
@ -162,7 +162,7 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
- def __le__(self, some_set):
|
||||
- self.le_called = True
|
||||
- return False
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class TestRichSetCompare:
|
||||
+ def __gt__(self, some_set):
|
||||
+ self.gt_called = True
|
||||
|
|
@ -176,16 +176,16 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
+ def __le__(self, some_set):
|
||||
+ self.le_called = True
|
||||
+ return False
|
||||
|
||||
|
||||
# This first tries the builtin rich set comparison, which doesn't know
|
||||
# how to handle the custom object. Upon returning NotImplemented, the
|
||||
@@ -644,28 +703,31 @@ class TestSetSubclass(TestSet):
|
||||
basetype = set
|
||||
|
||||
|
||||
def test_keywords_in_subclass(self):
|
||||
- class subclass(set):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class subclass(set):
|
||||
+ pass
|
||||
u = subclass([1, 2])
|
||||
|
|
@ -193,12 +193,12 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.assertEqual(set(u), {1, 2})
|
||||
with self.assertRaises(TypeError):
|
||||
subclass(sequence=())
|
||||
|
||||
|
||||
- class subclass_with_init(set):
|
||||
- def __init__(self, arg, newarg=None):
|
||||
- super().__init__(arg)
|
||||
- self.newarg = newarg
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class subclass_with_init(set):
|
||||
+ def __init__(self, arg, newarg=None):
|
||||
+ super().__init__(arg)
|
||||
|
|
@ -207,13 +207,13 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.assertIs(type(u), subclass_with_init)
|
||||
self.assertEqual(set(u), {1, 2})
|
||||
self.assertEqual(u.newarg, 3)
|
||||
|
||||
|
||||
- class subclass_with_new(set):
|
||||
- def __new__(cls, arg, newarg=None):
|
||||
- self = super().__new__(cls, arg)
|
||||
- self.newarg = newarg
|
||||
- return self
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class subclass_with_new(set):
|
||||
+ def __new__(cls, arg, newarg=None):
|
||||
+ self = super().__new__(cls, arg)
|
||||
|
|
@ -224,20 +224,20 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.assertEqual(set(u), {1, 2})
|
||||
@@ -675,7 +737,7 @@ class TestSetSubclass(TestSet):
|
||||
subclass_with_new([1, 2], newarg=3)
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestFrozenSet(TestJointOps, unittest.TestCase):
|
||||
+class TestFrozenSet(_TestJointOps, __TestCase):
|
||||
thetype = frozenset
|
||||
basetype = frozenset
|
||||
|
||||
|
||||
@@ -756,27 +818,30 @@ class TestFrozenSetSubclass(TestFrozenSet):
|
||||
basetype = frozenset
|
||||
|
||||
|
||||
def test_keywords_in_subclass(self):
|
||||
- class subclass(frozenset):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class subclass(frozenset):
|
||||
+ pass
|
||||
u = subclass([1, 2])
|
||||
|
|
@ -245,11 +245,11 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.assertEqual(set(u), {1, 2})
|
||||
with self.assertRaises(TypeError):
|
||||
subclass(sequence=())
|
||||
|
||||
|
||||
- class subclass_with_init(frozenset):
|
||||
- def __init__(self, arg, newarg=None):
|
||||
- self.newarg = newarg
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class subclass_with_init(frozenset):
|
||||
+ def __init__(self, arg, newarg=None):
|
||||
+ self.newarg = newarg
|
||||
|
|
@ -257,13 +257,13 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.assertIs(type(u), subclass_with_init)
|
||||
self.assertEqual(set(u), {1, 2})
|
||||
self.assertEqual(u.newarg, 3)
|
||||
|
||||
|
||||
- class subclass_with_new(frozenset):
|
||||
- def __new__(cls, arg, newarg=None):
|
||||
- self = super().__new__(cls, arg)
|
||||
- self.newarg = newarg
|
||||
- return self
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class subclass_with_new(frozenset):
|
||||
+ def __new__(cls, arg, newarg=None):
|
||||
+ self = super().__new__(cls, arg)
|
||||
|
|
@ -275,7 +275,7 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
@@ -811,10 +876,17 @@ class TestFrozenSetSubclass(TestFrozenSet):
|
||||
class SetSubclassWithSlots(set):
|
||||
__slots__ = ('x', 'y', '__dict__')
|
||||
|
||||
|
||||
-class TestSetSubclassWithSlots(unittest.TestCase):
|
||||
+class TestSetSubclassWithSlots(__TestCase):
|
||||
thetype = SetSubclassWithSlots
|
||||
|
|
@ -290,22 +290,22 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
+ self.s = self.thetype(word)
|
||||
+ self.d = dict.fromkeys(word)
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
class FrozenSetSubclassWithSlots(frozenset):
|
||||
__slots__ = ('x', 'y', '__dict__')
|
||||
@@ -828,7 +900,7 @@ empty_set = set()
|
||||
|
||||
|
||||
#==============================================================================
|
||||
|
||||
|
||||
-class TestBasicOps:
|
||||
+class _TestBasicOps:
|
||||
|
||||
|
||||
def test_repr(self):
|
||||
if self.repr is not None:
|
||||
@@ -934,7 +1006,7 @@ class TestBasicOps:
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase):
|
||||
+class TestBasicOpsEmpty(_TestBasicOps, __TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -316,9 +316,9 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.length = 0
|
||||
self.repr = "set()"
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
|
||||
+class TestBasicOpsSingleton(_TestBasicOps, __TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -329,13 +329,13 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.length = 1
|
||||
self.repr = "{3}"
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
def test_in(self):
|
||||
self.assertIn(3, self.set)
|
||||
@@ -962,7 +1036,7 @@ class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
|
||||
+class TestBasicOpsTuple(_TestBasicOps, __TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -346,13 +346,13 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.length = 1
|
||||
self.repr = "{(0, 'zero')}"
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
def test_in(self):
|
||||
self.assertIn((0, "zero"), self.set)
|
||||
@@ -979,7 +1054,7 @@ class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestBasicOpsTriple(TestBasicOps, unittest.TestCase):
|
||||
+class TestBasicOpsTriple(_TestBasicOps, __TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -363,9 +363,9 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.length = 3
|
||||
self.repr = None
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestBasicOpsString(TestBasicOps, unittest.TestCase):
|
||||
+class TestBasicOpsString(_TestBasicOps, __TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -375,12 +375,12 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.dup = set(self.values)
|
||||
self.length = 3
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
def test_repr(self):
|
||||
self.check_repr_against_values()
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestBasicOpsBytes(TestBasicOps, unittest.TestCase):
|
||||
+class TestBasicOpsBytes(_TestBasicOps, __TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -390,12 +390,12 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.dup = set(self.values)
|
||||
self.length = 3
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
def test_repr(self):
|
||||
self.check_repr_against_values()
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
|
||||
+class TestBasicOpsMixedStringBytes(_TestBasicOps, __TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -406,71 +406,71 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.dup = set(self.values)
|
||||
self.length = 4
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
def test_repr(self):
|
||||
self.check_repr_against_values()
|
||||
@@ -1038,7 +1117,7 @@ def baditer():
|
||||
def gooditer():
|
||||
yield True
|
||||
|
||||
|
||||
-class TestExceptionPropagation(unittest.TestCase):
|
||||
+class TestExceptionPropagation(__TestCase):
|
||||
"""SF 628246: Set constructor should not trap iterator TypeErrors"""
|
||||
|
||||
|
||||
def test_instanceWithException(self):
|
||||
@@ -1065,7 +1144,7 @@ class TestExceptionPropagation(unittest.TestCase):
|
||||
|
||||
|
||||
#==============================================================================
|
||||
|
||||
|
||||
-class TestSetOfSets(unittest.TestCase):
|
||||
+class TestSetOfSets(__TestCase):
|
||||
def test_constructor(self):
|
||||
inner = frozenset([1])
|
||||
outer = set([inner])
|
||||
@@ -1078,9 +1157,10 @@ class TestSetOfSets(unittest.TestCase):
|
||||
|
||||
|
||||
#==============================================================================
|
||||
|
||||
|
||||
-class TestBinaryOps(unittest.TestCase):
|
||||
+class TestBinaryOps(__TestCase):
|
||||
def setUp(self):
|
||||
self.set = set((2, 4, 6))
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
def test_eq(self): # SF bug 643115
|
||||
self.assertEqual(self.set, set({2:1,4:3,6:5}))
|
||||
@@ -1151,9 +1231,10 @@ class TestBinaryOps(unittest.TestCase):
|
||||
|
||||
|
||||
#==============================================================================
|
||||
|
||||
|
||||
-class TestUpdateOps(unittest.TestCase):
|
||||
+class TestUpdateOps(__TestCase):
|
||||
def setUp(self):
|
||||
self.set = set((2, 4, 6))
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
def test_union_subset(self):
|
||||
self.set |= set([2])
|
||||
@@ -1237,10 +1318,11 @@ class TestUpdateOps(unittest.TestCase):
|
||||
|
||||
|
||||
#==============================================================================
|
||||
|
||||
|
||||
-class TestMutate(unittest.TestCase):
|
||||
+class TestMutate(__TestCase):
|
||||
def setUp(self):
|
||||
self.values = ["a", "b", "c"]
|
||||
self.set = set(self.values)
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
def test_add_present(self):
|
||||
self.set.add("c")
|
||||
@@ -1311,7 +1393,7 @@ class TestMutate(unittest.TestCase):
|
||||
|
||||
|
||||
#==============================================================================
|
||||
|
||||
|
||||
-class TestSubsets:
|
||||
+class _TestSubsets:
|
||||
|
||||
|
||||
case2method = {"<=": "issubset",
|
||||
">=": "issuperset",
|
||||
@@ -1334,22 +1416,22 @@ class TestSubsets:
|
||||
|
|
@ -483,7 +483,7 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
+ method = getattr(x, _TestSubsets.case2method[case])
|
||||
result = method(y)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
|
||||
# Now do the same for the operands reversed.
|
||||
- rcase = TestSubsets.reverse[case]
|
||||
+ rcase = _TestSubsets.reverse[case]
|
||||
|
|
@ -496,61 +496,61 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
result = method(x)
|
||||
self.assertEqual(result, expected)
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase):
|
||||
+class TestSubsetEqualEmpty(_TestSubsets, __TestCase):
|
||||
left = set()
|
||||
right = set()
|
||||
name = "both empty"
|
||||
@@ -1357,7 +1439,7 @@ class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase):
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase):
|
||||
+class TestSubsetEqualNonEmpty(_TestSubsets, __TestCase):
|
||||
left = set([1, 2])
|
||||
right = set([1, 2])
|
||||
name = "equal pair"
|
||||
@@ -1365,7 +1447,7 @@ class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase):
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase):
|
||||
+class TestSubsetEmptyNonEmpty(_TestSubsets, __TestCase):
|
||||
left = set()
|
||||
right = set([1, 2])
|
||||
name = "one empty, one non-empty"
|
||||
@@ -1373,7 +1455,7 @@ class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase):
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestSubsetPartial(TestSubsets, unittest.TestCase):
|
||||
+class TestSubsetPartial(_TestSubsets, __TestCase):
|
||||
left = set([1])
|
||||
right = set([1, 2])
|
||||
name = "one a non-empty proper subset of other"
|
||||
@@ -1381,7 +1463,7 @@ class TestSubsetPartial(TestSubsets, unittest.TestCase):
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestSubsetNonOverlap(TestSubsets, unittest.TestCase):
|
||||
+class TestSubsetNonOverlap(_TestSubsets, __TestCase):
|
||||
left = set([1])
|
||||
right = set([2])
|
||||
name = "neither empty, neither contains"
|
||||
@@ -1389,7 +1471,7 @@ class TestSubsetNonOverlap(TestSubsets, unittest.TestCase):
|
||||
|
||||
|
||||
#==============================================================================
|
||||
|
||||
|
||||
-class TestOnlySetsInBinaryOps:
|
||||
+class _TestOnlySetsInBinaryOps:
|
||||
|
||||
|
||||
def test_eq_ne(self):
|
||||
# Unlike the others, this is testing that == and != *are* allowed.
|
||||
@@ -1505,47 +1587,52 @@ class TestOnlySetsInBinaryOps:
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestOnlySetsNumeric(TestOnlySetsInBinaryOps, unittest.TestCase):
|
||||
+class TestOnlySetsNumeric(_TestOnlySetsInBinaryOps, __TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -558,9 +558,9 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.other = 19
|
||||
self.otherIsIterable = False
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestOnlySetsDict(TestOnlySetsInBinaryOps, unittest.TestCase):
|
||||
+class TestOnlySetsDict(_TestOnlySetsInBinaryOps, __TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -568,9 +568,9 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.other = {1:2, 3:4}
|
||||
self.otherIsIterable = True
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestOnlySetsOperator(TestOnlySetsInBinaryOps, unittest.TestCase):
|
||||
+class TestOnlySetsOperator(_TestOnlySetsInBinaryOps, __TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -578,9 +578,9 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.other = operator.add
|
||||
self.otherIsIterable = False
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestOnlySetsTuple(TestOnlySetsInBinaryOps, unittest.TestCase):
|
||||
+class TestOnlySetsTuple(_TestOnlySetsInBinaryOps, __TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -588,9 +588,9 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.other = (2, 4, 6)
|
||||
self.otherIsIterable = True
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestOnlySetsString(TestOnlySetsInBinaryOps, unittest.TestCase):
|
||||
+class TestOnlySetsString(_TestOnlySetsInBinaryOps, __TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -598,9 +598,9 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.other = 'abc'
|
||||
self.otherIsIterable = True
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase):
|
||||
+class TestOnlySetsGenerator(_TestOnlySetsInBinaryOps, __TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -611,80 +611,80 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
self.other = gen()
|
||||
self.otherIsIterable = True
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
#==============================================================================
|
||||
|
||||
|
||||
-class TestCopying:
|
||||
+class _TestCopying:
|
||||
|
||||
|
||||
def test_copy(self):
|
||||
dup = self.set.copy()
|
||||
@@ -1577,40 +1665,46 @@ class TestCopying:
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestCopyingEmpty(TestCopying, unittest.TestCase):
|
||||
+class TestCopyingEmpty(_TestCopying, __TestCase):
|
||||
def setUp(self):
|
||||
self.set = set()
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestCopyingSingleton(TestCopying, unittest.TestCase):
|
||||
+class TestCopyingSingleton(_TestCopying, __TestCase):
|
||||
def setUp(self):
|
||||
self.set = set(["hello"])
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestCopyingTriple(TestCopying, unittest.TestCase):
|
||||
+class TestCopyingTriple(_TestCopying, __TestCase):
|
||||
def setUp(self):
|
||||
self.set = set(["zero", 0, None])
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestCopyingTuple(TestCopying, unittest.TestCase):
|
||||
+class TestCopyingTuple(_TestCopying, __TestCase):
|
||||
def setUp(self):
|
||||
self.set = set([(1, 2)])
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-class TestCopyingNested(TestCopying, unittest.TestCase):
|
||||
+class TestCopyingNested(_TestCopying, __TestCase):
|
||||
def setUp(self):
|
||||
self.set = set([((1, 2), (3, 4))])
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
#==============================================================================
|
||||
|
||||
|
||||
-class TestIdentities(unittest.TestCase):
|
||||
+class TestIdentities(__TestCase):
|
||||
def setUp(self):
|
||||
self.a = set('abracadabra')
|
||||
self.b = set('alacazam')
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
def test_binopsVsSubsets(self):
|
||||
a, b = self.a, self.b
|
||||
@@ -1727,7 +1821,7 @@ def L(seqn):
|
||||
'Test multiple tiers of iterators'
|
||||
return chain(map(lambda x:x, R(Ig(G(seqn)))))
|
||||
|
||||
|
||||
-class TestVariousIteratorArgs(unittest.TestCase):
|
||||
+class TestVariousIteratorArgs(__TestCase):
|
||||
|
||||
|
||||
def test_constructor(self):
|
||||
for cons in (set, frozenset):
|
||||
@@ -1785,7 +1879,7 @@ class bad_dict_clear:
|
||||
def __hash__(self):
|
||||
return 0
|
||||
|
||||
|
||||
-class TestWeirdBugs(unittest.TestCase):
|
||||
+class TestWeirdBugs(__TestCase):
|
||||
def test_8420_set_merge(self):
|
||||
|
|
@ -692,7 +692,7 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
global be_bad, set2, dict2
|
||||
@@ -1813,12 +1907,13 @@ class TestWeirdBugs(unittest.TestCase):
|
||||
list(si)
|
||||
|
||||
|
||||
def test_merge_and_mutate(self):
|
||||
- class X:
|
||||
- def __hash__(self):
|
||||
|
|
@ -700,27 +700,27 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
- def __eq__(self, o):
|
||||
- other.clear()
|
||||
- return False
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class X:
|
||||
+ def __hash__(self):
|
||||
+ return hash(0)
|
||||
+ def __eq__(self, o):
|
||||
+ other.clear()
|
||||
+ return False
|
||||
|
||||
|
||||
other = set()
|
||||
other = {X() for i in range(10)}
|
||||
@@ -1826,24 +1921,25 @@ class TestWeirdBugs(unittest.TestCase):
|
||||
s.update(other)
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestOperationsMutating:
|
||||
+class _TestOperationsMutating:
|
||||
"""Regression test for bpo-46615"""
|
||||
|
||||
|
||||
constructor1 = None
|
||||
constructor2 = None
|
||||
|
||||
|
||||
def make_sets_of_bad_objects(self):
|
||||
- class Bad:
|
||||
- def __eq__(self, other):
|
||||
|
|
@ -733,7 +733,7 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
- return bool(randrange(2))
|
||||
- def __hash__(self):
|
||||
- return randrange(2)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Bad:
|
||||
+ def __eq__(self, other):
|
||||
+ if not enabled:
|
||||
|
|
@ -750,89 +750,89 @@ index d9102eb98a5..c8ee5ca451f 100644
|
|||
set1 = self.constructor1(Bad() for _ in range(randrange(50)))
|
||||
@@ -1862,7 +1958,7 @@ class TestOperationsMutating:
|
||||
self.assertIn("changed size during iteration", str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestBinaryOpsMutating(TestOperationsMutating):
|
||||
+class _TestBinaryOpsMutating(_TestOperationsMutating):
|
||||
|
||||
|
||||
def test_eq_with_mutation(self):
|
||||
self.check_set_op_does_not_crash(lambda a, b: a == b)
|
||||
@@ -1933,24 +2029,24 @@ class TestBinaryOpsMutating(TestOperationsMutating):
|
||||
self.check_set_op_does_not_crash(f3)
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestBinaryOpsMutating_Set_Set(TestBinaryOpsMutating, unittest.TestCase):
|
||||
+class TestBinaryOpsMutating_Set_Set(_TestBinaryOpsMutating, __TestCase):
|
||||
constructor1 = set
|
||||
constructor2 = set
|
||||
|
||||
|
||||
-class TestBinaryOpsMutating_Subclass_Subclass(TestBinaryOpsMutating, unittest.TestCase):
|
||||
+class TestBinaryOpsMutating_Subclass_Subclass(_TestBinaryOpsMutating, __TestCase):
|
||||
constructor1 = SetSubclass
|
||||
constructor2 = SetSubclass
|
||||
|
||||
|
||||
-class TestBinaryOpsMutating_Set_Subclass(TestBinaryOpsMutating, unittest.TestCase):
|
||||
+class TestBinaryOpsMutating_Set_Subclass(_TestBinaryOpsMutating, __TestCase):
|
||||
constructor1 = set
|
||||
constructor2 = SetSubclass
|
||||
|
||||
|
||||
-class TestBinaryOpsMutating_Subclass_Set(TestBinaryOpsMutating, unittest.TestCase):
|
||||
+class TestBinaryOpsMutating_Subclass_Set(_TestBinaryOpsMutating, __TestCase):
|
||||
constructor1 = SetSubclass
|
||||
constructor2 = set
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestMethodsMutating(TestOperationsMutating):
|
||||
+class _TestMethodsMutating(_TestOperationsMutating):
|
||||
|
||||
|
||||
def test_issubset_with_mutation(self):
|
||||
self.check_set_op_does_not_crash(set.issubset)
|
||||
@@ -1986,27 +2082,27 @@ class TestMethodsMutating(TestOperationsMutating):
|
||||
self.check_set_op_does_not_crash(set.update)
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestMethodsMutating_Set_Set(TestMethodsMutating, unittest.TestCase):
|
||||
+class TestMethodsMutating_Set_Set(_TestMethodsMutating, __TestCase):
|
||||
constructor1 = set
|
||||
constructor2 = set
|
||||
|
||||
|
||||
-class TestMethodsMutating_Subclass_Subclass(TestMethodsMutating, unittest.TestCase):
|
||||
+class TestMethodsMutating_Subclass_Subclass(_TestMethodsMutating, __TestCase):
|
||||
constructor1 = SetSubclass
|
||||
constructor2 = SetSubclass
|
||||
|
||||
|
||||
-class TestMethodsMutating_Set_Subclass(TestMethodsMutating, unittest.TestCase):
|
||||
+class TestMethodsMutating_Set_Subclass(_TestMethodsMutating, __TestCase):
|
||||
constructor1 = set
|
||||
constructor2 = SetSubclass
|
||||
|
||||
|
||||
-class TestMethodsMutating_Subclass_Set(TestMethodsMutating, unittest.TestCase):
|
||||
+class TestMethodsMutating_Subclass_Set(_TestMethodsMutating, __TestCase):
|
||||
constructor1 = SetSubclass
|
||||
constructor2 = set
|
||||
|
||||
|
||||
-class TestMethodsMutating_Set_Dict(TestMethodsMutating, unittest.TestCase):
|
||||
+class TestMethodsMutating_Set_Dict(_TestMethodsMutating, __TestCase):
|
||||
constructor1 = set
|
||||
constructor2 = dict.fromkeys
|
||||
|
||||
|
||||
-class TestMethodsMutating_Set_List(TestMethodsMutating, unittest.TestCase):
|
||||
+class TestMethodsMutating_Set_List(_TestMethodsMutating, __TestCase):
|
||||
constructor1 = set
|
||||
constructor2 = list
|
||||
|
||||
|
||||
@@ -2068,7 +2164,7 @@ def faces(G):
|
||||
return f
|
||||
|
||||
|
||||
|
||||
|
||||
-class TestGraphs(unittest.TestCase):
|
||||
+class TestGraphs(__TestCase):
|
||||
|
||||
|
||||
def test_cube(self):
|
||||
|
||||
|
||||
@@ -2118,4 +2214,4 @@ class TestGraphs(unittest.TestCase):
|
||||
#==============================================================================
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
|
|
|
|||
|
|
@ -315,7 +315,7 @@ class _TestJointOps:
|
|||
self.assertEqual(self.thetype(it), data - self.thetype((drop,)))
|
||||
|
||||
def test_deepcopy(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Tracer:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
|
@ -334,7 +334,7 @@ class _TestJointOps:
|
|||
|
||||
def test_gc(self):
|
||||
# Create a nest of cycles to exercise overall ref count check
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A:
|
||||
pass
|
||||
s = set(A() for i in range(1000))
|
||||
|
|
@ -345,7 +345,7 @@ class _TestJointOps:
|
|||
|
||||
def test_subclass_with_custom_hash(self):
|
||||
# Bug #1257731
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class H(self.thetype):
|
||||
def __hash__(self):
|
||||
return int(id(self) & 0x7fffffff)
|
||||
|
|
@ -399,7 +399,7 @@ class _TestJointOps:
|
|||
|
||||
def test_container_iterator(self):
|
||||
# Bug #3680: tp_traverse was not implemented for set iterator object
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C(object):
|
||||
pass
|
||||
obj = C()
|
||||
|
|
@ -658,7 +658,7 @@ class TestSet(_TestJointOps, __TestCase):
|
|||
self.assertRaises(ReferenceError, str, p)
|
||||
|
||||
def test_rich_compare(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class TestRichSetCompare:
|
||||
def __gt__(self, some_set):
|
||||
self.gt_called = True
|
||||
|
|
@ -703,7 +703,7 @@ class TestSetSubclass(TestSet):
|
|||
basetype = set
|
||||
|
||||
def test_keywords_in_subclass(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class subclass(set):
|
||||
pass
|
||||
u = subclass([1, 2])
|
||||
|
|
@ -712,7 +712,7 @@ class TestSetSubclass(TestSet):
|
|||
with self.assertRaises(TypeError):
|
||||
subclass(sequence=())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class subclass_with_init(set):
|
||||
def __init__(self, arg, newarg=None):
|
||||
super().__init__(arg)
|
||||
|
|
@ -722,7 +722,7 @@ class TestSetSubclass(TestSet):
|
|||
self.assertEqual(set(u), {1, 2})
|
||||
self.assertEqual(u.newarg, 3)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class subclass_with_new(set):
|
||||
def __new__(cls, arg, newarg=None):
|
||||
self = super().__new__(cls, arg)
|
||||
|
|
@ -818,7 +818,7 @@ class TestFrozenSetSubclass(TestFrozenSet):
|
|||
basetype = frozenset
|
||||
|
||||
def test_keywords_in_subclass(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class subclass(frozenset):
|
||||
pass
|
||||
u = subclass([1, 2])
|
||||
|
|
@ -827,7 +827,7 @@ class TestFrozenSetSubclass(TestFrozenSet):
|
|||
with self.assertRaises(TypeError):
|
||||
subclass(sequence=())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class subclass_with_init(frozenset):
|
||||
def __init__(self, arg, newarg=None):
|
||||
self.newarg = newarg
|
||||
|
|
@ -836,7 +836,7 @@ class TestFrozenSetSubclass(TestFrozenSet):
|
|||
self.assertEqual(set(u), {1, 2})
|
||||
self.assertEqual(u.newarg, 3)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class subclass_with_new(frozenset):
|
||||
def __new__(cls, arg, newarg=None):
|
||||
self = super().__new__(cls, arg)
|
||||
|
|
@ -1907,7 +1907,7 @@ class TestWeirdBugs(__TestCase):
|
|||
list(si)
|
||||
|
||||
def test_merge_and_mutate(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class X:
|
||||
def __hash__(self):
|
||||
return hash(0)
|
||||
|
|
@ -1928,7 +1928,7 @@ class _TestOperationsMutating:
|
|||
constructor2 = None
|
||||
|
||||
def make_sets_of_bad_objects(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Bad:
|
||||
def __eq__(self, other):
|
||||
if not enabled:
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
|
|||
@@ -39,7 +93,7 @@ def check(tag, expected, raw, compare=None):
|
||||
nerrors += 1
|
||||
return
|
||||
|
||||
|
||||
-class TestBase(unittest.TestCase):
|
||||
+class TestBase(__TestCase):
|
||||
def testStressfully(self):
|
||||
|
|
@ -72,18 +72,18 @@ index 2a7cfb7affa..4805f1fcceb 100644
|
|||
@@ -48,32 +102,33 @@ class TestBase(unittest.TestCase):
|
||||
sizes.extend(range(n-1, n+2))
|
||||
sizes.extend([10, 100, 1000])
|
||||
|
||||
|
||||
- class Complains(object):
|
||||
- maybe_complain = True
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Complains(object):
|
||||
+ maybe_complain = True
|
||||
|
||||
|
||||
- def __init__(self, i):
|
||||
- self.i = i
|
||||
+ def __init__(self, i):
|
||||
+ self.i = i
|
||||
|
||||
|
||||
- def __lt__(self, other):
|
||||
- if Complains.maybe_complain and random.random() < 0.001:
|
||||
- if verbose:
|
||||
|
|
@ -96,12 +96,12 @@ index 2a7cfb7affa..4805f1fcceb 100644
|
|||
+ print(" complaining at", self, other)
|
||||
+ raise RuntimeError
|
||||
+ return self.i < other.i
|
||||
|
||||
|
||||
- def __repr__(self):
|
||||
- return "Complains(%d)" % self.i
|
||||
+ def __repr__(self):
|
||||
+ return "Complains(%d)" % self.i
|
||||
|
||||
|
||||
- class Stable(object):
|
||||
- def __init__(self, key, i):
|
||||
- self.key = key
|
||||
|
|
@ -110,31 +110,31 @@ index 2a7cfb7affa..4805f1fcceb 100644
|
|||
+ def __init__(self, key, i):
|
||||
+ self.key = key
|
||||
+ self.index = i
|
||||
|
||||
|
||||
- def __lt__(self, other):
|
||||
- return self.key < other.key
|
||||
+ def __lt__(self, other):
|
||||
+ return self.key < other.key
|
||||
|
||||
|
||||
- def __repr__(self):
|
||||
- return "Stable(%d, %d)" % (self.key, self.index)
|
||||
+ def __repr__(self):
|
||||
+ return "Stable(%d, %d)" % (self.key, self.index)
|
||||
|
||||
|
||||
for n in sizes:
|
||||
x = list(range(n))
|
||||
@@ -151,20 +206,21 @@ class TestBase(unittest.TestCase):
|
||||
self.assertEqual(forced, native)
|
||||
#==============================================================================
|
||||
|
||||
|
||||
-class TestBugs(unittest.TestCase):
|
||||
+class TestBugs(__TestCase):
|
||||
|
||||
|
||||
def test_bug453523(self):
|
||||
# bug 453523 -- list.sort() crasher.
|
||||
# If this fails, the most likely outcome is a core dump.
|
||||
# Mutations during a list sort should raise a ValueError.
|
||||
|
||||
|
||||
- class C:
|
||||
- def __lt__(self, other):
|
||||
- if L and random.random() < 0.75:
|
||||
|
|
@ -142,7 +142,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
|
|||
- else:
|
||||
- L.append(3)
|
||||
- return random.random() < 0.5
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C:
|
||||
+ def __lt__(self, other):
|
||||
+ if L and random.random() < 0.75:
|
||||
|
|
@ -150,20 +150,20 @@ index 2a7cfb7affa..4805f1fcceb 100644
|
|||
+ else:
|
||||
+ L.append(3)
|
||||
+ return random.random() < 0.5
|
||||
|
||||
|
||||
L = [C() for i in range(50)]
|
||||
self.assertRaises(ValueError, L.sort)
|
||||
@@ -188,7 +244,7 @@ class TestBugs(unittest.TestCase):
|
||||
|
||||
|
||||
#==============================================================================
|
||||
|
||||
|
||||
-class TestDecorateSortUndecorate(unittest.TestCase):
|
||||
+class TestDecorateSortUndecorate(__TestCase):
|
||||
|
||||
|
||||
def test_decorated(self):
|
||||
data = 'The quick Brown fox Jumped over The lazy Dog'.split()
|
||||
@@ -228,26 +284,28 @@ class TestDecorateSortUndecorate(unittest.TestCase):
|
||||
|
||||
|
||||
def test_key_with_mutating_del(self):
|
||||
data = list(range(10))
|
||||
- class SortKiller(object):
|
||||
|
|
@ -174,7 +174,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
|
|||
- data[:] = range(20)
|
||||
- def __lt__(self, other):
|
||||
- return id(self) < id(other)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class SortKiller(object):
|
||||
+ def __init__(self, x):
|
||||
+ pass
|
||||
|
|
@ -184,7 +184,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
|
|||
+ def __lt__(self, other):
|
||||
+ return id(self) < id(other)
|
||||
self.assertRaises(ValueError, data.sort, key=SortKiller)
|
||||
|
||||
|
||||
def test_key_with_mutating_del_and_exception(self):
|
||||
data = list(range(10))
|
||||
## dup = data[:]
|
||||
|
|
@ -195,7 +195,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
|
|||
- def __del__(self):
|
||||
- del data[:]
|
||||
- data[:] = list(range(20))
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class SortKiller(object):
|
||||
+ def __init__(self, x):
|
||||
+ if x > 2:
|
||||
|
|
@ -209,7 +209,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
|
|||
@@ -309,7 +367,7 @@ def check_against_PyObject_RichCompareBool(self, L):
|
||||
self.assertIs(opt, ref)
|
||||
#note: not assertEqual! We want to ensure *identical* behavior.
|
||||
|
||||
|
||||
-class TestOptimizedCompares(unittest.TestCase):
|
||||
+class TestOptimizedCompares(__TestCase):
|
||||
def test_safe_object_compare(self):
|
||||
|
|
@ -218,39 +218,39 @@ index 2a7cfb7affa..4805f1fcceb 100644
|
|||
@@ -331,17 +389,18 @@ class TestOptimizedCompares(unittest.TestCase):
|
||||
# This test is by ppperry. It ensures that unsafe_object_compare is
|
||||
# verifying ms->key_richcompare == tp->richcompare before comparing.
|
||||
|
||||
|
||||
- class WackyComparator(int):
|
||||
- def __lt__(self, other):
|
||||
- elem.__class__ = WackyList2
|
||||
- return int.__lt__(self, other)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class WackyComparator(int):
|
||||
+ def __lt__(self, other):
|
||||
+ elem.__class__ = WackyList2
|
||||
+ return int.__lt__(self, other)
|
||||
|
||||
|
||||
- class WackyList1(list):
|
||||
- pass
|
||||
+ class WackyList1(list):
|
||||
+ pass
|
||||
|
||||
|
||||
- class WackyList2(list):
|
||||
- def __lt__(self, other):
|
||||
- raise ValueError
|
||||
+ class WackyList2(list):
|
||||
+ def __lt__(self, other):
|
||||
+ raise ValueError
|
||||
|
||||
|
||||
L = [WackyList1([WackyComparator(i), i]) for i in range(10)]
|
||||
elem = L[-1]
|
||||
@@ -355,9 +414,10 @@ class TestOptimizedCompares(unittest.TestCase):
|
||||
|
||||
|
||||
# The following test is also by ppperry. It ensures that
|
||||
# unsafe_object_compare handles Py_NotImplemented appropriately.
|
||||
- class PointlessComparator:
|
||||
- def __lt__(self, other):
|
||||
- return NotImplemented
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class PointlessComparator:
|
||||
+ def __lt__(self, other):
|
||||
+ return NotImplemented
|
||||
|
|
@ -259,7 +259,7 @@ index 2a7cfb7affa..4805f1fcceb 100644
|
|||
self.assertRaises(TypeError, [(x,) for x in L].sort)
|
||||
@@ -408,4 +468,4 @@ class TestOptimizedCompares(unittest.TestCase):
|
||||
#==============================================================================
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ class TestBase(__TestCase):
|
|||
sizes.extend(range(n-1, n+2))
|
||||
sizes.extend([10, 100, 1000])
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Complains(object):
|
||||
maybe_complain = True
|
||||
|
||||
|
|
@ -213,7 +213,7 @@ class TestBugs(__TestCase):
|
|||
# If this fails, the most likely outcome is a core dump.
|
||||
# Mutations during a list sort should raise a ValueError.
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
def __lt__(self, other):
|
||||
if L and random.random() < 0.75:
|
||||
|
|
@ -284,7 +284,7 @@ class TestDecorateSortUndecorate(__TestCase):
|
|||
|
||||
def test_key_with_mutating_del(self):
|
||||
data = list(range(10))
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class SortKiller(object):
|
||||
def __init__(self, x):
|
||||
pass
|
||||
|
|
@ -298,7 +298,7 @@ class TestDecorateSortUndecorate(__TestCase):
|
|||
def test_key_with_mutating_del_and_exception(self):
|
||||
data = list(range(10))
|
||||
## dup = data[:]
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class SortKiller(object):
|
||||
def __init__(self, x):
|
||||
if x > 2:
|
||||
|
|
@ -389,7 +389,7 @@ class TestOptimizedCompares(__TestCase):
|
|||
# This test is by ppperry. It ensures that unsafe_object_compare is
|
||||
# verifying ms->key_richcompare == tp->richcompare before comparing.
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class WackyComparator(int):
|
||||
def __lt__(self, other):
|
||||
elem.__class__ = WackyList2
|
||||
|
|
@ -414,7 +414,7 @@ class TestOptimizedCompares(__TestCase):
|
|||
|
||||
# The following test is also by ppperry. It ensures that
|
||||
# unsafe_object_compare handles Py_NotImplemented appropriately.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class PointlessComparator:
|
||||
def __lt__(self, other):
|
||||
return NotImplemented
|
||||
|
|
|
|||
|
|
@ -60,15 +60,15 @@ index 9ce80c5e8ea..1080e85e31a 100644
|
|||
+from test import support
|
||||
+import seq_tests
|
||||
import unittest
|
||||
|
||||
|
||||
import gc
|
||||
@@ -43,27 +97,30 @@ class TupleTest(seq_tests.CommonTest):
|
||||
tuple(sequence=())
|
||||
|
||||
|
||||
def test_keywords_in_subclass(self):
|
||||
- class subclass(tuple):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class subclass(tuple):
|
||||
+ pass
|
||||
u = subclass([1, 2])
|
||||
|
|
@ -76,11 +76,11 @@ index 9ce80c5e8ea..1080e85e31a 100644
|
|||
self.assertEqual(list(u), [1, 2])
|
||||
with self.assertRaises(TypeError):
|
||||
subclass(sequence=())
|
||||
|
||||
|
||||
- class subclass_with_init(tuple):
|
||||
- def __init__(self, arg, newarg=None):
|
||||
- self.newarg = newarg
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class subclass_with_init(tuple):
|
||||
+ def __init__(self, arg, newarg=None):
|
||||
+ self.newarg = newarg
|
||||
|
|
@ -88,13 +88,13 @@ index 9ce80c5e8ea..1080e85e31a 100644
|
|||
self.assertIs(type(u), subclass_with_init)
|
||||
self.assertEqual(list(u), [1, 2])
|
||||
self.assertEqual(u.newarg, 3)
|
||||
|
||||
|
||||
- class subclass_with_new(tuple):
|
||||
- def __new__(cls, arg, newarg=None):
|
||||
- self = super().__new__(cls, arg)
|
||||
- self.newarg = newarg
|
||||
- return self
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class subclass_with_new(tuple):
|
||||
+ def __new__(cls, arg, newarg=None):
|
||||
+ self = super().__new__(cls, arg)
|
||||
|
|
@ -109,25 +109,25 @@ index 9ce80c5e8ea..1080e85e31a 100644
|
|||
# Tuple subtypes must always be tracked
|
||||
- class MyTuple(tuple):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyTuple(tuple):
|
||||
+ pass
|
||||
self.check_track_dynamic(MyTuple, True)
|
||||
|
||||
|
||||
@support.cpython_only
|
||||
@@ -404,7 +462,8 @@ class TupleTest(seq_tests.CommonTest):
|
||||
# Issue 8847: In the PGO build, the MSVC linker's COMDAT folding
|
||||
# optimization causes failures in code that relies on distinct
|
||||
# function addresses.
|
||||
- class T(tuple): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class T(tuple): pass
|
||||
with self.assertRaises(TypeError):
|
||||
[3,] + T((1,2))
|
||||
|
||||
|
||||
@@ -510,4 +569,4 @@ class TupleTest(seq_tests.CommonTest):
|
||||
# pileup 262,143 mean 8.0 coll 262,143 z +92683.6
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ class TupleTest(seq_tests.CommonTest):
|
|||
tuple(sequence=())
|
||||
|
||||
def test_keywords_in_subclass(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class subclass(tuple):
|
||||
pass
|
||||
u = subclass([1, 2])
|
||||
|
|
@ -106,7 +106,7 @@ class TupleTest(seq_tests.CommonTest):
|
|||
with self.assertRaises(TypeError):
|
||||
subclass(sequence=())
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class subclass_with_init(tuple):
|
||||
def __init__(self, arg, newarg=None):
|
||||
self.newarg = newarg
|
||||
|
|
@ -115,7 +115,7 @@ class TupleTest(seq_tests.CommonTest):
|
|||
self.assertEqual(list(u), [1, 2])
|
||||
self.assertEqual(u.newarg, 3)
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class subclass_with_new(tuple):
|
||||
def __new__(cls, arg, newarg=None):
|
||||
self = super().__new__(cls, arg)
|
||||
|
|
@ -408,7 +408,7 @@ class TupleTest(seq_tests.CommonTest):
|
|||
@support.cpython_only
|
||||
def test_track_subtypes(self):
|
||||
# Tuple subtypes must always be tracked
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyTuple(tuple):
|
||||
pass
|
||||
self.check_track_dynamic(MyTuple, True)
|
||||
|
|
@ -462,7 +462,7 @@ class TupleTest(seq_tests.CommonTest):
|
|||
# Issue 8847: In the PGO build, the MSVC linker's COMDAT folding
|
||||
# optimization causes failures in code that relies on distinct
|
||||
# function addresses.
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class T(tuple): pass
|
||||
with self.assertRaises(TypeError):
|
||||
[3,] + T((1,2))
|
||||
|
|
|
|||
|
|
@ -58,29 +58,29 @@ index 312702c8e39..d3d8dbf394a 100644
|
|||
+# ======= END DYNAMO PATCH =======
|
||||
+
|
||||
# Check every path through every method of UserList
|
||||
|
||||
|
||||
from collections import UserList
|
||||
-from test import list_tests
|
||||
+import list_tests
|
||||
import unittest
|
||||
from test import support
|
||||
|
||||
|
||||
@@ -56,9 +110,10 @@ class UserListTest(list_tests.CommonTest):
|
||||
|
||||
|
||||
def test_getitemoverwriteiter(self):
|
||||
# Verify that __getitem__ overrides *are* recognized by __iter__
|
||||
- class T(self.type2test):
|
||||
- def __getitem__(self, key):
|
||||
- return str(key) + '!!!'
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class T(self.type2test):
|
||||
+ def __getitem__(self, key):
|
||||
+ return str(key) + '!!!'
|
||||
self.assertEqual(next(iter(T((1,2)))), "0!!!")
|
||||
|
||||
|
||||
def test_userlist_copy(self):
|
||||
@@ -69,9 +124,9 @@ class UserListTest(list_tests.CommonTest):
|
||||
|
||||
|
||||
# Decorate existing test with recursion limit, because
|
||||
# the test is for C structure, but `UserList` is a Python structure.
|
||||
- test_repr_deep = support.infinite_recursion(25)(
|
||||
|
|
@ -89,7 +89,7 @@ index 312702c8e39..d3d8dbf394a 100644
|
|||
+ # test_repr_deep = support.infinite_recursion(25)(
|
||||
+ # list_tests.CommonTest.test_repr_deep,
|
||||
+ # )
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ class UserListTest(list_tests.CommonTest):
|
|||
|
||||
def test_getitemoverwriteiter(self):
|
||||
# Verify that __getitem__ overrides *are* recognized by __iter__
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class T(self.type2test):
|
||||
def __getitem__(self, key):
|
||||
return str(key) + '!!!'
|
||||
|
|
|
|||
|
|
@ -24,84 +24,84 @@ index 8e9ed8500c7..66c18ad886a 100644
|
|||
+# ======= END DYNAMO PATCH =======
|
||||
+
|
||||
"""Unit tests for the with statement specified in PEP 343."""
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -104,16 +124,17 @@ class MockNested(Nested):
|
||||
return Nested.__exit__(self, *exc_info)
|
||||
|
||||
|
||||
|
||||
|
||||
-class FailureTestCase(unittest.TestCase):
|
||||
+class FailureTestCase(__TestCase):
|
||||
def testNameError(self):
|
||||
def fooNotDeclared():
|
||||
with foo: pass
|
||||
self.assertRaises(NameError, fooNotDeclared)
|
||||
|
||||
|
||||
def testEnterAttributeError1(self):
|
||||
- class LacksEnter(object):
|
||||
- def __exit__(self, type, value, traceback):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class LacksEnter(object):
|
||||
+ def __exit__(self, type, value, traceback):
|
||||
+ pass
|
||||
|
||||
|
||||
def fooLacksEnter():
|
||||
foo = LacksEnter()
|
||||
@@ -121,8 +142,9 @@ class FailureTestCase(unittest.TestCase):
|
||||
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnter)
|
||||
|
||||
|
||||
def testEnterAttributeError2(self):
|
||||
- class LacksEnterAndExit(object):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class LacksEnterAndExit(object):
|
||||
+ pass
|
||||
|
||||
|
||||
def fooLacksEnterAndExit():
|
||||
foo = LacksEnterAndExit()
|
||||
@@ -130,9 +152,10 @@ class FailureTestCase(unittest.TestCase):
|
||||
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnterAndExit)
|
||||
|
||||
|
||||
def testExitAttributeError(self):
|
||||
- class LacksExit(object):
|
||||
- def __enter__(self):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class LacksExit(object):
|
||||
+ def __enter__(self):
|
||||
+ pass
|
||||
|
||||
|
||||
def fooLacksExit():
|
||||
foo = LacksExit()
|
||||
@@ -162,11 +185,12 @@ class FailureTestCase(unittest.TestCase):
|
||||
' pass')
|
||||
|
||||
|
||||
def testEnterThrows(self):
|
||||
- class EnterThrows(object):
|
||||
- def __enter__(self):
|
||||
- raise RuntimeError("Enter threw")
|
||||
- def __exit__(self, *args):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class EnterThrows(object):
|
||||
+ def __enter__(self):
|
||||
+ raise RuntimeError("Enter threw")
|
||||
+ def __exit__(self, *args):
|
||||
+ pass
|
||||
|
||||
|
||||
def shouldThrow():
|
||||
ct = EnterThrows()
|
||||
@@ -180,11 +204,12 @@ class FailureTestCase(unittest.TestCase):
|
||||
self.assertEqual(self.foo, None)
|
||||
|
||||
|
||||
def testExitThrows(self):
|
||||
- class ExitThrows(object):
|
||||
- def __enter__(self):
|
||||
- return
|
||||
- def __exit__(self, *args):
|
||||
- raise RuntimeError(42)
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class ExitThrows(object):
|
||||
+ def __enter__(self):
|
||||
+ return
|
||||
|
|
@ -111,17 +111,17 @@ index 8e9ed8500c7..66c18ad886a 100644
|
|||
with ExitThrows():
|
||||
pass
|
||||
@@ -194,6 +219,7 @@ class ContextmanagerAssertionMixin(object):
|
||||
|
||||
|
||||
def setUp(self):
|
||||
self.TEST_EXCEPTION = RuntimeError("test exception")
|
||||
+ super().setUp()
|
||||
|
||||
|
||||
def assertInWithManagerInvariants(self, mock_manager):
|
||||
self.assertTrue(mock_manager.enter_called)
|
||||
@@ -237,7 +263,7 @@ class ContextmanagerAssertionMixin(object):
|
||||
self.assertTrue(mock_generator.stopped)
|
||||
|
||||
|
||||
|
||||
|
||||
-class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
|
||||
+class NonexceptionalTestCase(__TestCase, ContextmanagerAssertionMixin):
|
||||
def testInlineGeneratorSyntax(self):
|
||||
|
|
@ -129,8 +129,8 @@ index 8e9ed8500c7..66c18ad886a 100644
|
|||
pass
|
||||
@@ -289,7 +315,7 @@ class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
|
||||
self.assertAfterWithGeneratorInvariantsNoError(foo)
|
||||
|
||||
|
||||
|
||||
|
||||
-class NestedNonexceptionalTestCase(unittest.TestCase,
|
||||
+class NestedNonexceptionalTestCase(__TestCase,
|
||||
ContextmanagerAssertionMixin):
|
||||
|
|
@ -138,15 +138,15 @@ index 8e9ed8500c7..66c18ad886a 100644
|
|||
with Nested(mock_contextmanager_generator()):
|
||||
@@ -355,7 +381,7 @@ class NestedNonexceptionalTestCase(unittest.TestCase,
|
||||
self.assertAfterWithManagerInvariantsNoError(mock_nested)
|
||||
|
||||
|
||||
|
||||
|
||||
-class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
|
||||
+class ExceptionalTestCase(ContextmanagerAssertionMixin, __TestCase):
|
||||
def testSingleResource(self):
|
||||
cm = mock_contextmanager_generator()
|
||||
def shouldThrow():
|
||||
@@ -466,11 +492,12 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
|
||||
|
||||
|
||||
def testRaisedStopIteration2(self):
|
||||
# From bug 1462485
|
||||
- class cm(object):
|
||||
|
|
@ -154,17 +154,17 @@ index 8e9ed8500c7..66c18ad886a 100644
|
|||
- pass
|
||||
- def __exit__(self, type, value, traceback):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class cm(object):
|
||||
+ def __enter__(self):
|
||||
+ pass
|
||||
+ def __exit__(self, type, value, traceback):
|
||||
+ pass
|
||||
|
||||
|
||||
def shouldThrow():
|
||||
with cm():
|
||||
@@ -507,11 +534,12 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
|
||||
|
||||
|
||||
def testRaisedGeneratorExit2(self):
|
||||
# From bug 1462485
|
||||
- class cm (object):
|
||||
|
|
@ -172,19 +172,19 @@ index 8e9ed8500c7..66c18ad886a 100644
|
|||
- pass
|
||||
- def __exit__(self, type, value, traceback):
|
||||
- pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class cm (object):
|
||||
+ def __enter__(self):
|
||||
+ pass
|
||||
+ def __exit__(self, type, value, traceback):
|
||||
+ pass
|
||||
|
||||
|
||||
def shouldThrow():
|
||||
with cm():
|
||||
@@ -523,16 +551,17 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
|
||||
# issue4589: __exit__ return code may raise an exception
|
||||
# when looking at its truth value.
|
||||
|
||||
|
||||
- class cm(object):
|
||||
- def __init__(self, bool_conversion):
|
||||
- class Bool:
|
||||
|
|
@ -195,7 +195,7 @@ index 8e9ed8500c7..66c18ad886a 100644
|
|||
- return 3
|
||||
- def __exit__(self, a, b, c):
|
||||
- return self.exit_result
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class cm(object):
|
||||
+ def __init__(self, bool_conversion):
|
||||
+ class Bool:
|
||||
|
|
@ -206,25 +206,25 @@ index 8e9ed8500c7..66c18ad886a 100644
|
|||
+ return 3
|
||||
+ def __exit__(self, a, b, c):
|
||||
+ return self.exit_result
|
||||
|
||||
|
||||
def trueAsBool():
|
||||
with cm(lambda: True):
|
||||
@@ -550,7 +579,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
|
||||
self.assertRaises(ZeroDivisionError, failAsBool)
|
||||
|
||||
|
||||
|
||||
|
||||
-class NonLocalFlowControlTestCase(unittest.TestCase):
|
||||
+class NonLocalFlowControlTestCase(__TestCase):
|
||||
|
||||
|
||||
def testWithBreak(self):
|
||||
counter = 0
|
||||
@@ -607,7 +636,7 @@ class NonLocalFlowControlTestCase(unittest.TestCase):
|
||||
self.fail("Didn't raise RuntimeError")
|
||||
|
||||
|
||||
|
||||
|
||||
-class AssignmentTargetTestCase(unittest.TestCase):
|
||||
+class AssignmentTargetTestCase(__TestCase):
|
||||
|
||||
|
||||
def testSingleComplexTarget(self):
|
||||
targets = {1: [0, 1, 2]}
|
||||
@@ -621,15 +650,17 @@ class AssignmentTargetTestCase(unittest.TestCase):
|
||||
|
|
@ -232,17 +232,17 @@ index 8e9ed8500c7..66c18ad886a 100644
|
|||
keys.sort()
|
||||
self.assertEqual(keys, [1, 2])
|
||||
- class C: pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C: pass
|
||||
blah = C()
|
||||
with mock_contextmanager_generator() as blah.foo:
|
||||
self.assertEqual(hasattr(blah, "foo"), True)
|
||||
|
||||
|
||||
def testMultipleComplexTargets(self):
|
||||
- class C:
|
||||
- def __enter__(self): return 1, 2, 3
|
||||
- def __exit__(self, t, v, tb): pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C:
|
||||
+ def __enter__(self): return 1, 2, 3
|
||||
+ def __exit__(self, t, v, tb): pass
|
||||
|
|
@ -254,23 +254,23 @@ index 8e9ed8500c7..66c18ad886a 100644
|
|||
with C() as (targets[1], targets[2], targets[3]):
|
||||
self.assertEqual(targets, {1: 1, 2: 2, 3: 3})
|
||||
- class B: pass
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class B: pass
|
||||
blah = B()
|
||||
with C() as (blah.one, blah.two, blah.three):
|
||||
self.assertEqual(blah.one, 1)
|
||||
@@ -651,12 +683,13 @@ class AssignmentTargetTestCase(unittest.TestCase):
|
||||
self.assertEqual(c, 4)
|
||||
|
||||
|
||||
|
||||
|
||||
-class ExitSwallowsExceptionTestCase(unittest.TestCase):
|
||||
+class ExitSwallowsExceptionTestCase(__TestCase):
|
||||
|
||||
|
||||
def testExitTrueSwallowsException(self):
|
||||
- class AfricanSwallow:
|
||||
- def __enter__(self): pass
|
||||
- def __exit__(self, t, v, tb): return True
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class AfricanSwallow:
|
||||
+ def __enter__(self): pass
|
||||
+ def __exit__(self, t, v, tb): return True
|
||||
|
|
@ -279,12 +279,12 @@ index 8e9ed8500c7..66c18ad886a 100644
|
|||
1/0
|
||||
@@ -664,9 +697,10 @@ class ExitSwallowsExceptionTestCase(unittest.TestCase):
|
||||
self.fail("ZeroDivisionError should have been swallowed")
|
||||
|
||||
|
||||
def testExitFalseDoesntSwallowException(self):
|
||||
- class EuropeanSwallow:
|
||||
- def __enter__(self): pass
|
||||
- def __exit__(self, t, v, tb): return False
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class EuropeanSwallow:
|
||||
+ def __enter__(self): pass
|
||||
+ def __exit__(self, t, v, tb): return False
|
||||
|
|
@ -293,16 +293,16 @@ index 8e9ed8500c7..66c18ad886a 100644
|
|||
1/0
|
||||
@@ -676,7 +710,7 @@ class ExitSwallowsExceptionTestCase(unittest.TestCase):
|
||||
self.fail("ZeroDivisionError should have been raised")
|
||||
|
||||
|
||||
|
||||
|
||||
-class NestedWith(unittest.TestCase):
|
||||
+class NestedWith(__TestCase):
|
||||
|
||||
|
||||
class Dummy(object):
|
||||
def __init__(self, value=None, gobble=False):
|
||||
@@ -796,4 +830,4 @@ class NestedWith(unittest.TestCase):
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ class FailureTestCase(__TestCase):
|
|||
self.assertRaises(NameError, fooNotDeclared)
|
||||
|
||||
def testEnterAttributeError1(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class LacksEnter(object):
|
||||
def __exit__(self, type, value, traceback):
|
||||
pass
|
||||
|
|
@ -142,7 +142,7 @@ class FailureTestCase(__TestCase):
|
|||
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnter)
|
||||
|
||||
def testEnterAttributeError2(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class LacksEnterAndExit(object):
|
||||
pass
|
||||
|
||||
|
|
@ -152,7 +152,7 @@ class FailureTestCase(__TestCase):
|
|||
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnterAndExit)
|
||||
|
||||
def testExitAttributeError(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class LacksExit(object):
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
|
@ -185,7 +185,7 @@ class FailureTestCase(__TestCase):
|
|||
' pass')
|
||||
|
||||
def testEnterThrows(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class EnterThrows(object):
|
||||
def __enter__(self):
|
||||
raise RuntimeError("Enter threw")
|
||||
|
|
@ -204,7 +204,7 @@ class FailureTestCase(__TestCase):
|
|||
self.assertEqual(self.foo, None)
|
||||
|
||||
def testExitThrows(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class ExitThrows(object):
|
||||
def __enter__(self):
|
||||
return
|
||||
|
|
@ -492,7 +492,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, __TestCase):
|
|||
|
||||
def testRaisedStopIteration2(self):
|
||||
# From bug 1462485
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class cm(object):
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
|
@ -534,7 +534,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, __TestCase):
|
|||
|
||||
def testRaisedGeneratorExit2(self):
|
||||
# From bug 1462485
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class cm (object):
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
|
@ -551,7 +551,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, __TestCase):
|
|||
# issue4589: __exit__ return code may raise an exception
|
||||
# when looking at its truth value.
|
||||
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class cm(object):
|
||||
def __init__(self, bool_conversion):
|
||||
class Bool:
|
||||
|
|
@ -650,14 +650,14 @@ class AssignmentTargetTestCase(__TestCase):
|
|||
keys = list(targets.keys())
|
||||
keys.sort()
|
||||
self.assertEqual(keys, [1, 2])
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C: pass
|
||||
blah = C()
|
||||
with mock_contextmanager_generator() as blah.foo:
|
||||
self.assertEqual(hasattr(blah, "foo"), True)
|
||||
|
||||
def testMultipleComplexTargets(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
def __enter__(self): return 1, 2, 3
|
||||
def __exit__(self, t, v, tb): pass
|
||||
|
|
@ -668,7 +668,7 @@ class AssignmentTargetTestCase(__TestCase):
|
|||
self.assertEqual(targets, {1: [3, 2, 1]})
|
||||
with C() as (targets[1], targets[2], targets[3]):
|
||||
self.assertEqual(targets, {1: 1, 2: 2, 3: 3})
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class B: pass
|
||||
blah = B()
|
||||
with C() as (blah.one, blah.two, blah.three):
|
||||
|
|
@ -686,7 +686,7 @@ class AssignmentTargetTestCase(__TestCase):
|
|||
class ExitSwallowsExceptionTestCase(__TestCase):
|
||||
|
||||
def testExitTrueSwallowsException(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class AfricanSwallow:
|
||||
def __enter__(self): pass
|
||||
def __exit__(self, t, v, tb): return True
|
||||
|
|
@ -697,7 +697,7 @@ class ExitSwallowsExceptionTestCase(__TestCase):
|
|||
self.fail("ZeroDivisionError should have been swallowed")
|
||||
|
||||
def testExitFalseDoesntSwallowException(self):
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class EuropeanSwallow:
|
||||
def __enter__(self): pass
|
||||
def __exit__(self, t, v, tb): return False
|
||||
|
|
|
|||
|
|
@ -1721,13 +1721,13 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
):
|
||||
f4(torch.randn(3))
|
||||
|
||||
def test_set_fullgraph(self):
|
||||
def test_error_on_graph_break(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def f1(x):
|
||||
x = x + 1
|
||||
with torch._dynamo.set_fullgraph(False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
|
|
@ -1738,7 +1738,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
@torch.compile(backend=cnts)
|
||||
def f2(x):
|
||||
x = x + 1
|
||||
with torch._dynamo.set_fullgraph(True):
|
||||
with torch._dynamo.error_on_graph_break(True):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
|
|
@ -1748,7 +1748,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def f3(x):
|
||||
x = x + 1
|
||||
with torch._dynamo.set_fullgraph(False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
torch._dynamo.graph_break()
|
||||
x = x + 2
|
||||
torch._dynamo.graph_break()
|
||||
|
|
@ -1766,7 +1766,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def f4(x):
|
||||
x = x + 1
|
||||
with torch._dynamo.set_fullgraph(False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
torch._dynamo.skip_frame()
|
||||
return inner_f4(x)
|
||||
|
||||
|
|
@ -1774,11 +1774,11 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(f4(inp), inp + 7)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
def test_set_fullgraph_nested(self):
|
||||
# set_fullgraph in a nested frame
|
||||
def test_error_on_graph_break_nested(self):
|
||||
# error_on_graph_break in a nested frame
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
@torch._dynamo.set_fullgraph(False)
|
||||
@torch._dynamo.error_on_graph_break(False)
|
||||
def inner_f5(x):
|
||||
x = x + 2
|
||||
torch._dynamo.graph_break()
|
||||
|
|
@ -1795,7 +1795,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
def inner_f6(x):
|
||||
x = x + 2
|
||||
with torch._dynamo.set_fullgraph(False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 4
|
||||
|
||||
|
|
@ -1810,7 +1810,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
def inner_f7(x):
|
||||
x = x + 2
|
||||
with torch._dynamo.set_fullgraph(True):
|
||||
with torch._dynamo.error_on_graph_break(True):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 4
|
||||
|
||||
|
|
@ -1822,18 +1822,18 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
with self.assertRaises(Unsupported):
|
||||
f7(inp)
|
||||
|
||||
def test_set_fullgraph_nested_with_skip(self):
|
||||
# set_fullgraph in a nested frame with a skipped frame in between
|
||||
def test_error_on_graph_break_nested_with_skip(self):
|
||||
# error_on_graph_break in a nested frame with a skipped frame in between
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
@torch._dynamo.set_fullgraph(False)
|
||||
@torch._dynamo.error_on_graph_break(False)
|
||||
def inner2_f8(x):
|
||||
x = x + 2
|
||||
torch._dynamo.graph_break()
|
||||
return x + 4
|
||||
|
||||
def inner1_f8(x):
|
||||
with torch._dynamo.set_fullgraph(False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
torch._dynamo.skip_frame()
|
||||
return inner2_f8(x)
|
||||
|
||||
|
|
@ -1848,7 +1848,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
def inner2_f9(x):
|
||||
x = x + 2
|
||||
with torch._dynamo.set_fullgraph(True):
|
||||
with torch._dynamo.error_on_graph_break(True):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 4
|
||||
|
||||
|
|
@ -1864,10 +1864,10 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
with self.assertRaises(Unsupported):
|
||||
f9(inp)
|
||||
|
||||
# test export with set_fullgraph(False) still errors
|
||||
# test export with error_on_graph_break(False) still errors
|
||||
|
||||
def test_set_fullgraph_export(self):
|
||||
@torch._dynamo.set_fullgraph(False)
|
||||
def test_error_on_graph_break_export(self):
|
||||
@torch._dynamo.error_on_graph_break(False)
|
||||
def inner(x):
|
||||
x = x + 2
|
||||
torch._dynamo.graph_break()
|
||||
|
|
@ -1880,7 +1880,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
with self.assertRaises(Unsupported):
|
||||
torch._dynamo.export(f)(torch.ones(3))
|
||||
|
||||
def test_set_fullgraph_nested_deep(self):
|
||||
def test_error_on_graph_break_nested_deep(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
def inner1_f1(x):
|
||||
|
|
@ -1892,7 +1892,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
return inner1_f1(x)
|
||||
|
||||
def inner3_f1(x):
|
||||
with torch._dynamo.set_fullgraph(False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
return inner2_f1(x)
|
||||
|
||||
def inner4_f1(x):
|
||||
|
|
@ -1916,7 +1916,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
return inner1_f2(x)
|
||||
|
||||
def inner3_f2(x):
|
||||
with torch._dynamo.set_fullgraph(True):
|
||||
with torch._dynamo.error_on_graph_break(True):
|
||||
return inner2_f2(x)
|
||||
|
||||
def inner4_f2(x):
|
||||
|
|
@ -1930,20 +1930,20 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
with self.assertRaises(Unsupported):
|
||||
f2(inp)
|
||||
|
||||
def test_set_fullgraph_error(self):
|
||||
def test_error_on_graph_break_error(self):
|
||||
@torch.compile(backend="eager")
|
||||
def f1():
|
||||
with torch._dynamo.set_fullgraph(foo="bar"):
|
||||
with torch._dynamo.error_on_graph_break(foo="bar"):
|
||||
pass
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def f2():
|
||||
with torch._dynamo.set_fullgraph():
|
||||
with torch._dynamo.error_on_graph_break():
|
||||
pass
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def f3():
|
||||
with torch._dynamo.set_fullgraph("foo"):
|
||||
with torch._dynamo.error_on_graph_break("foo"):
|
||||
pass
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from .decorators import (
|
|||
disable,
|
||||
disallow_in_graph,
|
||||
dont_skip_tracing,
|
||||
error_on_graph_break,
|
||||
forbid_in_graph,
|
||||
graph_break,
|
||||
mark_dynamic,
|
||||
|
|
@ -30,7 +31,6 @@ from .decorators import (
|
|||
nonstrict_trace,
|
||||
patch_dynamo_config,
|
||||
run,
|
||||
set_fullgraph,
|
||||
set_stance,
|
||||
skip_frame,
|
||||
substitute_in_graph,
|
||||
|
|
@ -90,7 +90,7 @@ __all__ = [
|
|||
"replay",
|
||||
"reset",
|
||||
"run",
|
||||
"set_fullgraph",
|
||||
"error_on_graph_break",
|
||||
"set_stance",
|
||||
"skip_frame",
|
||||
"substitute_in_graph",
|
||||
|
|
|
|||
|
|
@ -1747,7 +1747,7 @@ def replay(filename: str) -> None:
|
|||
record = ExecutionRecord.load(in_file)
|
||||
record.globals = dict(itertools.chain(record.globals.items(), globals().items()))
|
||||
|
||||
with decorators.set_fullgraph(fullgraph=False):
|
||||
with decorators.error_on_graph_break(False):
|
||||
try:
|
||||
_compile(
|
||||
record.code,
|
||||
|
|
|
|||
|
|
@ -918,15 +918,15 @@ def dont_skip_tracing(fn: Optional[Any] = None) -> Any:
|
|||
return ctx
|
||||
|
||||
|
||||
class SetFullgraphDecoratorContextManager:
|
||||
def __init__(self, fullgraph: bool) -> None:
|
||||
self.fullgraph = fullgraph
|
||||
class ErrorOnGraphBreakDecoratorContextManager:
|
||||
def __init__(self, error_on_graph_break: bool) -> None:
|
||||
self.error_on_graph_break = error_on_graph_break
|
||||
|
||||
__call__ = wrap_dunder_call_ctx_manager
|
||||
|
||||
def __enter__(self) -> None:
|
||||
self.prev_fullgraph = _get_error_on_graph_break()
|
||||
_set_error_on_graph_break(self.fullgraph)
|
||||
self.prev_error_on_graph_break = _get_error_on_graph_break()
|
||||
_set_error_on_graph_break(self.error_on_graph_break)
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
|
|
@ -934,14 +934,16 @@ class SetFullgraphDecoratorContextManager:
|
|||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
_set_error_on_graph_break(self.prev_fullgraph)
|
||||
_set_error_on_graph_break(self.prev_error_on_graph_break)
|
||||
|
||||
|
||||
def set_fullgraph(fullgraph: bool) -> SetFullgraphDecoratorContextManager:
|
||||
def error_on_graph_break(
|
||||
error_on_graph_break: bool,
|
||||
) -> ErrorOnGraphBreakDecoratorContextManager:
|
||||
"""
|
||||
Context manager/decorator to toggle fullgraph setting.
|
||||
Context manager/decorator to toggle error_on_graph_break (i.e. torch.compile's fullgraph) setting.
|
||||
|
||||
More precisely, when encountering a graph break, we will decide to resume (fullgraph=False)
|
||||
or error out (fullgraph=True) based on the fullgraph setting at the location of the graph break.
|
||||
"""
|
||||
return SetFullgraphDecoratorContextManager(fullgraph)
|
||||
return ErrorOnGraphBreakDecoratorContextManager(error_on_graph_break)
|
||||
|
|
|
|||
|
|
@ -858,7 +858,9 @@ class _TorchDynamoContext:
|
|||
|
||||
# hooks to properly handle inlining
|
||||
compile_wrapper._torchdynamo_inline = ( # type: ignore[attr-defined]
|
||||
external_utils.wrap_inline_with_set_fullgraph(fn, self.error_on_graph_break)
|
||||
external_utils.wrap_inline_with_error_on_graph_break(
|
||||
fn, self.error_on_graph_break
|
||||
)
|
||||
)
|
||||
|
||||
# Save the function pointer to find the original callable while nesting
|
||||
|
|
|
|||
|
|
@ -229,23 +229,23 @@ def call_accumulate_grad(
|
|||
variable.grad = updated_grad[0]
|
||||
|
||||
|
||||
def wrap_inline_with_set_fullgraph(
|
||||
fn: Callable[_P, _R], fullgraph: bool
|
||||
def wrap_inline_with_error_on_graph_break(
|
||||
fn: Callable[_P, _R], error_on_graph_break: bool
|
||||
) -> Callable[_P, _R]:
|
||||
# NB: need multiple definitions in order to prevent `fullgraph` from
|
||||
# being a freevar of wrapper
|
||||
if fullgraph:
|
||||
if error_on_graph_break:
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
with torch._dynamo.set_fullgraph(True):
|
||||
with torch._dynamo.error_on_graph_break(True):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
else:
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
with torch._dynamo.set_fullgraph(False):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
|
|
|||
|
|
@ -348,7 +348,7 @@ manual_torch_name_rule_map: dict[
|
|||
"torch._dynamo.mark_static": UserFunctionVariable,
|
||||
"torch._dynamo.nonstrict_trace": UserFunctionVariable,
|
||||
"torch._dynamo.patch_dynamo_config": UserFunctionVariable,
|
||||
"torch._dynamo.set_fullgraph": UserFunctionVariable,
|
||||
"torch._dynamo.error_on_graph_break": UserFunctionVariable,
|
||||
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
|
||||
"torch.fx.experimental.symbolic_shapes.guard_or_true": TorchInGraphFunctionVariable,
|
||||
"torch.fx.experimental.symbolic_shapes.guard_or_false": TorchInGraphFunctionVariable,
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from .ctx_manager import (
|
|||
DisabledSavedTensorsHooksVariable,
|
||||
DualLevelContextManager,
|
||||
DynamoConfigPatchVariable,
|
||||
ErrorOnGraphBreakVariable,
|
||||
FSDPParamGroupUseTrainingStateVariable,
|
||||
GradIncrementNestingCtxManagerVariable,
|
||||
GradInplaceRequiresGradCtxManagerVariable,
|
||||
|
|
@ -34,7 +35,6 @@ from .ctx_manager import (
|
|||
InferenceModeVariable,
|
||||
JvpIncrementNestingCtxManagerVariable,
|
||||
SDPAKernelVariable,
|
||||
SetFullgraphVariable,
|
||||
SetFwdGradEnabledContextManager,
|
||||
StreamContextVariable,
|
||||
StreamVariable,
|
||||
|
|
@ -200,7 +200,7 @@ __all__ = [
|
|||
"RemovableHandleVariable",
|
||||
"RepeatIteratorVariable",
|
||||
"SDPAParamsVariable",
|
||||
"SetFullgraphVariable",
|
||||
"ErrorOnGraphBreakVariable",
|
||||
"SkipFunctionVariable",
|
||||
"SliceVariable",
|
||||
"StringFormatVariable",
|
||||
|
|
|
|||
|
|
@ -168,10 +168,10 @@ from .constant import ConstantVariable, EnumVariable
|
|||
from .ctx_manager import (
|
||||
AutocastModeVariable,
|
||||
DynamoConfigPatchVariable,
|
||||
ErrorOnGraphBreakVariable,
|
||||
EventVariable,
|
||||
NullContextVariable,
|
||||
PreserveVersionContextVariable,
|
||||
SetFullgraphVariable,
|
||||
StreamContextVariable,
|
||||
StreamVariable,
|
||||
)
|
||||
|
|
@ -630,7 +630,7 @@ class VariableBuilder:
|
|||
|
||||
from ..decorators import (
|
||||
DynamoConfigPatchProxy,
|
||||
SetFullgraphDecoratorContextManager,
|
||||
ErrorOnGraphBreakDecoratorContextManager,
|
||||
)
|
||||
|
||||
if has_triton():
|
||||
|
|
@ -988,8 +988,8 @@ class VariableBuilder:
|
|||
)
|
||||
elif isinstance(value, DynamoConfigPatchProxy):
|
||||
return DynamoConfigPatchVariable(value.changes)
|
||||
elif isinstance(value, SetFullgraphDecoratorContextManager):
|
||||
return SetFullgraphVariable(value.fullgraph)
|
||||
elif isinstance(value, ErrorOnGraphBreakDecoratorContextManager):
|
||||
return ErrorOnGraphBreakVariable(value.error_on_graph_break)
|
||||
elif callable(value) and trace_rules.lookup_callable(value) is not None:
|
||||
if trace_rules.is_callable_allowed(value):
|
||||
self.tx.output.has_user_defined_allowed_in_graph = True
|
||||
|
|
|
|||
|
|
@ -1429,12 +1429,12 @@ class DynamoConfigPatchVariable(ContextWrappingVariable):
|
|||
return "patch_dynamo_config"
|
||||
|
||||
|
||||
class SetFullgraphVariable(ContextWrappingVariable):
|
||||
"""represents torch._dynamo.set_fullgraph"""
|
||||
class ErrorOnGraphBreakVariable(ContextWrappingVariable):
|
||||
"""represents torch._dynamo.error_on_graph_break"""
|
||||
|
||||
def __init__(self, fullgraph, **kwargs) -> None:
|
||||
def __init__(self, error_on_graph_break, **kwargs) -> None:
|
||||
super().__init__(
|
||||
target_values=(fullgraph,),
|
||||
target_values=(error_on_graph_break,),
|
||||
initial_values=(_get_error_on_graph_break(),),
|
||||
**kwargs,
|
||||
)
|
||||
|
|
@ -1447,7 +1447,7 @@ class SetFullgraphVariable(ContextWrappingVariable):
|
|||
return "torch._dynamo"
|
||||
|
||||
def fn_name(self):
|
||||
return "set_fullgraph"
|
||||
return "error_on_graph_break"
|
||||
|
||||
|
||||
class WithExitFunctionVariable(VariableTracker):
|
||||
|
|
|
|||
|
|
@ -521,15 +521,17 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
|||
"Please fix your call to patch_dynamo_config by using simpler inputs. "
|
||||
f"args: {args}, kwargs: {kwargs}"
|
||||
) from e
|
||||
elif self.fn is torch._dynamo.set_fullgraph:
|
||||
elif self.fn is torch._dynamo.error_on_graph_break:
|
||||
try:
|
||||
bound = inspect.signature(self.fn).bind(*args, **kwargs)
|
||||
fullgraph = bound.arguments["fullgraph"].as_python_constant()
|
||||
assert isinstance(fullgraph, bool)
|
||||
return variables.SetFullgraphVariable(fullgraph)
|
||||
error_on_graph_break = bound.arguments[
|
||||
"error_on_graph_break"
|
||||
].as_python_constant()
|
||||
assert isinstance(error_on_graph_break, bool)
|
||||
return variables.ErrorOnGraphBreakVariable(error_on_graph_break)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Improper set_fullgraph() call. Please fix your call to set_fullgraph(). "
|
||||
"Improper error_on_graph_break() call. Please fix your call to error_on_graph_break(). "
|
||||
f"args: {args}, kwargs: {kwargs}"
|
||||
) from e
|
||||
# Handle a `nonstrict_trace(fn)` call
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user