mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo] fix opcode YIELD_FROM and SEND (#123912)
This PR is split from #120300. - #120300 Pull Request resolved: https://github.com/pytorch/pytorch/pull/123912 Approved by: https://github.com/anijain2305
This commit is contained in:
parent
4b889d1247
commit
7b11fb4695
|
|
@ -8729,7 +8729,7 @@ def ___make_guard_fn():
|
|||
|
||||
return [t * k for t in yield_from_gen(t_list)]
|
||||
|
||||
t_list = [torch.randn([2, 3])] * 3
|
||||
t_list = [torch.randn([2, 3]) for _ in range(3)]
|
||||
eager = yield_from_fn(t_list, 2)
|
||||
counter = CompileCounter()
|
||||
compiled = torch._dynamo.optimize(counter)(yield_from_fn)(t_list, 2)
|
||||
|
|
@ -8778,6 +8778,34 @@ def ___make_guard_fn():
|
|||
self.assertEqual(eager, compiled)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
|
||||
def test_yield_from_user_stop_iteration(self):
|
||||
class MyIter:
|
||||
def __init__(self, seq):
|
||||
self.seq = seq
|
||||
self.index = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
self.index += 1
|
||||
if self.index <= len(self.seq):
|
||||
return self.seq[self.index - 1]
|
||||
raise StopIteration(self.index)
|
||||
|
||||
def yield_from_iter_fn(seq):
|
||||
def gen(seq):
|
||||
yield from MyIter(seq)
|
||||
|
||||
return [i for i in gen(seq)]
|
||||
|
||||
seq = [torch.randn([2, 3]) for _ in range(3)]
|
||||
eager = yield_from_iter_fn(seq)
|
||||
counter = CompileCounter()
|
||||
compiled = torch._dynamo.optimize(counter)(yield_from_iter_fn)(seq)
|
||||
self.assertEqual(eager, compiled)
|
||||
self.assertEqual(counter.frame_count, 0)
|
||||
|
||||
def test_yield_send_to_subgenerator_graph_break(self):
|
||||
def subgenerator(tensor):
|
||||
multiplier = yield
|
||||
|
|
|
|||
|
|
@ -576,7 +576,6 @@ class TestFX(JitTestCase):
|
|||
with self.assertRaisesRegex(AssertionError, "doesn't exist in"):
|
||||
tracer.trace(f)
|
||||
|
||||
|
||||
def test_graph_unique_names(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a, b):
|
||||
|
|
@ -814,7 +813,6 @@ class TestFX(JitTestCase):
|
|||
# Return final GraphModule!!!
|
||||
return GraphModule(wrapper, graph)
|
||||
|
||||
|
||||
# Lower GraphModule to C++ interpreter
|
||||
lowered = lower_to_elementwise_interpreter(msm)
|
||||
|
||||
|
|
@ -870,7 +868,6 @@ class TestFX(JitTestCase):
|
|||
x = self.lin(x)
|
||||
return x
|
||||
|
||||
|
||||
ec = ExampleCode()
|
||||
|
||||
traced = torch.fx.symbolic_trace(ec)
|
||||
|
|
@ -878,7 +875,6 @@ class TestFX(JitTestCase):
|
|||
x = torch.randn(bs, d_hid)
|
||||
torch.testing.assert_close(ec(x), traced(x))
|
||||
|
||||
|
||||
def test_node_tagging(self):
|
||||
class TaggingTracer(Tracer):
|
||||
def create_node(self, kind : str, target : Union[str, Callable],
|
||||
|
|
@ -952,7 +948,6 @@ class TestFX(JitTestCase):
|
|||
traced.graph.lint()
|
||||
self.assertEqual(count_attrs(traced), 2)
|
||||
|
||||
|
||||
def test_symbolic_trace_sequential(self):
|
||||
class Simple(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
@ -1486,7 +1481,6 @@ class TestFX(JitTestCase):
|
|||
|
||||
self.assertTrue(neg in relu.users)
|
||||
|
||||
|
||||
def test_nonetype_annotation(self):
|
||||
eb = torch.nn.EmbeddingBag(3, 4)
|
||||
symbolic_trace(eb)
|
||||
|
|
@ -1506,7 +1500,6 @@ class TestFX(JitTestCase):
|
|||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return (x, x + x)
|
||||
|
||||
|
||||
original = M()
|
||||
traced = symbolic_trace(original)
|
||||
self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1)))
|
||||
|
|
@ -1801,7 +1794,6 @@ class TestFX(JitTestCase):
|
|||
self.assertEqual(node.meta["stack_trace"], "stack_trace")
|
||||
self.assertEqual(node.meta["source_fn_stack"], "source_fn_stack")
|
||||
|
||||
|
||||
def test_interpreter(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -2174,7 +2166,6 @@ class TestFX(JitTestCase):
|
|||
for node in to_erase:
|
||||
rn18_traced.graph.erase_node(node)
|
||||
|
||||
|
||||
def test_replace_input(self):
|
||||
graph : torch.fx.Graph = torch.fx.Graph()
|
||||
x : torch.fx.Node = graph.create_node('placeholder', 'x')
|
||||
|
|
@ -2217,7 +2208,6 @@ class TestFX(JitTestCase):
|
|||
inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
|
||||
self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
|
||||
|
||||
|
||||
b.update_arg(0, y)
|
||||
new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
|
||||
self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
|
||||
|
|
@ -2233,7 +2223,6 @@ class TestFX(JitTestCase):
|
|||
inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
|
||||
self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
|
||||
|
||||
|
||||
b.update_kwarg('input', y)
|
||||
new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
|
||||
self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
|
||||
|
|
@ -2391,7 +2380,6 @@ class TestFX(JitTestCase):
|
|||
x, y = torch.randn(3, 4), torch.randn(3, 4)
|
||||
self.checkGraphModule(foo, (x, y))
|
||||
|
||||
|
||||
def test_trace_return_dataclass(self):
|
||||
"""
|
||||
Test case for Module that return dataclass
|
||||
|
|
@ -2449,7 +2437,6 @@ class TestFX(JitTestCase):
|
|||
|
||||
self.assertEqual(module(x), gm(x))
|
||||
|
||||
|
||||
def test_trace_return_namedtuple(self):
|
||||
"""
|
||||
Test case for Module that return namedtuple
|
||||
|
|
@ -2462,7 +2449,6 @@ class TestFX(JitTestCase):
|
|||
def forward(self, d : torch.Tensor):
|
||||
return MyOutput(foo=d, bar=d)
|
||||
|
||||
|
||||
module = ModuleReturnNamedTuple()
|
||||
|
||||
traced_graph = symbolic_trace(module).graph
|
||||
|
|
@ -2748,7 +2734,6 @@ class TestFX(JitTestCase):
|
|||
proc.join()
|
||||
self.assertEqual(proc.exitcode, 0)
|
||||
|
||||
|
||||
def test_user_friendly_call_provenance_with_function(self):
|
||||
def fn(x):
|
||||
return wrapper_fn(x)
|
||||
|
|
@ -3597,7 +3582,7 @@ class TestFX(JitTestCase):
|
|||
|
||||
def verify_pytree(f, inp):
|
||||
val = pytree.tree_map(lambda x: torch.randn(3) if isinstance(x, PHBase) else x, inp)
|
||||
num_flat_args = len([i == PH for i in pytree.tree_leaves(inp)])
|
||||
num_flat_args = len(pytree.tree_leaves(inp))
|
||||
orig_out = f(val)
|
||||
nf = symbolic_trace(f, concrete_args={'x': inp})
|
||||
self.assertEqual(nf(val), orig_out)
|
||||
|
|
@ -3867,7 +3852,6 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
|||
m.graph.lint()
|
||||
|
||||
|
||||
|
||||
def run_getitem_target():
|
||||
from torch.fx._symbolic_trace import _wrapped_methods_to_patch
|
||||
_wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
|
||||
|
|
|
|||
|
|
@ -159,8 +159,16 @@ class UserError(Unsupported):
|
|||
|
||||
|
||||
class UserStopIteration(TorchDynamoException):
|
||||
def __init__(self):
|
||||
value: Optional[Any]
|
||||
|
||||
# Reference `StopIteration_init` in CPython
|
||||
# https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L568-L584
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__("unhandled `raise StopIteration`")
|
||||
if len(args) > 0:
|
||||
self.value = args[0]
|
||||
else:
|
||||
self.value = None
|
||||
|
||||
|
||||
class UncapturedHigherOrderOpError(TorchDynamoException):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import collections
|
||||
import collections.abc
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
|
|
@ -1179,7 +1180,7 @@ class InstructionTranslatorBase(
|
|||
# Python 3.8 only
|
||||
addr = self.indexof[self.next_instruction]
|
||||
self.push(ConstantVariable.create(addr))
|
||||
self.instruction_pointer = self.indexof[inst.target]
|
||||
self.jump(inst)
|
||||
|
||||
def END_FINALLY(self, inst):
|
||||
# Python 3.8 only
|
||||
|
|
@ -2636,7 +2637,6 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
|
|||
|
||||
def YIELD_VALUE(self, inst: Instruction):
|
||||
self.generated_items.append(self.pop())
|
||||
# TODO(jansel): figure out why this is needed, it isn't in the docs for YIELD_VALUE
|
||||
self.push(ConstantVariable.create(None))
|
||||
|
||||
def GET_YIELD_FROM_ITER(self, inst):
|
||||
|
|
@ -2645,61 +2645,61 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
|
|||
self.pop()
|
||||
res = BuiltinVariable(iter).call_function(self, [tos], {})
|
||||
self.push(res)
|
||||
return self.YIELD_FROM(inst)
|
||||
|
||||
def YIELD_FROM(self, inst):
|
||||
while True:
|
||||
tos = self.stack[-1].realize()
|
||||
if isinstance(tos, ConstantVariable) and tos.value is None:
|
||||
self.pop()
|
||||
return
|
||||
try:
|
||||
val = tos.next_variable(self)
|
||||
assert len(self.stack) >= 2
|
||||
val = self.pop()
|
||||
tos = self.stack[-1]
|
||||
if not (isinstance(val, ConstantVariable) and val.value is None):
|
||||
# invoke send
|
||||
# Unreachable code - if you hit this, you are implementing generator support and have
|
||||
# lifted the `unimplemented("generator")` in frame conversion. This codepath handles
|
||||
# subgenerator and lines up with this line in Python 3.10
|
||||
# https://github.com/python/cpython/blob/3.10/Python/ceval.c#L2599
|
||||
unimplemented("Unreachable sub-generator code")
|
||||
|
||||
# TODO(anijain2305,jansel) - The last pop is because
|
||||
# YIELD_FROM. If we remove it from there, we don't need to
|
||||
# pop it here.
|
||||
self.push(val)
|
||||
self.YIELD_VALUE(inst)
|
||||
self.pop()
|
||||
try:
|
||||
val = tos.next_variable(self)
|
||||
except (StopIteration, exc.UserStopIteration) as ex:
|
||||
# The iterator is exhausted. Stop the loop and return.
|
||||
self.pop()
|
||||
self.push(ConstantVariable.create(ex.value))
|
||||
else:
|
||||
self.push(val)
|
||||
# Add the value to yield into generated_items and replace the top of the stack with None
|
||||
self.YIELD_VALUE(inst)
|
||||
|
||||
# Pop the old iter and push the new iter
|
||||
self.pop()
|
||||
self.push(tos)
|
||||
except (StopIteration, exc.UserStopIteration):
|
||||
return
|
||||
# Repeat the YIELD_FROM instruction in the next eval loop
|
||||
assert (
|
||||
isinstance(self.instruction_pointer, int)
|
||||
and self.instruction_pointer > 0
|
||||
)
|
||||
self.instruction_pointer -= 1
|
||||
|
||||
def SEND(self, inst):
|
||||
assert len(self.stack) >= 2
|
||||
val = self.pop()
|
||||
tos = self.stack[-1]
|
||||
if isinstance(tos, ListIteratorVariable):
|
||||
# We handle yield in a very differnt way than CPython does. Instead
|
||||
# of returning to the parent frame on a yield, TorchDynamo instead
|
||||
# just collects the generated_items and proceed to the next
|
||||
# instruction in the same frame. From bytecode tracing stanpoint,
|
||||
# this means that the iterator returned from the child funtion on
|
||||
# `yield from ...` will always be exhausted.
|
||||
|
||||
# Therefore to implement SEND, we have to look at the implementation
|
||||
# when the iterator returns StopIteration. This translates to this code
|
||||
# 3.11 - https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2613-L2618
|
||||
# 3.12 - https://github.com/python/cpython/blob/3.12/Python/bytecodes.c#L863-L865
|
||||
# The implementation is different in 3.11 and 3.12. In 3.12, we rely
|
||||
# on END_SEND to clean up. In 3.11, SEND does the cleanup as well.
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
# Do not pop, we will rely on END_SEND to pop the iterator
|
||||
pass
|
||||
else:
|
||||
# Check that the iterator is exhausted. It should be because of
|
||||
# how we implement yields.
|
||||
assert tos.is_exhausted()
|
||||
self.pop()
|
||||
|
||||
if isinstance(tos, ListIteratorVariable) or (
|
||||
isinstance(tos, UserDefinedObjectVariable)
|
||||
and isinstance(tos.value, collections.abc.Iterator)
|
||||
):
|
||||
if isinstance(val, ConstantVariable) and val.value is None:
|
||||
self.push(val)
|
||||
self.instruction_pointer = self.indexof[inst.target]
|
||||
try:
|
||||
val = tos.next_variable(self)
|
||||
except (StopIteration, exc.UserStopIteration) as ex:
|
||||
# To implement SEND, we have to look at the implementation
|
||||
# when the iterator returns StopIteration. This translates to this code
|
||||
# 3.11: https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2613-L2619
|
||||
# 3.12: https://github.com/python/cpython/blob/3.12/Python/bytecodes.c#L863-L866
|
||||
# The implementation is different in 3.11 and 3.12. In 3.12, we rely
|
||||
# on END_SEND to clean up. In 3.11, SEND does the cleanup as well.
|
||||
if sys.version_info < (3, 12):
|
||||
self.pop() # Python 3.12 uses new opcode END_SEND
|
||||
self.push(ConstantVariable.create(ex.value))
|
||||
self.jump(inst)
|
||||
else:
|
||||
self.push(val)
|
||||
else:
|
||||
# invoke send
|
||||
# Unreachable code - if you hit this, you are implementing generator support and have
|
||||
|
|
|
|||
|
|
@ -678,9 +678,6 @@ class ListIteratorVariable(VariableTracker):
|
|||
]
|
||||
)
|
||||
|
||||
def is_exhausted(self):
|
||||
return self.index >= len(self.items)
|
||||
|
||||
|
||||
class TupleIteratorVariable(ListIteratorVariable):
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user