Do not leak SkipFrame exception to parent frames (#91059)

Discovered by https://github.com/pytorch/torchdynamo/issues/2000, we noticed the exception `SkipFrame` to avoid repeatedly compiling frame of loop with graph breaks could leak to parent frames while inlining, which then prevents compiling.

This PR checks at inlining if such exception is raised and would instead raise an `Unsupported` to the outer frame. The original behavior and goal of #88857 is unaffected: the inner frame that has loop would still be skipped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91059
Approved by: https://github.com/jansel, https://github.com/thiagocrepaldi
This commit is contained in:
BowenBao 2023-01-12 10:06:59 -08:00 committed by PyTorch MergeBot
parent a60125e298
commit a72bcb3388
3 changed files with 64 additions and 3 deletions

View File

@ -2081,6 +2081,56 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 1)
def test_nested_while_loop_graph_break(self):
def inner_loop(x):
i = 3
while i > 0:
i -= 1
x += 1
torch._dynamo.graph_break()
return x
def inner(x):
inner_loop(x)
return torch.sin(x)
def fn(x):
i = 20
while i > 10:
x = inner(x)
i -= 1
torch._dynamo.graph_break()
return x
cnt = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnt)(fn)
x = torch.randn(4)
opt_fn(x)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 1)
def test_while_loop_graph_break_inside_call_function(self):
# Repro of huggingface graph break inside loop in `get_parameter_dtype`.
# Skip only the inner frame that has loop that contains graph break.
def inner(x):
for i in range(3):
x += 1
torch._dynamo.graph_break()
return x
def fn(x):
x += 2
inner(x)
x += 3
return x
cnt = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnt)(fn)
x = torch.randn(4)
opt_fn(x)
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 2)
@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
def test_rewrite_assert_with_msg(self):
def f(x):

View File

@ -513,6 +513,7 @@ class TestTorchDeviceType(TestCase):
# collected tests of ops that used scalar_check in Declarations.cwrap for
# correctness
@skipIfTorchInductor("segfaults")
def test_scalar_check(self, device):
zero_d = torch.randn((), device=device)
one_d = torch.randn((1,), device=device)
@ -3801,6 +3802,7 @@ else:
self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 0, dim=0)])
# functions that operate over a dimension but don't reduce.
@skipIfTorchInductor("RuntimeError: Trying to create tensor with negative dimension -1: [-1]")
def test_dim_function_empty(self, device):
shape = (0, 1, 2, 0)
x = torch.randn(shape, device=device)
@ -5686,6 +5688,7 @@ class TestTorch(TestCase):
added = zeros.index_add(0, torch.arange(0, size[0], dtype=idx_dtype, device=device), tensor, alpha=-1)
self.assertEqual(added, -tensor)
@skipIfTorchInductor("AssertionError: RuntimeError not raised by <lambda>")
def test_index_add_correctness(self):
# Check whether index_add can get correct result when
# alpha is 1, and dtype of index is torch.long,

View File

@ -1796,9 +1796,9 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
try:
sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
except TypeError as exc:
except TypeError as e:
log.warning(
f"{func.get_filename()} {func.get_function()} {args} {kwargs} {exc}"
f"{func.get_filename()} {func.get_function()} {args} {kwargs} {e}"
)
unimplemented("arg mismatch inlining")
@ -1822,7 +1822,15 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
parent, code, sub_locals, parent.symbolic_globals, closure_cells, func
)
tracer.run()
try:
tracer.run()
except exc.SkipFrame as e:
msg = f"SKIPPED INLINING {code}: {e}"
log.debug(msg)
raise Unsupported(msg) from e
except Exception as e:
log.debug(f"FAILED INLINING {code}")
raise
assert tracer.symbolic_result is not None
func.export_freevars(parent, tracer)