pytorch/test/dynamo/cpython/3_13/test_iter.diff
William Wen 8678d831c4 [dynamo] rename set_fullgraph to error_on_graph_break (#161739)
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph.

I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet).

 cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739
Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
2025-09-04 01:15:06 +00:00

435 lines
15 KiB
Diff

diff --git a/test/dynamo/cpython/3_13/test_iter.py b/test/dynamo/cpython/3_13/test_iter.py
index 1b9f3cf7624..6560c7423a6 100644
--- a/test/dynamo/cpython/3_13/test_iter.py
+++ b/test/dynamo/cpython/3_13/test_iter.py
@@ -1,3 +1,60 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_iter.py
+
+import sys
+import torch
+import torch._dynamo.test_case
+import unittest
+from torch._dynamo.test_case import CPythonTestCase
+from torch.testing._internal.common_utils import (
+ skipIfTorchDynamo,
+ run_tests,
+)
+
+__TestCase = CPythonTestCase
+
+
+# redirect import statements
+import sys
+import importlib.abc
+
+redirect_imports = (
+ "test.mapping_tests",
+ "test.typinganndata",
+ "test.test_grammar",
+ "test.test_math",
+ "test.test_iter",
+ "test.typinganndata.ann_module",
+)
+
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
+ def find_spec(self, fullname, path, target=None):
+ # Check if the import is the problematic one
+ if fullname in redirect_imports:
+ try:
+ # Attempt to import the standalone module
+ name = fullname.removeprefix("test.")
+ r = importlib.import_module(name)
+ # Redirect the module in sys.modules
+ sys.modules[fullname] = r
+ # Return a module spec from the found module
+ return importlib.util.find_spec(name)
+ except ImportError:
+ return None
+ return None
+
+# Add the custom finder to sys.meta_path
+sys.meta_path.insert(0, RedirectImportFinder())
+
+
+# ======= END DYNAMO PATCH =======
+
# Test iterators.
import sys
@@ -104,12 +161,10 @@ class EmptyIterClass:
# Main test suite
-class TestCase(unittest.TestCase):
+class TestCase(__TestCase):
# Helper to check that an iterator returns a given sequence
def check_iterator(self, it, seq, pickle=True):
- if pickle:
- self.check_pickle(it, seq)
res = []
while 1:
try:
@@ -121,8 +176,6 @@ class TestCase(unittest.TestCase):
# Helper to check that a for loop generates a given sequence
def check_for_loop(self, expr, seq, pickle=True):
- if pickle:
- self.check_pickle(iter(expr), seq)
res = []
for val in expr:
res.append(val)
@@ -261,19 +314,20 @@ class TestCase(unittest.TestCase):
def run(builtin_name, item, sentinel=None):
it = iter(item) if sentinel is None else iter(item, sentinel)
- class CustomStr:
- def __init__(self, name, iterator):
- self.name = name
- self.iterator = iterator
- def __hash__(self):
- return hash(self.name)
- def __eq__(self, other):
- # Here we exhaust our iterator, possibly changing
- # its `it_seq` pointer to NULL
- # The `__reduce__` call should correctly get
- # the pointers after this call
- list(self.iterator)
- return other == self.name
+ with torch._dynamo.error_on_graph_break(False):
+ class CustomStr:
+ def __init__(self, name, iterator):
+ self.name = name
+ self.iterator = iterator
+ def __hash__(self):
+ return hash(self.name)
+ def __eq__(self, other):
+ # Here we exhaust our iterator, possibly changing
+ # its `it_seq` pointer to NULL
+ # The `__reduce__` call should correctly get
+ # the pointers after this call
+ list(self.iterator)
+ return other == self.name
# del is required here
# to not prematurely call __eq__ from
@@ -323,9 +377,10 @@ class TestCase(unittest.TestCase):
# Test a new_style class with __iter__ but no next() method
def test_new_style_iter_class(self):
- class IterClass(object):
- def __iter__(self):
- return self
+ with torch._dynamo.error_on_graph_break(False):
+ class IterClass(object):
+ def __iter__(self):
+ return self
self.assertRaises(TypeError, iter, IterClass())
# Test two-argument iter() with callable instance
@@ -394,11 +449,12 @@ class TestCase(unittest.TestCase):
# Test exception propagation through sequence iterator
def test_exception_sequence(self):
- class MySequenceClass(SequenceClass):
- def __getitem__(self, i):
- if i == 10:
- raise RuntimeError
- return SequenceClass.__getitem__(self, i)
+ with torch._dynamo.error_on_graph_break(False):
+ class MySequenceClass(SequenceClass):
+ def __getitem__(self, i):
+ if i == 10:
+ raise RuntimeError
+ return SequenceClass.__getitem__(self, i)
res = []
try:
for x in MySequenceClass(20):
@@ -410,11 +466,12 @@ class TestCase(unittest.TestCase):
# Test for StopIteration from __getitem__
def test_stop_sequence(self):
- class MySequenceClass(SequenceClass):
- def __getitem__(self, i):
- if i == 10:
- raise StopIteration
- return SequenceClass.__getitem__(self, i)
+ with torch._dynamo.error_on_graph_break(False):
+ class MySequenceClass(SequenceClass):
+ def __getitem__(self, i):
+ if i == 10:
+ raise StopIteration
+ return SequenceClass.__getitem__(self, i)
self.check_for_loop(MySequenceClass(20), list(range(10)), pickle=False)
# Test a big range
@@ -541,32 +598,34 @@ class TestCase(unittest.TestCase):
self.assertRaises(TypeError, filter, None, list)
self.assertRaises(TypeError, filter, None, 42)
- class Boolean:
- def __init__(self, truth):
- self.truth = truth
- def __bool__(self):
- return self.truth
+ with torch._dynamo.error_on_graph_break(False):
+ class Boolean:
+ def __init__(self, truth):
+ self.truth = truth
+ def __bool__(self):
+ return self.truth
bTrue = Boolean(True)
bFalse = Boolean(False)
- class Seq:
- def __init__(self, *args):
- self.vals = args
- def __iter__(self):
- class SeqIter:
- def __init__(self, vals):
- self.vals = vals
- self.i = 0
- def __iter__(self):
- return self
- def __next__(self):
- i = self.i
- self.i = i + 1
- if i < len(self.vals):
- return self.vals[i]
- else:
- raise StopIteration
- return SeqIter(self.vals)
+ with torch._dynamo.error_on_graph_break(False):
+ class Seq:
+ def __init__(self, *args):
+ self.vals = args
+ def __iter__(self):
+ class SeqIter:
+ def __init__(self, vals):
+ self.vals = vals
+ self.i = 0
+ def __iter__(self):
+ return self
+ def __next__(self):
+ i = self.i
+ self.i = i + 1
+ if i < len(self.vals):
+ return self.vals[i]
+ else:
+ raise StopIteration
+ return SeqIter(self.vals)
seq = Seq(*([bTrue, bFalse] * 25))
self.assertEqual(list(filter(lambda x: not x, seq)), [bFalse]*25)
@@ -635,6 +694,7 @@ class TestCase(unittest.TestCase):
pass
# Test zip()'s use of iterators.
+ @skipIfTorchDynamo("infinite loop")
def test_builtin_zip(self):
self.assertEqual(list(zip()), [])
self.assertEqual(list(zip(*[])), [])
@@ -653,17 +713,18 @@ class TestCase(unittest.TestCase):
self.assertEqual(list(d.items()), list(zip(d, d.values())))
# Generate all ints starting at constructor arg.
- class IntsFrom:
- def __init__(self, start):
- self.i = start
+ with torch._dynamo.error_on_graph_break(False):
+ class IntsFrom:
+ def __init__(self, start):
+ self.i = start
- def __iter__(self):
- return self
+ def __iter__(self):
+ return self
- def __next__(self):
- i = self.i
- self.i = i+1
- return i
+ def __next__(self):
+ i = self.i
+ self.i = i+1
+ return i
f = open(TESTFN, "w", encoding="utf-8")
try:
@@ -686,19 +747,20 @@ class TestCase(unittest.TestCase):
self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)])
# Classes that lie about their lengths.
- class NoGuessLen5:
- def __getitem__(self, i):
- if i >= 5:
- raise IndexError
- return i
+ with torch._dynamo.error_on_graph_break(False):
+ class NoGuessLen5:
+ def __getitem__(self, i):
+ if i >= 5:
+ raise IndexError
+ return i
- class Guess3Len5(NoGuessLen5):
- def __len__(self):
- return 3
+ class Guess3Len5(NoGuessLen5):
+ def __len__(self):
+ return 3
- class Guess30Len5(NoGuessLen5):
- def __len__(self):
- return 30
+ class Guess30Len5(NoGuessLen5):
+ def __len__(self):
+ return 30
def lzip(*args):
return list(zip(*args))
@@ -718,20 +780,21 @@ class TestCase(unittest.TestCase):
# This class inserts a Unicode object into its argument's natural
# iteration, in the 3rd position.
- class OhPhooey:
- def __init__(self, seq):
- self.it = iter(seq)
- self.i = 0
+ with torch._dynamo.error_on_graph_break(False):
+ class OhPhooey:
+ def __init__(self, seq):
+ self.it = iter(seq)
+ self.i = 0
- def __iter__(self):
- return self
+ def __iter__(self):
+ return self
- def __next__(self):
- i = self.i
- self.i = i+1
- if i == 2:
- return "fooled you!"
- return next(self.it)
+ def __next__(self):
+ i = self.i
+ self.i = i+1
+ if i == 2:
+ return "fooled you!"
+ return next(self.it)
f = open(TESTFN, "w", encoding="utf-8")
try:
@@ -895,29 +958,30 @@ class TestCase(unittest.TestCase):
f.writelines({})
# Try a big chunk too.
- class Iterator:
- def __init__(self, start, finish):
- self.start = start
- self.finish = finish
- self.i = self.start
+ with torch._dynamo.error_on_graph_break(False):
+ class Iterator:
+ def __init__(self, start, finish):
+ self.start = start
+ self.finish = finish
+ self.i = self.start
- def __next__(self):
- if self.i >= self.finish:
- raise StopIteration
- result = str(self.i) + '\n'
- self.i += 1
- return result
+ def __next__(self):
+ if self.i >= self.finish:
+ raise StopIteration
+ result = str(self.i) + '\n'
+ self.i += 1
+ return result
- def __iter__(self):
- return self
+ def __iter__(self):
+ return self
- class Whatever:
- def __init__(self, start, finish):
- self.start = start
- self.finish = finish
+ class Whatever:
+ def __init__(self, start, finish):
+ self.start = start
+ self.finish = finish
- def __iter__(self):
- return Iterator(self.start, self.finish)
+ def __iter__(self):
+ return Iterator(self.start, self.finish)
f.writelines(Whatever(6, 6+2000))
f.close()
@@ -990,15 +1054,16 @@ class TestCase(unittest.TestCase):
@cpython_only
def test_ref_counting_behavior(self):
- class C(object):
- count = 0
- def __new__(cls):
- cls.count += 1
- return object.__new__(cls)
- def __del__(self):
- cls = self.__class__
- assert cls.count > 0
- cls.count -= 1
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ count = 0
+ def __new__(cls):
+ cls.count += 1
+ return object.__new__(cls)
+ def __del__(self):
+ cls = self.__class__
+ assert cls.count > 0
+ cls.count -= 1
x = C()
self.assertEqual(C.count, 1)
del x
@@ -1089,12 +1154,13 @@ class TestCase(unittest.TestCase):
def test_3720(self):
# Avoid a crash, when an iterator deletes its next() method.
- class BadIterator(object):
- def __iter__(self):
- return self
- def __next__(self):
- del BadIterator.__next__
- return 1
+ with torch._dynamo.error_on_graph_break(False):
+ class BadIterator(object):
+ def __iter__(self):
+ return self
+ def __next__(self):
+ del BadIterator.__next__
+ return 1
try:
for i in BadIterator() :
@@ -1187,4 +1253,4 @@ class TestCase(unittest.TestCase):
if __name__ == "__main__":
- unittest.main()
+ run_tests()