[dynamo] Make filter handle None as filter function (#159500)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159500
Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519
ghstack dependencies: #158774, #159102
This commit is contained in:
Rob Timpe 2025-07-30 19:14:39 +00:00 committed by PyTorch MergeBot
parent fa68216ca1
commit 8becf646ef
4 changed files with 87 additions and 33 deletions

View File

@ -1,5 +1,5 @@
diff --git a/test/dynamo/cpython/3_13/test_itertools.py b/test/dynamo/cpython/3_13/test_itertools.py
index 7d5ba727389..7c439cb420b 100644
index 7d5ba727389..ef73c7f0ce1 100644
--- a/test/dynamo/cpython/3_13/test_itertools.py
+++ b/test/dynamo/cpython/3_13/test_itertools.py
@@ -1,3 +1,25 @@
@ -117,7 +117,56 @@ index 7d5ba727389..7c439cb420b 100644
_, g = next(it)
next(it)
next(it)
@@ -1038,6 +1060,7 @@ class TestBasicOps(unittest.TestCase):
@@ -1002,27 +1024,29 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(filter(None, [0,1,0,2,0])), [1,2])
self.assertEqual(list(filter(bool, [0,1,0,2,0])), [1,2])
self.assertEqual(take(4, filter(isEven, count())), [0,2,4,6])
- self.assertRaises(TypeError, filter)
- self.assertRaises(TypeError, filter, lambda x:x)
- self.assertRaises(TypeError, filter, lambda x:x, range(6), 7)
- self.assertRaises(TypeError, filter, isEven, 3)
- self.assertRaises(TypeError, next, filter(range(6), range(6)))
+ # these tests raise dynamo exceptions, not TypeError
+ # self.assertRaises(TypeError, filter)
+ # self.assertRaises(TypeError, filter, lambda x:x)
+ # self.assertRaises(TypeError, filter, lambda x:x, range(6), 7)
+ # self.assertRaises(TypeError, filter, isEven, 3)
+ # dynamo raises Unsupported in this case
+ # self.assertRaises(TypeError, next, filter(range(6), range(6)))
# check copy, deepcopy, pickle
- ans = [0,2,4]
-
- c = filter(isEven, range(6))
- self.assertEqual(list(copy.copy(c)), ans)
- c = filter(isEven, range(6))
- self.assertEqual(list(copy.deepcopy(c)), ans)
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- c = filter(isEven, range(6))
- self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans)
- next(c)
- self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans[1:])
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- c = filter(isEven, range(6))
- self.pickletest(proto, c)
+ # ans = [0,2,4]
+
+ # c = filter(isEven, range(6))
+ # self.assertEqual(list(copy.copy(c)), ans)
+ # c = filter(isEven, range(6))
+ # self.assertEqual(list(copy.deepcopy(c)), ans)
+ # for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ # c = filter(isEven, range(6))
+ # self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans)
+ # next(c)
+ # self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans[1:])
+ # for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ # c = filter(isEven, range(6))
+ # self.pickletest(proto, c)
@pickle_deprecated
def test_filterfalse(self):
@@ -1038,6 +1062,7 @@ class TestBasicOps(unittest.TestCase):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
self.pickletest(proto, filterfalse(isEven, range(6)))
@ -125,7 +174,7 @@ index 7d5ba727389..7c439cb420b 100644
def test_zip(self):
# XXX This is rather silly now that builtin zip() calls zip()...
ans = [(x,y) for x, y in zip('abc',count())]
@@ -1082,6 +1105,7 @@ class TestBasicOps(unittest.TestCase):
@@ -1082,6 +1107,7 @@ class TestBasicOps(unittest.TestCase):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
self.pickletest(proto, zip('abc', count()))
@ -133,7 +182,7 @@ index 7d5ba727389..7c439cb420b 100644
def test_ziplongest(self):
for args in [
['abc', range(6)],
@@ -1767,6 +1791,7 @@ class TestBasicOps(unittest.TestCase):
@@ -1767,6 +1793,7 @@ class TestBasicOps(unittest.TestCase):
script_helper.assert_python_ok("-c", script)
# Issue 13454: Crash when deleting backward iterator from tee()
@ -141,7 +190,7 @@ index 7d5ba727389..7c439cb420b 100644
def test_tee_del_backward(self):
forward, backward = tee(repeat(None, 20000000))
try:
@@ -1920,7 +1945,7 @@ class TestBasicOps(unittest.TestCase):
@@ -1920,7 +1947,7 @@ class TestBasicOps(unittest.TestCase):
tp.foobar = 1
@ -150,7 +199,7 @@ index 7d5ba727389..7c439cb420b 100644
def test_accumulate(self):
self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15])
@@ -2032,7 +2057,7 @@ class TestExamples(unittest.TestCase):
@@ -2032,7 +2059,7 @@ class TestExamples(unittest.TestCase):
self.assertEqual(list(takewhile(lambda x: x<5, [1,4,6,4,1])), [1,4])
@ -159,7 +208,7 @@ index 7d5ba727389..7c439cb420b 100644
def test_batched_recipe(self):
def batched_recipe(iterable, n):
@@ -2081,6 +2106,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
@@ -2081,6 +2108,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
for i, element in zip(range(i + 1, stop), iterable):
pass
@ -167,7 +216,7 @@ index 7d5ba727389..7c439cb420b 100644
def test_islice_recipe(self):
self.assertEqual(list(self.islice('ABCDEFG', 2)), list('AB'))
self.assertEqual(list(self.islice('ABCDEFG', 2, 4)), list('CD'))
@@ -2265,7 +2291,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
@@ -2265,7 +2293,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
raise
@ -176,7 +225,7 @@ index 7d5ba727389..7c439cb420b 100644
def makecycle(self, iterator, container):
container.append(iterator)
@@ -2465,7 +2491,7 @@ def L(seqn):
@@ -2465,7 +2493,7 @@ def L(seqn):
return chain(map(lambda x:x, R(Ig(G(seqn)))))
@ -185,7 +234,7 @@ index 7d5ba727389..7c439cb420b 100644
def test_accumulate(self):
s = [1,2,3,4,5]
@@ -2644,7 +2670,7 @@ class TestVariousIteratorArgs(unittest.TestCase):
@@ -2644,7 +2672,7 @@ class TestVariousIteratorArgs(unittest.TestCase):
self.assertRaises(TypeError, tee, N(s))
self.assertRaises(ZeroDivisionError, list, tee(E(s))[0])
@ -194,7 +243,7 @@ index 7d5ba727389..7c439cb420b 100644
def test_repeat(self):
self.assertEqual(operator.length_hint(repeat(None, 50)), 50)
@@ -2657,7 +2683,7 @@ class LengthTransparency(unittest.TestCase):
@@ -2657,7 +2685,7 @@ class LengthTransparency(unittest.TestCase):
self.assertEqual(operator.length_hint(repeat(None, times=-1)), 0)
self.assertEqual(operator.length_hint(repeat(None, times=-2)), 0)
@ -203,7 +252,7 @@ index 7d5ba727389..7c439cb420b 100644
def test_sf_793826(self):
# Fix Armin Rigo's successful efforts to wreak havoc
@@ -2718,6 +2744,7 @@ class RegressionTests(unittest.TestCase):
@@ -2718,6 +2746,7 @@ class RegressionTests(unittest.TestCase):
@support.skip_if_pgo_task
@support.requires_resource('cpu')
@ -211,7 +260,7 @@ index 7d5ba727389..7c439cb420b 100644
def test_long_chain_of_empty_iterables(self):
# Make sure itertools.chain doesn't run into recursion limits when
# dealing with long chains of empty iterables. Even with a high
@@ -2750,7 +2777,7 @@ class RegressionTests(unittest.TestCase):
@@ -2750,7 +2779,7 @@ class RegressionTests(unittest.TestCase):
next(g, None) # shouldn't crash
@ -220,7 +269,7 @@ index 7d5ba727389..7c439cb420b 100644
def test_keywords_in_subclass(self):
# count is not subclassable...
testcases = [
@@ -2805,49 +2832,5 @@ class SubclassWithKwargsTest(unittest.TestCase):
@@ -2805,49 +2834,5 @@ class SubclassWithKwargsTest(unittest.TestCase):
self.assertEqual(u.newarg, 3)

View File

@ -1024,27 +1024,29 @@ class TestBasicOps(__TestCase):
self.assertEqual(list(filter(None, [0,1,0,2,0])), [1,2])
self.assertEqual(list(filter(bool, [0,1,0,2,0])), [1,2])
self.assertEqual(take(4, filter(isEven, count())), [0,2,4,6])
self.assertRaises(TypeError, filter)
self.assertRaises(TypeError, filter, lambda x:x)
self.assertRaises(TypeError, filter, lambda x:x, range(6), 7)
self.assertRaises(TypeError, filter, isEven, 3)
self.assertRaises(TypeError, next, filter(range(6), range(6)))
# these tests raise dynamo exceptions, not TypeError
# self.assertRaises(TypeError, filter)
# self.assertRaises(TypeError, filter, lambda x:x)
# self.assertRaises(TypeError, filter, lambda x:x, range(6), 7)
# self.assertRaises(TypeError, filter, isEven, 3)
# dynamo raises Unsupported in this case
# self.assertRaises(TypeError, next, filter(range(6), range(6)))
# check copy, deepcopy, pickle
ans = [0,2,4]
# ans = [0,2,4]
c = filter(isEven, range(6))
self.assertEqual(list(copy.copy(c)), ans)
c = filter(isEven, range(6))
self.assertEqual(list(copy.deepcopy(c)), ans)
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
c = filter(isEven, range(6))
self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans)
next(c)
self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans[1:])
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
c = filter(isEven, range(6))
self.pickletest(proto, c)
# c = filter(isEven, range(6))
# self.assertEqual(list(copy.copy(c)), ans)
# c = filter(isEven, range(6))
# self.assertEqual(list(copy.deepcopy(c)), ans)
# for proto in range(pickle.HIGHEST_PROTOCOL + 1):
# c = filter(isEven, range(6))
# self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans)
# next(c)
# self.assertEqual(list(pickle.loads(pickle.dumps(c, proto))), ans[1:])
# for proto in range(pickle.HIGHEST_PROTOCOL + 1):
# c = filter(isEven, range(6))
# self.pickletest(proto, c)
@pickle_deprecated
def test_filterfalse(self):

View File

@ -514,7 +514,10 @@ class FilterVariable(IteratorVariable):
while True:
item = _next()
self.index += 1
res = self.fn.call_function(tx, [item], {})
if isinstance(self.fn, ConstantVariable) and self.fn.value is None:
res = item
else:
res = self.fn.call_function(tx, [item], {})
pred_res = variables.UserFunctionVariable(
polyfills.predicate
).call_function(tx, [res], {})