mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Implement generator.send(..) (#144422)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144422 Approved by: https://github.com/zou3519 ghstack dependencies: #141055, #144421
This commit is contained in:
parent
d798831167
commit
ca9b16e070
|
|
@ -613,13 +613,51 @@ class GraphModule(torch.nn.Module):
|
||||||
self.assertEqual(y, t + sum(range(6)))
|
self.assertEqual(y, t + sum(range(6)))
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeneratorSend(GeneratorTestsBase):
|
||||||
|
def test_send(self):
|
||||||
|
def double():
|
||||||
|
x = yield
|
||||||
|
yield x * 2
|
||||||
|
|
||||||
|
@torch.compile(backend="eager", fullgraph=True)
|
||||||
|
def fn(t):
|
||||||
|
gen = double()
|
||||||
|
next(gen)
|
||||||
|
return gen.send(t)
|
||||||
|
|
||||||
|
t = torch.randn(2)
|
||||||
|
y = fn(t)
|
||||||
|
self.assertEqual(y, t * 2)
|
||||||
|
|
||||||
|
@parametrize("fullgraph", [True, False])
|
||||||
|
def test_send_stop_iteration(self, fullgraph):
|
||||||
|
def double():
|
||||||
|
x = yield
|
||||||
|
yield x * 2
|
||||||
|
|
||||||
|
@torch.compile(backend="eager", fullgraph=fullgraph)
|
||||||
|
def fn(t):
|
||||||
|
gen = double()
|
||||||
|
next(gen)
|
||||||
|
a = gen.send(t)
|
||||||
|
b = gen.send(t) # should result in StopIteration
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
t = torch.randn(2)
|
||||||
|
if fullgraph:
|
||||||
|
with self.assertRaisesRegex(Unsupported, "Observed exception"):
|
||||||
|
fn(t)
|
||||||
|
else:
|
||||||
|
with self.assertRaises(StopIteration):
|
||||||
|
fn(t)
|
||||||
|
|
||||||
|
|
||||||
class GeneratorCPythonTests(GeneratorTestsBase):
|
class GeneratorCPythonTests(GeneratorTestsBase):
|
||||||
# Taken from commit
|
# Taken from commit
|
||||||
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
|
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
|
||||||
# changed the tests a little bit to run them inside dynamo
|
# changed the tests a little bit to run them inside dynamo
|
||||||
# + replaced all self.assert* calls to plain assert statements
|
# + replaced all self.assert* calls to plain assert statements
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
def test_send_non_none_to_new_gen(self):
|
def test_send_non_none_to_new_gen(self):
|
||||||
def f():
|
def f():
|
||||||
yield 1
|
yield 1
|
||||||
|
|
@ -661,6 +699,7 @@ class GeneratorCPythonTests(GeneratorTestsBase):
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(GeneratorTests)
|
instantiate_parametrized_tests(GeneratorTests)
|
||||||
|
instantiate_parametrized_tests(TestGeneratorSend)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -290,6 +290,11 @@ class ObservedNotImplementedError(ObservedException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ObservedTypeError(ObservedException):
|
||||||
|
# A TypeError exception to be raised from inside Dynamo tracing. This can happen on generator.send(..) method
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
observed_exception_map = {
|
observed_exception_map = {
|
||||||
StopIteration: ObservedUserStopIteration,
|
StopIteration: ObservedUserStopIteration,
|
||||||
LookupError: ObservedLookupError,
|
LookupError: ObservedLookupError,
|
||||||
|
|
@ -299,6 +304,7 @@ observed_exception_map = {
|
||||||
AttributeError: ObservedAttributeError,
|
AttributeError: ObservedAttributeError,
|
||||||
RuntimeError: ObservedRuntimeError,
|
RuntimeError: ObservedRuntimeError,
|
||||||
NotImplementedError: ObservedNotImplementedError,
|
NotImplementedError: ObservedNotImplementedError,
|
||||||
|
TypeError: ObservedTypeError,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -490,6 +490,9 @@ class LocalGeneratorObjectVariable(VariableTracker):
|
||||||
break
|
break
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _is_generator_just_started(self):
|
||||||
|
return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0
|
||||||
|
|
||||||
def call_method(
|
def call_method(
|
||||||
self,
|
self,
|
||||||
tx: "InstructionTranslator",
|
tx: "InstructionTranslator",
|
||||||
|
|
@ -502,6 +505,21 @@ class LocalGeneratorObjectVariable(VariableTracker):
|
||||||
elif name == "__iter__":
|
elif name == "__iter__":
|
||||||
# iter(gen) returns itself
|
# iter(gen) returns itself
|
||||||
return self
|
return self
|
||||||
|
elif name == "send":
|
||||||
|
# Sends a value into the generator function. Returns the next value
|
||||||
|
# yielded by the generator, or raises StopIteration if the generator
|
||||||
|
# exits without yielding another value
|
||||||
|
if self._is_generator_just_started() and len(args):
|
||||||
|
# can't send non-None value to a just-started generator
|
||||||
|
# Test: GeneratorCPythonTests.test_send_non_none_to_new_gen
|
||||||
|
if not all(
|
||||||
|
isinstance(arg, ConstantVariable) and arg.value is None
|
||||||
|
for arg in args
|
||||||
|
):
|
||||||
|
raise_observed_exception(TypeError, tx)
|
||||||
|
tracer = self._get_inline_tracer(tx)
|
||||||
|
tracer.push_many(args)
|
||||||
|
return self.next_variable(tx)
|
||||||
|
|
||||||
super().call_method(tx, name, args, kwargs)
|
super().call_method(tx, name, args, kwargs)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user