mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph. I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet). cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739 Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
435 lines
15 KiB
Diff
435 lines
15 KiB
Diff
diff --git a/test/dynamo/cpython/3_13/test_iter.py b/test/dynamo/cpython/3_13/test_iter.py
|
|
index 1b9f3cf7624..6560c7423a6 100644
|
|
--- a/test/dynamo/cpython/3_13/test_iter.py
|
|
+++ b/test/dynamo/cpython/3_13/test_iter.py
|
|
@@ -1,3 +1,60 @@
|
|
+# ======= BEGIN Dynamo patch =======
|
|
+# Owner(s): ["module: dynamo"]
|
|
+
|
|
+# ruff: noqa
|
|
+# flake8: noqa
|
|
+
|
|
+# Test copied from
|
|
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_iter.py
|
|
+
|
|
+import sys
|
|
+import torch
|
|
+import torch._dynamo.test_case
|
|
+import unittest
|
|
+from torch._dynamo.test_case import CPythonTestCase
|
|
+from torch.testing._internal.common_utils import (
|
|
+ skipIfTorchDynamo,
|
|
+ run_tests,
|
|
+)
|
|
+
|
|
+__TestCase = CPythonTestCase
|
|
+
|
|
+
|
|
+# redirect import statements
|
|
+import sys
|
|
+import importlib.abc
|
|
+
|
|
+redirect_imports = (
|
|
+ "test.mapping_tests",
|
|
+ "test.typinganndata",
|
|
+ "test.test_grammar",
|
|
+ "test.test_math",
|
|
+ "test.test_iter",
|
|
+ "test.typinganndata.ann_module",
|
|
+)
|
|
+
|
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
|
+ def find_spec(self, fullname, path, target=None):
|
|
+ # Check if the import is the problematic one
|
|
+ if fullname in redirect_imports:
|
|
+ try:
|
|
+ # Attempt to import the standalone module
|
|
+ name = fullname.removeprefix("test.")
|
|
+ r = importlib.import_module(name)
|
|
+ # Redirect the module in sys.modules
|
|
+ sys.modules[fullname] = r
|
|
+ # Return a module spec from the found module
|
|
+ return importlib.util.find_spec(name)
|
|
+ except ImportError:
|
|
+ return None
|
|
+ return None
|
|
+
|
|
+# Add the custom finder to sys.meta_path
|
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
|
+
|
|
+
|
|
+# ======= END DYNAMO PATCH =======
|
|
+
|
|
# Test iterators.
|
|
|
|
import sys
|
|
@@ -104,12 +161,10 @@ class EmptyIterClass:
|
|
|
|
# Main test suite
|
|
|
|
-class TestCase(unittest.TestCase):
|
|
+class TestCase(__TestCase):
|
|
|
|
# Helper to check that an iterator returns a given sequence
|
|
def check_iterator(self, it, seq, pickle=True):
|
|
- if pickle:
|
|
- self.check_pickle(it, seq)
|
|
res = []
|
|
while 1:
|
|
try:
|
|
@@ -121,8 +176,6 @@ class TestCase(unittest.TestCase):
|
|
|
|
# Helper to check that a for loop generates a given sequence
|
|
def check_for_loop(self, expr, seq, pickle=True):
|
|
- if pickle:
|
|
- self.check_pickle(iter(expr), seq)
|
|
res = []
|
|
for val in expr:
|
|
res.append(val)
|
|
@@ -261,19 +314,20 @@ class TestCase(unittest.TestCase):
|
|
def run(builtin_name, item, sentinel=None):
|
|
it = iter(item) if sentinel is None else iter(item, sentinel)
|
|
|
|
- class CustomStr:
|
|
- def __init__(self, name, iterator):
|
|
- self.name = name
|
|
- self.iterator = iterator
|
|
- def __hash__(self):
|
|
- return hash(self.name)
|
|
- def __eq__(self, other):
|
|
- # Here we exhaust our iterator, possibly changing
|
|
- # its `it_seq` pointer to NULL
|
|
- # The `__reduce__` call should correctly get
|
|
- # the pointers after this call
|
|
- list(self.iterator)
|
|
- return other == self.name
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class CustomStr:
|
|
+ def __init__(self, name, iterator):
|
|
+ self.name = name
|
|
+ self.iterator = iterator
|
|
+ def __hash__(self):
|
|
+ return hash(self.name)
|
|
+ def __eq__(self, other):
|
|
+ # Here we exhaust our iterator, possibly changing
|
|
+ # its `it_seq` pointer to NULL
|
|
+ # The `__reduce__` call should correctly get
|
|
+ # the pointers after this call
|
|
+ list(self.iterator)
|
|
+ return other == self.name
|
|
|
|
# del is required here
|
|
# to not prematurely call __eq__ from
|
|
@@ -323,9 +377,10 @@ class TestCase(unittest.TestCase):
|
|
|
|
# Test a new_style class with __iter__ but no next() method
|
|
def test_new_style_iter_class(self):
|
|
- class IterClass(object):
|
|
- def __iter__(self):
|
|
- return self
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class IterClass(object):
|
|
+ def __iter__(self):
|
|
+ return self
|
|
self.assertRaises(TypeError, iter, IterClass())
|
|
|
|
# Test two-argument iter() with callable instance
|
|
@@ -394,11 +449,12 @@ class TestCase(unittest.TestCase):
|
|
|
|
# Test exception propagation through sequence iterator
|
|
def test_exception_sequence(self):
|
|
- class MySequenceClass(SequenceClass):
|
|
- def __getitem__(self, i):
|
|
- if i == 10:
|
|
- raise RuntimeError
|
|
- return SequenceClass.__getitem__(self, i)
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MySequenceClass(SequenceClass):
|
|
+ def __getitem__(self, i):
|
|
+ if i == 10:
|
|
+ raise RuntimeError
|
|
+ return SequenceClass.__getitem__(self, i)
|
|
res = []
|
|
try:
|
|
for x in MySequenceClass(20):
|
|
@@ -410,11 +466,12 @@ class TestCase(unittest.TestCase):
|
|
|
|
# Test for StopIteration from __getitem__
|
|
def test_stop_sequence(self):
|
|
- class MySequenceClass(SequenceClass):
|
|
- def __getitem__(self, i):
|
|
- if i == 10:
|
|
- raise StopIteration
|
|
- return SequenceClass.__getitem__(self, i)
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MySequenceClass(SequenceClass):
|
|
+ def __getitem__(self, i):
|
|
+ if i == 10:
|
|
+ raise StopIteration
|
|
+ return SequenceClass.__getitem__(self, i)
|
|
self.check_for_loop(MySequenceClass(20), list(range(10)), pickle=False)
|
|
|
|
# Test a big range
|
|
@@ -541,32 +598,34 @@ class TestCase(unittest.TestCase):
|
|
self.assertRaises(TypeError, filter, None, list)
|
|
self.assertRaises(TypeError, filter, None, 42)
|
|
|
|
- class Boolean:
|
|
- def __init__(self, truth):
|
|
- self.truth = truth
|
|
- def __bool__(self):
|
|
- return self.truth
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Boolean:
|
|
+ def __init__(self, truth):
|
|
+ self.truth = truth
|
|
+ def __bool__(self):
|
|
+ return self.truth
|
|
bTrue = Boolean(True)
|
|
bFalse = Boolean(False)
|
|
|
|
- class Seq:
|
|
- def __init__(self, *args):
|
|
- self.vals = args
|
|
- def __iter__(self):
|
|
- class SeqIter:
|
|
- def __init__(self, vals):
|
|
- self.vals = vals
|
|
- self.i = 0
|
|
- def __iter__(self):
|
|
- return self
|
|
- def __next__(self):
|
|
- i = self.i
|
|
- self.i = i + 1
|
|
- if i < len(self.vals):
|
|
- return self.vals[i]
|
|
- else:
|
|
- raise StopIteration
|
|
- return SeqIter(self.vals)
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Seq:
|
|
+ def __init__(self, *args):
|
|
+ self.vals = args
|
|
+ def __iter__(self):
|
|
+ class SeqIter:
|
|
+ def __init__(self, vals):
|
|
+ self.vals = vals
|
|
+ self.i = 0
|
|
+ def __iter__(self):
|
|
+ return self
|
|
+ def __next__(self):
|
|
+ i = self.i
|
|
+ self.i = i + 1
|
|
+ if i < len(self.vals):
|
|
+ return self.vals[i]
|
|
+ else:
|
|
+ raise StopIteration
|
|
+ return SeqIter(self.vals)
|
|
|
|
seq = Seq(*([bTrue, bFalse] * 25))
|
|
self.assertEqual(list(filter(lambda x: not x, seq)), [bFalse]*25)
|
|
@@ -635,6 +694,7 @@ class TestCase(unittest.TestCase):
|
|
pass
|
|
|
|
# Test zip()'s use of iterators.
|
|
+ @skipIfTorchDynamo("infinite loop")
|
|
def test_builtin_zip(self):
|
|
self.assertEqual(list(zip()), [])
|
|
self.assertEqual(list(zip(*[])), [])
|
|
@@ -653,17 +713,18 @@ class TestCase(unittest.TestCase):
|
|
self.assertEqual(list(d.items()), list(zip(d, d.values())))
|
|
|
|
# Generate all ints starting at constructor arg.
|
|
- class IntsFrom:
|
|
- def __init__(self, start):
|
|
- self.i = start
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class IntsFrom:
|
|
+ def __init__(self, start):
|
|
+ self.i = start
|
|
|
|
- def __iter__(self):
|
|
- return self
|
|
+ def __iter__(self):
|
|
+ return self
|
|
|
|
- def __next__(self):
|
|
- i = self.i
|
|
- self.i = i+1
|
|
- return i
|
|
+ def __next__(self):
|
|
+ i = self.i
|
|
+ self.i = i+1
|
|
+ return i
|
|
|
|
f = open(TESTFN, "w", encoding="utf-8")
|
|
try:
|
|
@@ -686,19 +747,20 @@ class TestCase(unittest.TestCase):
|
|
self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)])
|
|
|
|
# Classes that lie about their lengths.
|
|
- class NoGuessLen5:
|
|
- def __getitem__(self, i):
|
|
- if i >= 5:
|
|
- raise IndexError
|
|
- return i
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class NoGuessLen5:
|
|
+ def __getitem__(self, i):
|
|
+ if i >= 5:
|
|
+ raise IndexError
|
|
+ return i
|
|
|
|
- class Guess3Len5(NoGuessLen5):
|
|
- def __len__(self):
|
|
- return 3
|
|
+ class Guess3Len5(NoGuessLen5):
|
|
+ def __len__(self):
|
|
+ return 3
|
|
|
|
- class Guess30Len5(NoGuessLen5):
|
|
- def __len__(self):
|
|
- return 30
|
|
+ class Guess30Len5(NoGuessLen5):
|
|
+ def __len__(self):
|
|
+ return 30
|
|
|
|
def lzip(*args):
|
|
return list(zip(*args))
|
|
@@ -718,20 +780,21 @@ class TestCase(unittest.TestCase):
|
|
|
|
# This class inserts a Unicode object into its argument's natural
|
|
# iteration, in the 3rd position.
|
|
- class OhPhooey:
|
|
- def __init__(self, seq):
|
|
- self.it = iter(seq)
|
|
- self.i = 0
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class OhPhooey:
|
|
+ def __init__(self, seq):
|
|
+ self.it = iter(seq)
|
|
+ self.i = 0
|
|
|
|
- def __iter__(self):
|
|
- return self
|
|
+ def __iter__(self):
|
|
+ return self
|
|
|
|
- def __next__(self):
|
|
- i = self.i
|
|
- self.i = i+1
|
|
- if i == 2:
|
|
- return "fooled you!"
|
|
- return next(self.it)
|
|
+ def __next__(self):
|
|
+ i = self.i
|
|
+ self.i = i+1
|
|
+ if i == 2:
|
|
+ return "fooled you!"
|
|
+ return next(self.it)
|
|
|
|
f = open(TESTFN, "w", encoding="utf-8")
|
|
try:
|
|
@@ -895,29 +958,30 @@ class TestCase(unittest.TestCase):
|
|
f.writelines({})
|
|
|
|
# Try a big chunk too.
|
|
- class Iterator:
|
|
- def __init__(self, start, finish):
|
|
- self.start = start
|
|
- self.finish = finish
|
|
- self.i = self.start
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Iterator:
|
|
+ def __init__(self, start, finish):
|
|
+ self.start = start
|
|
+ self.finish = finish
|
|
+ self.i = self.start
|
|
|
|
- def __next__(self):
|
|
- if self.i >= self.finish:
|
|
- raise StopIteration
|
|
- result = str(self.i) + '\n'
|
|
- self.i += 1
|
|
- return result
|
|
+ def __next__(self):
|
|
+ if self.i >= self.finish:
|
|
+ raise StopIteration
|
|
+ result = str(self.i) + '\n'
|
|
+ self.i += 1
|
|
+ return result
|
|
|
|
- def __iter__(self):
|
|
- return self
|
|
+ def __iter__(self):
|
|
+ return self
|
|
|
|
- class Whatever:
|
|
- def __init__(self, start, finish):
|
|
- self.start = start
|
|
- self.finish = finish
|
|
+ class Whatever:
|
|
+ def __init__(self, start, finish):
|
|
+ self.start = start
|
|
+ self.finish = finish
|
|
|
|
- def __iter__(self):
|
|
- return Iterator(self.start, self.finish)
|
|
+ def __iter__(self):
|
|
+ return Iterator(self.start, self.finish)
|
|
|
|
f.writelines(Whatever(6, 6+2000))
|
|
f.close()
|
|
@@ -990,15 +1054,16 @@ class TestCase(unittest.TestCase):
|
|
|
|
@cpython_only
|
|
def test_ref_counting_behavior(self):
|
|
- class C(object):
|
|
- count = 0
|
|
- def __new__(cls):
|
|
- cls.count += 1
|
|
- return object.__new__(cls)
|
|
- def __del__(self):
|
|
- cls = self.__class__
|
|
- assert cls.count > 0
|
|
- cls.count -= 1
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class C(object):
|
|
+ count = 0
|
|
+ def __new__(cls):
|
|
+ cls.count += 1
|
|
+ return object.__new__(cls)
|
|
+ def __del__(self):
|
|
+ cls = self.__class__
|
|
+ assert cls.count > 0
|
|
+ cls.count -= 1
|
|
x = C()
|
|
self.assertEqual(C.count, 1)
|
|
del x
|
|
@@ -1089,12 +1154,13 @@ class TestCase(unittest.TestCase):
|
|
|
|
def test_3720(self):
|
|
# Avoid a crash, when an iterator deletes its next() method.
|
|
- class BadIterator(object):
|
|
- def __iter__(self):
|
|
- return self
|
|
- def __next__(self):
|
|
- del BadIterator.__next__
|
|
- return 1
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class BadIterator(object):
|
|
+ def __iter__(self):
|
|
+ return self
|
|
+ def __next__(self):
|
|
+ del BadIterator.__next__
|
|
+ return 1
|
|
|
|
try:
|
|
for i in BadIterator() :
|
|
@@ -1187,4 +1253,4 @@ class TestCase(unittest.TestCase):
|
|
|
|
|
|
if __name__ == "__main__":
|
|
- unittest.main()
|
|
+ run_tests()
|