[Dynamo] fix opcode YIELD_FROM and SEND (#123912)

This PR is split from #120300.

- #120300

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123912
Approved by: https://github.com/anijain2305
This commit is contained in:
Xuehai Pan 2024-04-12 18:24:20 +00:00 committed by PyTorch MergeBot
parent 4b889d1247
commit 7b11fb4695
6 changed files with 86 additions and 69 deletions

View File

@ -8729,7 +8729,7 @@ def ___make_guard_fn():
return [t * k for t in yield_from_gen(t_list)]
t_list = [torch.randn([2, 3])] * 3
t_list = [torch.randn([2, 3]) for _ in range(3)]
eager = yield_from_fn(t_list, 2)
counter = CompileCounter()
compiled = torch._dynamo.optimize(counter)(yield_from_fn)(t_list, 2)
@ -8778,6 +8778,34 @@ def ___make_guard_fn():
self.assertEqual(eager, compiled)
self.assertEqual(counter.frame_count, 1)
def test_yield_from_user_stop_iteration(self):
class MyIter:
def __init__(self, seq):
self.seq = seq
self.index = 0
def __iter__(self):
return self
def __next__(self):
self.index += 1
if self.index <= len(self.seq):
return self.seq[self.index - 1]
raise StopIteration(self.index)
def yield_from_iter_fn(seq):
def gen(seq):
yield from MyIter(seq)
return [i for i in gen(seq)]
seq = [torch.randn([2, 3]) for _ in range(3)]
eager = yield_from_iter_fn(seq)
counter = CompileCounter()
compiled = torch._dynamo.optimize(counter)(yield_from_iter_fn)(seq)
self.assertEqual(eager, compiled)
self.assertEqual(counter.frame_count, 0)
def test_yield_send_to_subgenerator_graph_break(self):
def subgenerator(tensor):
multiplier = yield

View File

@ -576,7 +576,6 @@ class TestFX(JitTestCase):
with self.assertRaisesRegex(AssertionError, "doesn't exist in"):
tracer.trace(f)
def test_graph_unique_names(self):
class M(torch.nn.Module):
def forward(self, a, b):
@ -814,7 +813,6 @@ class TestFX(JitTestCase):
# Return final GraphModule!!!
return GraphModule(wrapper, graph)
# Lower GraphModule to C++ interpreter
lowered = lower_to_elementwise_interpreter(msm)
@ -870,7 +868,6 @@ class TestFX(JitTestCase):
x = self.lin(x)
return x
ec = ExampleCode()
traced = torch.fx.symbolic_trace(ec)
@ -878,7 +875,6 @@ class TestFX(JitTestCase):
x = torch.randn(bs, d_hid)
torch.testing.assert_close(ec(x), traced(x))
def test_node_tagging(self):
class TaggingTracer(Tracer):
def create_node(self, kind : str, target : Union[str, Callable],
@ -952,7 +948,6 @@ class TestFX(JitTestCase):
traced.graph.lint()
self.assertEqual(count_attrs(traced), 2)
def test_symbolic_trace_sequential(self):
class Simple(torch.nn.Module):
def forward(self, x):
@ -1486,7 +1481,6 @@ class TestFX(JitTestCase):
self.assertTrue(neg in relu.users)
def test_nonetype_annotation(self):
eb = torch.nn.EmbeddingBag(3, 4)
symbolic_trace(eb)
@ -1506,7 +1500,6 @@ class TestFX(JitTestCase):
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return (x, x + x)
original = M()
traced = symbolic_trace(original)
self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1)))
@ -1801,7 +1794,6 @@ class TestFX(JitTestCase):
self.assertEqual(node.meta["stack_trace"], "stack_trace")
self.assertEqual(node.meta["source_fn_stack"], "source_fn_stack")
def test_interpreter(self):
class MyModule(torch.nn.Module):
def __init__(self):
@ -2174,7 +2166,6 @@ class TestFX(JitTestCase):
for node in to_erase:
rn18_traced.graph.erase_node(node)
def test_replace_input(self):
graph : torch.fx.Graph = torch.fx.Graph()
x : torch.fx.Node = graph.create_node('placeholder', 'x')
@ -2217,7 +2208,6 @@ class TestFX(JitTestCase):
inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
b.update_arg(0, y)
new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
@ -2233,7 +2223,6 @@ class TestFX(JitTestCase):
inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
b.update_kwarg('input', y)
new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
@ -2391,7 +2380,6 @@ class TestFX(JitTestCase):
x, y = torch.randn(3, 4), torch.randn(3, 4)
self.checkGraphModule(foo, (x, y))
def test_trace_return_dataclass(self):
"""
Test case for Module that return dataclass
@ -2449,7 +2437,6 @@ class TestFX(JitTestCase):
self.assertEqual(module(x), gm(x))
def test_trace_return_namedtuple(self):
"""
Test case for Module that return namedtuple
@ -2462,7 +2449,6 @@ class TestFX(JitTestCase):
def forward(self, d : torch.Tensor):
return MyOutput(foo=d, bar=d)
module = ModuleReturnNamedTuple()
traced_graph = symbolic_trace(module).graph
@ -2748,7 +2734,6 @@ class TestFX(JitTestCase):
proc.join()
self.assertEqual(proc.exitcode, 0)
def test_user_friendly_call_provenance_with_function(self):
def fn(x):
return wrapper_fn(x)
@ -3597,7 +3582,7 @@ class TestFX(JitTestCase):
def verify_pytree(f, inp):
val = pytree.tree_map(lambda x: torch.randn(3) if isinstance(x, PHBase) else x, inp)
num_flat_args = len([i == PH for i in pytree.tree_leaves(inp)])
num_flat_args = len(pytree.tree_leaves(inp))
orig_out = f(val)
nf = symbolic_trace(f, concrete_args={'x': inp})
self.assertEqual(nf(val), orig_out)
@ -3867,7 +3852,6 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
m.graph.lint()
def run_getitem_target():
from torch.fx._symbolic_trace import _wrapped_methods_to_patch
_wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))

View File

@ -159,8 +159,16 @@ class UserError(Unsupported):
class UserStopIteration(TorchDynamoException):
def __init__(self):
value: Optional[Any]
# Reference `StopIteration_init` in CPython
# https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L568-L584
def __init__(self, *args, **kwargs):
super().__init__("unhandled `raise StopIteration`")
if len(args) > 0:
self.value = args[0]
else:
self.value = None
class UncapturedHigherOrderOpError(TorchDynamoException):

View File

@ -1,4 +1,5 @@
import collections
import collections.abc
import contextlib
import copy
import dataclasses
@ -1179,7 +1180,7 @@ class InstructionTranslatorBase(
# Python 3.8 only
addr = self.indexof[self.next_instruction]
self.push(ConstantVariable.create(addr))
self.instruction_pointer = self.indexof[inst.target]
self.jump(inst)
def END_FINALLY(self, inst):
# Python 3.8 only
@ -2636,7 +2637,6 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
def YIELD_VALUE(self, inst: Instruction):
self.generated_items.append(self.pop())
# TODO(jansel): figure out why this is needed, it isn't in the docs for YIELD_VALUE
self.push(ConstantVariable.create(None))
def GET_YIELD_FROM_ITER(self, inst):
@ -2645,61 +2645,61 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
self.pop()
res = BuiltinVariable(iter).call_function(self, [tos], {})
self.push(res)
return self.YIELD_FROM(inst)
def YIELD_FROM(self, inst):
while True:
tos = self.stack[-1].realize()
if isinstance(tos, ConstantVariable) and tos.value is None:
self.pop()
return
try:
val = tos.next_variable(self)
assert len(self.stack) >= 2
val = self.pop()
tos = self.stack[-1]
if not (isinstance(val, ConstantVariable) and val.value is None):
# invoke send
# Unreachable code - if you hit this, you are implementing generator support and have
# lifted the `unimplemented("generator")` in frame conversion. This codepath handles
# subgenerator and lines up with this line in Python 3.10
# https://github.com/python/cpython/blob/3.10/Python/ceval.c#L2599
unimplemented("Unreachable sub-generator code")
# TODO(anijain2305,jansel) - The last pop is because
# YIELD_FROM. If we remove it from there, we don't need to
# pop it here.
self.push(val)
self.YIELD_VALUE(inst)
self.pop()
try:
val = tos.next_variable(self)
except (StopIteration, exc.UserStopIteration) as ex:
# The iterator is exhausted. Stop the loop and return.
self.pop()
self.push(ConstantVariable.create(ex.value))
else:
self.push(val)
# Add the value to yield into generated_items and replace the top of the stack with None
self.YIELD_VALUE(inst)
# Pop the old iter and push the new iter
self.pop()
self.push(tos)
except (StopIteration, exc.UserStopIteration):
return
# Repeat the YIELD_FROM instruction in the next eval loop
assert (
isinstance(self.instruction_pointer, int)
and self.instruction_pointer > 0
)
self.instruction_pointer -= 1
def SEND(self, inst):
assert len(self.stack) >= 2
val = self.pop()
tos = self.stack[-1]
if isinstance(tos, ListIteratorVariable):
# We handle yield in a very differnt way than CPython does. Instead
# of returning to the parent frame on a yield, TorchDynamo instead
# just collects the generated_items and proceed to the next
# instruction in the same frame. From bytecode tracing stanpoint,
# this means that the iterator returned from the child funtion on
# `yield from ...` will always be exhausted.
# Therefore to implement SEND, we have to look at the implementation
# when the iterator returns StopIteration. This translates to this code
# 3.11 - https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2613-L2618
# 3.12 - https://github.com/python/cpython/blob/3.12/Python/bytecodes.c#L863-L865
# The implementation is different in 3.11 and 3.12. In 3.12, we rely
# on END_SEND to clean up. In 3.11, SEND does the cleanup as well.
if sys.version_info >= (3, 12):
# Do not pop, we will rely on END_SEND to pop the iterator
pass
else:
# Check that the iterator is exhausted. It should be because of
# how we implement yields.
assert tos.is_exhausted()
self.pop()
if isinstance(tos, ListIteratorVariable) or (
isinstance(tos, UserDefinedObjectVariable)
and isinstance(tos.value, collections.abc.Iterator)
):
if isinstance(val, ConstantVariable) and val.value is None:
self.push(val)
self.instruction_pointer = self.indexof[inst.target]
try:
val = tos.next_variable(self)
except (StopIteration, exc.UserStopIteration) as ex:
# To implement SEND, we have to look at the implementation
# when the iterator returns StopIteration. This translates to this code
# 3.11: https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2613-L2619
# 3.12: https://github.com/python/cpython/blob/3.12/Python/bytecodes.c#L863-L866
# The implementation is different in 3.11 and 3.12. In 3.12, we rely
# on END_SEND to clean up. In 3.11, SEND does the cleanup as well.
if sys.version_info < (3, 12):
self.pop() # Python 3.12 uses new opcode END_SEND
self.push(ConstantVariable.create(ex.value))
self.jump(inst)
else:
self.push(val)
else:
# invoke send
# Unreachable code - if you hit this, you are implementing generator support and have

View File

@ -678,9 +678,6 @@ class ListIteratorVariable(VariableTracker):
]
)
def is_exhausted(self):
return self.index >= len(self.items)
class TupleIteratorVariable(ListIteratorVariable):
pass