mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Support itertools.filterfalse (#160596)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160596 Approved by: https://github.com/guilhermeleobas
This commit is contained in:
parent
450517f346
commit
162bf78df6
|
|
@ -1,5 +1,5 @@
|
|||
diff --git a/test/dynamo/cpython/3_13/test_itertools.py b/test/dynamo/cpython/3_13/test_itertools.py
|
||||
index 7d5ba727389..d15d83a2184 100644
|
||||
index 7d5ba727389..8d462284884 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_itertools.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_itertools.py
|
||||
@@ -1,3 +1,25 @@
|
||||
|
|
@ -151,7 +151,7 @@ index 7d5ba727389..d15d83a2184 100644
|
|||
_, g = next(it)
|
||||
next(it)
|
||||
next(it)
|
||||
@@ -1002,27 +1015,29 @@ class TestBasicOps(unittest.TestCase):
|
||||
@@ -1002,29 +1015,30 @@ class TestBasicOps(unittest.TestCase):
|
||||
self.assertEqual(list(filter(None, [0,1,0,2,0])), [1,2])
|
||||
self.assertEqual(list(filter(bool, [0,1,0,2,0])), [1,2])
|
||||
self.assertEqual(take(4, filter(isEven, count())), [0,2,4,6])
|
||||
|
|
@ -198,8 +198,24 @@ index 7d5ba727389..d15d83a2184 100644
|
|||
+ # c = filter(isEven, range(6))
|
||||
+ # self.pickletest(proto, c)
|
||||
|
||||
@pickle_deprecated
|
||||
- @pickle_deprecated
|
||||
def test_filterfalse(self):
|
||||
self.assertEqual(list(filterfalse(isEven, range(6))), [1,3,5])
|
||||
self.assertEqual(list(filterfalse(None, [0,1,0,2,0])), [0,0,0])
|
||||
@@ -1034,9 +1048,10 @@ class TestBasicOps(unittest.TestCase):
|
||||
self.assertRaises(TypeError, filterfalse, lambda x:x)
|
||||
self.assertRaises(TypeError, filterfalse, lambda x:x, range(6), 7)
|
||||
self.assertRaises(TypeError, filterfalse, isEven, 3)
|
||||
- self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
|
||||
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
- self.pickletest(proto, filterfalse(isEven, range(6)))
|
||||
+ with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
+ self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
|
||||
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
+ self.pickletest(proto, filterfalse(isEven, range(6)))
|
||||
|
||||
def test_zip(self):
|
||||
# XXX This is rather silly now that builtin zip() calls zip()...
|
||||
@@ -1047,8 +1062,8 @@ class TestBasicOps(unittest.TestCase):
|
||||
self.assertEqual(take(3,zip('abcdef', count())), lzip('abcdef', range(3)))
|
||||
self.assertEqual(list(zip('abcdef')), lzip('abcdef'))
|
||||
|
|
|
|||
|
|
@ -1039,7 +1039,6 @@ class TestBasicOps(__TestCase):
|
|||
# c = filter(isEven, range(6))
|
||||
# self.pickletest(proto, c)
|
||||
|
||||
@pickle_deprecated
|
||||
def test_filterfalse(self):
|
||||
self.assertEqual(list(filterfalse(isEven, range(6))), [1,3,5])
|
||||
self.assertEqual(list(filterfalse(None, [0,1,0,2,0])), [0,0,0])
|
||||
|
|
@ -1049,9 +1048,10 @@ class TestBasicOps(__TestCase):
|
|||
self.assertRaises(TypeError, filterfalse, lambda x:x)
|
||||
self.assertRaises(TypeError, filterfalse, lambda x:x, range(6), 7)
|
||||
self.assertRaises(TypeError, filterfalse, isEven, 3)
|
||||
self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
self.pickletest(proto, filterfalse(isEven, range(6)))
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
self.pickletest(proto, filterfalse(isEven, range(6)))
|
||||
|
||||
def test_zip(self):
|
||||
# XXX This is rather silly now that builtin zip() calls zip()...
|
||||
|
|
|
|||
|
|
@ -310,6 +310,12 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||
itertools.permutations(filter(lambda x: True, [1, 2]))
|
||||
return a
|
||||
|
||||
@make_test
|
||||
def test_itertools_filterfalse_basic(a, b):
|
||||
for x in itertools.filterfalse(lambda x: x > 0, [-0.5, 0, 0.5]):
|
||||
a += x
|
||||
return a
|
||||
|
||||
@make_test
|
||||
def test_itertools_chain(a, b):
|
||||
v = a
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ __all__ = [
|
|||
"compress",
|
||||
"cycle",
|
||||
"dropwhile",
|
||||
"filterfalse",
|
||||
"islice",
|
||||
"tee",
|
||||
"zip_longest",
|
||||
|
|
@ -123,6 +124,15 @@ def dropwhile(predicate: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[
|
|||
yield from iterator
|
||||
|
||||
|
||||
@substitute_in_graph(itertools.filterfalse, is_embedded_type=True) # type: ignore[arg-type]
|
||||
def filterfalse(function: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]:
|
||||
it = iter(iterable)
|
||||
if function is None:
|
||||
return filter(operator.not_, it)
|
||||
else:
|
||||
return filter(lambda x: not function(x), it)
|
||||
|
||||
|
||||
# Reference: https://docs.python.org/3/library/itertools.html#itertools.islice
|
||||
@substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type]
|
||||
def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user