[dynamo] Support itertools.filterfalse (#160596)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160596
Approved by: https://github.com/guilhermeleobas
This commit is contained in:
Rob Timpe 2025-08-15 21:14:20 +00:00 committed by PyTorch MergeBot
parent 450517f346
commit 162bf78df6
6 changed files with 39 additions and 7 deletions

View File

@ -1,5 +1,5 @@
diff --git a/test/dynamo/cpython/3_13/test_itertools.py b/test/dynamo/cpython/3_13/test_itertools.py
index 7d5ba727389..d15d83a2184 100644
index 7d5ba727389..8d462284884 100644
--- a/test/dynamo/cpython/3_13/test_itertools.py
+++ b/test/dynamo/cpython/3_13/test_itertools.py
@@ -1,3 +1,25 @@
@ -151,7 +151,7 @@ index 7d5ba727389..d15d83a2184 100644
_, g = next(it)
next(it)
next(it)
@@ -1002,27 +1015,29 @@ class TestBasicOps(unittest.TestCase):
@@ -1002,29 +1015,30 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(filter(None, [0,1,0,2,0])), [1,2])
self.assertEqual(list(filter(bool, [0,1,0,2,0])), [1,2])
self.assertEqual(take(4, filter(isEven, count())), [0,2,4,6])
@ -198,8 +198,24 @@ index 7d5ba727389..d15d83a2184 100644
+ # c = filter(isEven, range(6))
+ # self.pickletest(proto, c)
@pickle_deprecated
- @pickle_deprecated
def test_filterfalse(self):
self.assertEqual(list(filterfalse(isEven, range(6))), [1,3,5])
self.assertEqual(list(filterfalse(None, [0,1,0,2,0])), [0,0,0])
@@ -1034,9 +1048,10 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(TypeError, filterfalse, lambda x:x)
self.assertRaises(TypeError, filterfalse, lambda x:x, range(6), 7)
self.assertRaises(TypeError, filterfalse, isEven, 3)
- self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- self.pickletest(proto, filterfalse(isEven, range(6)))
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ self.pickletest(proto, filterfalse(isEven, range(6)))
def test_zip(self):
# XXX This is rather silly now that builtin zip() calls zip()...
@@ -1047,8 +1062,8 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(take(3,zip('abcdef', count())), lzip('abcdef', range(3)))
self.assertEqual(list(zip('abcdef')), lzip('abcdef'))

View File

@ -1039,7 +1039,6 @@ class TestBasicOps(__TestCase):
# c = filter(isEven, range(6))
# self.pickletest(proto, c)
@pickle_deprecated
def test_filterfalse(self):
self.assertEqual(list(filterfalse(isEven, range(6))), [1,3,5])
self.assertEqual(list(filterfalse(None, [0,1,0,2,0])), [0,0,0])
@ -1049,9 +1048,10 @@ class TestBasicOps(__TestCase):
self.assertRaises(TypeError, filterfalse, lambda x:x)
self.assertRaises(TypeError, filterfalse, lambda x:x, range(6), 7)
self.assertRaises(TypeError, filterfalse, isEven, 3)
self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
self.pickletest(proto, filterfalse(isEven, range(6)))
with torch._dynamo.set_fullgraph(fullgraph=False):
self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
self.pickletest(proto, filterfalse(isEven, range(6)))
def test_zip(self):
# XXX This is rather silly now that builtin zip() calls zip()...

View File

@ -310,6 +310,12 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
itertools.permutations(filter(lambda x: True, [1, 2]))
return a
@make_test
def test_itertools_filterfalse_basic(a, b):
for x in itertools.filterfalse(lambda x: x > 0, [-0.5, 0, 0.5]):
a += x
return a
@make_test
def test_itertools_chain(a, b):
v = a

View File

@ -24,6 +24,7 @@ __all__ = [
"compress",
"cycle",
"dropwhile",
"filterfalse",
"islice",
"tee",
"zip_longest",
@ -123,6 +124,15 @@ def dropwhile(predicate: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[
yield from iterator
@substitute_in_graph(itertools.filterfalse, is_embedded_type=True) # type: ignore[arg-type]
def filterfalse(function: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]:
it = iter(iterable)
if function is None:
return filter(operator.not_, it)
else:
return filter(lambda x: not function(x), it)
# Reference: https://docs.python.org/3/library/itertools.html#itertools.islice
@substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type]
def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]: