mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph. I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet). cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739 Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
434 lines
18 KiB
Diff
434 lines
18 KiB
Diff
diff --git a/test/dynamo/cpython/3_13/test_itertools.py b/test/dynamo/cpython/3_13/test_itertools.py
|
|
index 7d5ba727389..ff514815da2 100644
|
|
--- a/test/dynamo/cpython/3_13/test_itertools.py
|
|
+++ b/test/dynamo/cpython/3_13/test_itertools.py
|
|
@@ -1,3 +1,25 @@
|
|
+# ======= BEGIN Dynamo patch =======
|
|
+# Owner(s): ["module: dynamo"]
|
|
+
|
|
+# ruff: noqa
|
|
+# flake8: noqa
|
|
+
|
|
+# Test copied from
|
|
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_itertools.py
|
|
+
|
|
+import torch
|
|
+import torch._dynamo.test_case
|
|
+from torch._dynamo.test_case import CPythonTestCase
|
|
+from torch.testing._internal.common_utils import (
|
|
+ run_tests,
|
|
+ skipIfTorchDynamo,
|
|
+ slowTest,
|
|
+)
|
|
+
|
|
+__TestCase = CPythonTestCase
|
|
+
|
|
+# ======= END DYNAMO PATCH =======
|
|
+
|
|
import doctest
|
|
import unittest
|
|
import itertools
|
|
@@ -40,6 +62,14 @@ def pickle_deprecated(testfunc):
|
|
maxsize = support.MAX_Py_ssize_t
|
|
minsize = -maxsize-1
|
|
|
|
+@torch._dynamo.disable
|
|
+def choice(*args):
|
|
+ return random.choice(*args)
|
|
+
|
|
+@torch._dynamo.disable
|
|
+def randrange(*args):
|
|
+ return random.randrange(*args)
|
|
+
|
|
def lzip(*args):
|
|
return list(zip(*args))
|
|
|
|
@@ -90,10 +120,10 @@ def fact(n):
|
|
return prod(range(1, n+1))
|
|
|
|
# root level methods for pickling ability
|
|
-def testR(r):
|
|
+def _testR(r):
|
|
return r[0]
|
|
|
|
-def testR2(r):
|
|
+def _testR2(r):
|
|
return r[2]
|
|
|
|
def underten(x):
|
|
@@ -102,7 +132,7 @@ def underten(x):
|
|
picklecopiers = [lambda s, proto=proto: pickle.loads(pickle.dumps(s, proto))
|
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1)]
|
|
|
|
-class TestBasicOps(unittest.TestCase):
|
|
+class TestBasicOps(__TestCase):
|
|
|
|
def pickletest(self, protocol, it, stop=4, take=1, compare=None):
|
|
"""Test that an iterator is the same after pickling, also when part-consumed"""
|
|
@@ -454,14 +484,8 @@ class TestBasicOps(unittest.TestCase):
|
|
self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1)
|
|
self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1)
|
|
|
|
- @pickle_deprecated
|
|
def test_permutations(self):
|
|
- self.assertRaises(TypeError, permutations) # too few arguments
|
|
- self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
|
|
- self.assertRaises(TypeError, permutations, None) # pool is not iterable
|
|
- self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
|
|
self.assertEqual(list(permutations('abc', 32)), []) # r > n
|
|
- self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None
|
|
self.assertEqual(list(permutations(range(3), 2)),
|
|
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
|
|
|
|
@@ -498,7 +522,7 @@ class TestBasicOps(unittest.TestCase):
|
|
if len(set(indices)) == r:
|
|
yield tuple(pool[i] for i in indices)
|
|
|
|
- for n in range(7):
|
|
+ for n in range(5):
|
|
values = [5*x-12 for x in range(n)]
|
|
for r in range(n+2):
|
|
result = list(permutations(values, r))
|
|
@@ -515,9 +539,6 @@ class TestBasicOps(unittest.TestCase):
|
|
self.assertEqual(result, list(permutations(values, None))) # test r as None
|
|
self.assertEqual(result, list(permutations(values))) # test default r
|
|
|
|
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
- self.pickletest(proto, permutations(values, r)) # test pickling
|
|
-
|
|
@support.bigaddrspacetest
|
|
def test_permutations_overflow(self):
|
|
with self.assertRaises((OverflowError, MemoryError)):
|
|
@@ -756,7 +777,7 @@ class TestBasicOps(unittest.TestCase):
|
|
def test_cycle(self):
|
|
self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
|
|
self.assertEqual(list(cycle('')), [])
|
|
- self.assertRaises(TypeError, cycle)
|
|
+ # self.assertRaises(TypeError, cycle)
|
|
self.assertRaises(TypeError, cycle, 5)
|
|
self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0])
|
|
|
|
@@ -888,7 +909,7 @@ class TestBasicOps(unittest.TestCase):
|
|
# Check normal pickled
|
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
dup = []
|
|
- for k, g in pickle.loads(pickle.dumps(groupby(s, testR), proto)):
|
|
+ for k, g in pickle.loads(pickle.dumps(groupby(s, _testR), proto)):
|
|
for elem in g:
|
|
self.assertEqual(k, elem[0])
|
|
dup.append(elem)
|
|
@@ -896,8 +917,8 @@ class TestBasicOps(unittest.TestCase):
|
|
|
|
# Check nested case
|
|
dup = []
|
|
- for k, g in groupby(s, testR):
|
|
- for ik, ig in groupby(g, testR2):
|
|
+ for k, g in groupby(s, _testR):
|
|
+ for ik, ig in groupby(g, _testR2):
|
|
for elem in ig:
|
|
self.assertEqual(k, elem[0])
|
|
self.assertEqual(ik, elem[2])
|
|
@@ -907,8 +928,8 @@ class TestBasicOps(unittest.TestCase):
|
|
# Check nested and pickled
|
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
dup = []
|
|
- for k, g in pickle.loads(pickle.dumps(groupby(s, testR), proto)):
|
|
- for ik, ig in pickle.loads(pickle.dumps(groupby(g, testR2), proto)):
|
|
+ for k, g in pickle.loads(pickle.dumps(groupby(s, _testR), proto)):
|
|
+ for ik, ig in pickle.loads(pickle.dumps(groupby(g, _testR2), proto)):
|
|
for elem in ig:
|
|
self.assertEqual(k, elem[0])
|
|
self.assertEqual(ik, elem[2])
|
|
@@ -917,7 +938,7 @@ class TestBasicOps(unittest.TestCase):
|
|
|
|
|
|
# Check case where inner iterator is not used
|
|
- keys = [k for k, g in groupby(s, testR)]
|
|
+ keys = [k for k, g in groupby(s, _testR)]
|
|
expectedkeys = set([r[0] for r in s])
|
|
self.assertEqual(set(keys), expectedkeys)
|
|
self.assertEqual(len(keys), len(expectedkeys))
|
|
@@ -925,7 +946,7 @@ class TestBasicOps(unittest.TestCase):
|
|
# Check case where inner iterator is used after advancing the groupby
|
|
# iterator
|
|
s = list(zip('AABBBAAAA', range(9)))
|
|
- it = groupby(s, testR)
|
|
+ it = groupby(s, _testR)
|
|
_, g1 = next(it)
|
|
_, g2 = next(it)
|
|
_, g3 = next(it)
|
|
@@ -936,7 +957,7 @@ class TestBasicOps(unittest.TestCase):
|
|
self.assertEqual(list(g3), [])
|
|
|
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
- it = groupby(s, testR)
|
|
+ it = groupby(s, _testR)
|
|
_, g = next(it)
|
|
next(it)
|
|
next(it)
|
|
@@ -1002,29 +1023,30 @@ class TestBasicOps(unittest.TestCase):
|
|
self.assertEqual(list(filter(None, [0,1,0,2,0])), [1,2])
|
|
self.assertEqual(list(filter(bool, [0,1,0,2,0])), [1,2])
|
|
self.assertEqual(take(4, filter(isEven, count())), [0,2,4,6])
|
|
- self.assertRaises(TypeError, filter)
|
|
- self.assertRaises(TypeError, filter, lambda x:x)
|
|
- self.assertRaises(TypeError, filter, lambda x:x, range(6), 7)
|
|
- self.assertRaises(TypeError, filter, isEven, 3)
|
|
- self.assertRaises(TypeError, next, filter(range(6), range(6)))
|
|
+ # these tests raise dynamo exceptions, not TypeError
|
|
+ # self.assertRaises(TypeError, filter)
|
|
+ # self.assertRaises(TypeError, filter, lambda x:x)
|
|
+ # self.assertRaises(TypeError, filter, lambda x:x, range(6), 7)
|
|
+ # self.assertRaises(TypeError, filter, isEven, 3)
|
|
+ # dynamo raises Unsupported in this case
|
|
+ # self.assertRaises(TypeError, next, filter(range(6), range(6)))
|
|
|
|
# check copy, deepcopy, pickle
|
|
- ans = [0,2,4]
|
|
-
|
|
- c = filter(isEven, range(6))
|
|
- self.assertEqual(list(copy.copy(c)), ans)
|
|
- c = filter(isEven, range(6))
|
|
- self.assertEqual(list(copy.deepcopy(c)), ans)
|
|
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
- c = filter(isEven, range(6))
|
|
- self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans)
|
|
- next(c)
|
|
- self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans[1:])
|
|
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
- c = filter(isEven, range(6))
|
|
- self.pickletest(proto, c)
|
|
+ # ans = [0,2,4]
|
|
+
|
|
+ # c = filter(isEven, range(6))
|
|
+ # self.assertEqual(list(copy.copy(c)), ans)
|
|
+ # c = filter(isEven, range(6))
|
|
+ # self.assertEqual(list(copy.deepcopy(c)), ans)
|
|
+ # for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
+ # c = filter(isEven, range(6))
|
|
+ # self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans)
|
|
+ # next(c)
|
|
+ # self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans[1:])
|
|
+ # for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
+ # c = filter(isEven, range(6))
|
|
+ # self.pickletest(proto, c)
|
|
|
|
- @pickle_deprecated
|
|
def test_filterfalse(self):
|
|
self.assertEqual(list(filterfalse(isEven, range(6))), [1,3,5])
|
|
self.assertEqual(list(filterfalse(None, [0,1,0,2,0])), [0,0,0])
|
|
@@ -1034,9 +1056,10 @@ class TestBasicOps(unittest.TestCase):
|
|
self.assertRaises(TypeError, filterfalse, lambda x:x)
|
|
self.assertRaises(TypeError, filterfalse, lambda x:x, range(6), 7)
|
|
self.assertRaises(TypeError, filterfalse, isEven, 3)
|
|
- self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
|
|
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
- self.pickletest(proto, filterfalse(isEven, range(6)))
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
|
|
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
+ self.pickletest(proto, filterfalse(isEven, range(6)))
|
|
|
|
def test_zip(self):
|
|
# XXX This is rather silly now that builtin zip() calls zip()...
|
|
@@ -1047,8 +1070,8 @@ class TestBasicOps(unittest.TestCase):
|
|
self.assertEqual(take(3,zip('abcdef', count())), lzip('abcdef', range(3)))
|
|
self.assertEqual(list(zip('abcdef')), lzip('abcdef'))
|
|
self.assertEqual(list(zip()), lzip())
|
|
- self.assertRaises(TypeError, zip, 3)
|
|
- self.assertRaises(TypeError, zip, range(3), 3)
|
|
+ # self.assertRaises(TypeError, zip, 3)
|
|
+ # self.assertRaises(TypeError, zip, range(3), 3)
|
|
self.assertEqual([tuple(list(pair)) for pair in zip('abc', 'def')],
|
|
lzip('abc', 'def'))
|
|
self.assertEqual([pair for pair in zip('abc', 'def')],
|
|
@@ -1105,19 +1128,19 @@ class TestBasicOps(unittest.TestCase):
|
|
|
|
self.assertEqual(list(zip_longest('abc', 'defg', **{})),
|
|
list(zip(list('abc')+[None], 'defg'))) # empty keyword dict
|
|
- self.assertRaises(TypeError, zip_longest, 3)
|
|
- self.assertRaises(TypeError, zip_longest, range(3), 3)
|
|
-
|
|
- for stmt in [
|
|
- "zip_longest('abc', fv=1)",
|
|
- "zip_longest('abc', fillvalue=1, bogus_keyword=None)",
|
|
- ]:
|
|
- try:
|
|
- eval(stmt, globals(), locals())
|
|
- except TypeError:
|
|
- pass
|
|
- else:
|
|
- self.fail('Did not raise Type in: ' + stmt)
|
|
+ # self.assertRaises(TypeError, zip_longest, 3)
|
|
+ # self.assertRaises(TypeError, zip_longest, range(3), 3)
|
|
+
|
|
+ # for stmt in [
|
|
+ # "zip_longest('abc', fv=1)",
|
|
+ # "zip_longest('abc', fillvalue=1, bogus_keyword=None)",
|
|
+ # ]:
|
|
+ # try:
|
|
+ # eval(stmt, globals(), locals())
|
|
+ # except TypeError:
|
|
+ # pass
|
|
+ # else:
|
|
+ # self.fail('Did not raise Type in: ' + stmt)
|
|
|
|
self.assertEqual([tuple(list(pair)) for pair in zip_longest('abc', 'def')],
|
|
list(zip('abc', 'def')))
|
|
@@ -1296,7 +1319,6 @@ class TestBasicOps(unittest.TestCase):
|
|
self.assertEqual(list(product(*(args*r))),
|
|
list(product(*args, **dict(repeat=r))))
|
|
self.assertEqual(len(list(product(*[range(7)]*6))), 7**6)
|
|
- self.assertRaises(TypeError, product, range(6), None)
|
|
|
|
def product1(*args, **kwds):
|
|
pools = list(map(tuple, args)) * kwds.get('repeat', 1)
|
|
@@ -1336,7 +1358,8 @@ class TestBasicOps(unittest.TestCase):
|
|
argtypes = ['', 'abc', '', range(0), range(4), dict(a=1, b=2, c=3),
|
|
set('abcdefg'), range(11), tuple(range(13))]
|
|
for i in range(100):
|
|
- args = [random.choice(argtypes) for j in range(random.randrange(5))]
|
|
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
|
+ args = [choice(argtypes) for j in range(randrange(5))]
|
|
expected_len = prod(map(len, args))
|
|
self.assertEqual(len(list(product(*args))), expected_len)
|
|
self.assertEqual(list(product(*args)), list(product1(*args)))
|
|
@@ -1767,6 +1790,7 @@ class TestBasicOps(unittest.TestCase):
|
|
script_helper.assert_python_ok("-c", script)
|
|
|
|
# Issue 13454: Crash when deleting backward iterator from tee()
|
|
+ @skipIfTorchDynamo("infinite loop in torch dynamo")
|
|
def test_tee_del_backward(self):
|
|
forward, backward = tee(repeat(None, 20000000))
|
|
try:
|
|
@@ -1920,7 +1944,7 @@ class TestBasicOps(unittest.TestCase):
|
|
tp.foobar = 1
|
|
|
|
|
|
-class TestExamples(unittest.TestCase):
|
|
+class TestExamples(__TestCase):
|
|
|
|
def test_accumulate(self):
|
|
self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15])
|
|
@@ -2032,7 +2056,7 @@ class TestExamples(unittest.TestCase):
|
|
self.assertEqual(list(takewhile(lambda x: x<5, [1,4,6,4,1])), [1,4])
|
|
|
|
|
|
-class TestPurePythonRoughEquivalents(unittest.TestCase):
|
|
+class TestPurePythonRoughEquivalents(__TestCase):
|
|
|
|
def test_batched_recipe(self):
|
|
def batched_recipe(iterable, n):
|
|
@@ -2081,6 +2105,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
|
|
for i, element in zip(range(i + 1, stop), iterable):
|
|
pass
|
|
|
|
+ @skipIfTorchDynamo("infinite loop in torch dynamo")
|
|
def test_islice_recipe(self):
|
|
self.assertEqual(list(self.islice('ABCDEFG', 2)), list('AB'))
|
|
self.assertEqual(list(self.islice('ABCDEFG', 2, 4)), list('CD'))
|
|
@@ -2265,7 +2290,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
|
|
raise
|
|
|
|
|
|
-class TestGC(unittest.TestCase):
|
|
+class TestGC(__TestCase):
|
|
|
|
def makecycle(self, iterator, container):
|
|
container.append(iterator)
|
|
@@ -2465,7 +2490,7 @@ def L(seqn):
|
|
return chain(map(lambda x:x, R(Ig(G(seqn)))))
|
|
|
|
|
|
-class TestVariousIteratorArgs(unittest.TestCase):
|
|
+class TestVariousIteratorArgs(__TestCase):
|
|
|
|
def test_accumulate(self):
|
|
s = [1,2,3,4,5]
|
|
@@ -2644,7 +2669,7 @@ class TestVariousIteratorArgs(unittest.TestCase):
|
|
self.assertRaises(TypeError, tee, N(s))
|
|
self.assertRaises(ZeroDivisionError, list, tee(E(s))[0])
|
|
|
|
-class LengthTransparency(unittest.TestCase):
|
|
+class LengthTransparency(__TestCase):
|
|
|
|
def test_repeat(self):
|
|
self.assertEqual(operator.length_hint(repeat(None, 50)), 50)
|
|
@@ -2657,7 +2682,7 @@ class LengthTransparency(unittest.TestCase):
|
|
self.assertEqual(operator.length_hint(repeat(None, times=-1)), 0)
|
|
self.assertEqual(operator.length_hint(repeat(None, times=-2)), 0)
|
|
|
|
-class RegressionTests(unittest.TestCase):
|
|
+class RegressionTests(__TestCase):
|
|
|
|
def test_sf_793826(self):
|
|
# Fix Armin Rigo's successful efforts to wreak havoc
|
|
@@ -2718,6 +2743,7 @@ class RegressionTests(unittest.TestCase):
|
|
|
|
@support.skip_if_pgo_task
|
|
@support.requires_resource('cpu')
|
|
+ @slowTest
|
|
def test_long_chain_of_empty_iterables(self):
|
|
# Make sure itertools.chain doesn't run into recursion limits when
|
|
# dealing with long chains of empty iterables. Even with a high
|
|
@@ -2750,7 +2776,7 @@ class RegressionTests(unittest.TestCase):
|
|
next(g, None) # shouldn't crash
|
|
|
|
|
|
-class SubclassWithKwargsTest(unittest.TestCase):
|
|
+class SubclassWithKwargsTest(__TestCase):
|
|
def test_keywords_in_subclass(self):
|
|
# count is not subclassable...
|
|
testcases = [
|
|
@@ -2805,49 +2831,5 @@ class SubclassWithKwargsTest(unittest.TestCase):
|
|
self.assertEqual(u.newarg, 3)
|
|
|
|
|
|
-@support.cpython_only
|
|
-class SizeofTest(unittest.TestCase):
|
|
- def setUp(self):
|
|
- self.ssize_t = struct.calcsize('n')
|
|
-
|
|
- check_sizeof = support.check_sizeof
|
|
-
|
|
- def test_product_sizeof(self):
|
|
- basesize = support.calcobjsize('3Pi')
|
|
- check = self.check_sizeof
|
|
- check(product('ab', '12'), basesize + 2 * self.ssize_t)
|
|
- check(product(*(('abc',) * 10)), basesize + 10 * self.ssize_t)
|
|
-
|
|
- def test_combinations_sizeof(self):
|
|
- basesize = support.calcobjsize('3Pni')
|
|
- check = self.check_sizeof
|
|
- check(combinations('abcd', 3), basesize + 3 * self.ssize_t)
|
|
- check(combinations(range(10), 4), basesize + 4 * self.ssize_t)
|
|
-
|
|
- def test_combinations_with_replacement_sizeof(self):
|
|
- cwr = combinations_with_replacement
|
|
- basesize = support.calcobjsize('3Pni')
|
|
- check = self.check_sizeof
|
|
- check(cwr('abcd', 3), basesize + 3 * self.ssize_t)
|
|
- check(cwr(range(10), 4), basesize + 4 * self.ssize_t)
|
|
-
|
|
- def test_permutations_sizeof(self):
|
|
- basesize = support.calcobjsize('4Pni')
|
|
- check = self.check_sizeof
|
|
- check(permutations('abcd'),
|
|
- basesize + 4 * self.ssize_t + 4 * self.ssize_t)
|
|
- check(permutations('abcd', 3),
|
|
- basesize + 4 * self.ssize_t + 3 * self.ssize_t)
|
|
- check(permutations('abcde', 3),
|
|
- basesize + 5 * self.ssize_t + 3 * self.ssize_t)
|
|
- check(permutations(range(10), 4),
|
|
- basesize + 10 * self.ssize_t + 4 * self.ssize_t)
|
|
-
|
|
-
|
|
-def load_tests(loader, tests, pattern):
|
|
- tests.addTest(doctest.DocTestSuite(itertools))
|
|
- return tests
|
|
-
|
|
-
|
|
if __name__ == "__main__":
|
|
- unittest.main()
|
|
+ run_tests()
|