mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
redirect iter(range) to range.__iter__() (#161803)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161803 Approved by: https://github.com/anijain2305 ghstack dependencies: #161801, #161802
This commit is contained in:
parent
485a7bd82e
commit
c8255c67cd
|
|
@ -3529,7 +3529,6 @@ class GraphModule(torch.nn.Module):
|
||||||
return a + b
|
return a + b
|
||||||
return a - b
|
return a - b
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@make_test
|
@make_test
|
||||||
def test_range_iterator_2(a, b):
|
def test_range_iterator_2(a, b):
|
||||||
# should pass once we stop having three different paths on call_iter
|
# should pass once we stop having three different paths on call_iter
|
||||||
|
|
|
||||||
|
|
@ -1820,6 +1820,8 @@ class BuiltinVariable(VariableTracker):
|
||||||
def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs):
|
def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs):
|
||||||
if isinstance(obj, variables.IteratorVariable):
|
if isinstance(obj, variables.IteratorVariable):
|
||||||
ret = obj
|
ret = obj
|
||||||
|
elif isinstance(obj, variables.RangeVariable):
|
||||||
|
ret = obj.call_method(tx, "__iter__", [], {})
|
||||||
else:
|
else:
|
||||||
# Handle the case where we are iterating over a tuple, list or iterator
|
# Handle the case where we are iterating over a tuple, list or iterator
|
||||||
ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs)
|
ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user