mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Do not leak SkipFrame exception to parent frames (#91059)
Discovered by https://github.com/pytorch/torchdynamo/issues/2000, we noticed the exception `SkipFrame` to avoid repeatedly compiling frame of loop with graph breaks could leak to parent frames while inlining, which then prevents compiling. This PR checks at inlining if such exception is raised and would instead raise an `Unsupported` to the outer frame. The original behavior and goal of #88857 is unaffected: the inner frame that has loop would still be skipped. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91059 Approved by: https://github.com/jansel, https://github.com/thiagocrepaldi
This commit is contained in:
parent
a60125e298
commit
a72bcb3388
|
|
@ -2081,6 +2081,56 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.op_count, 1)
|
||||
|
||||
def test_nested_while_loop_graph_break(self):
|
||||
def inner_loop(x):
|
||||
i = 3
|
||||
while i > 0:
|
||||
i -= 1
|
||||
x += 1
|
||||
torch._dynamo.graph_break()
|
||||
return x
|
||||
|
||||
def inner(x):
|
||||
inner_loop(x)
|
||||
return torch.sin(x)
|
||||
|
||||
def fn(x):
|
||||
i = 20
|
||||
while i > 10:
|
||||
x = inner(x)
|
||||
i -= 1
|
||||
torch._dynamo.graph_break()
|
||||
return x
|
||||
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
||||
x = torch.randn(4)
|
||||
opt_fn(x)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.op_count, 1)
|
||||
|
||||
def test_while_loop_graph_break_inside_call_function(self):
|
||||
# Repro of huggingface graph break inside loop in `get_parameter_dtype`.
|
||||
# Skip only the inner frame that has loop that contains graph break.
|
||||
def inner(x):
|
||||
for i in range(3):
|
||||
x += 1
|
||||
torch._dynamo.graph_break()
|
||||
return x
|
||||
|
||||
def fn(x):
|
||||
x += 2
|
||||
inner(x)
|
||||
x += 3
|
||||
return x
|
||||
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
||||
x = torch.randn(4)
|
||||
opt_fn(x)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
self.assertEqual(cnt.op_count, 2)
|
||||
|
||||
@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
|
||||
def test_rewrite_assert_with_msg(self):
|
||||
def f(x):
|
||||
|
|
|
|||
|
|
@ -513,6 +513,7 @@ class TestTorchDeviceType(TestCase):
|
|||
|
||||
# collected tests of ops that used scalar_check in Declarations.cwrap for
|
||||
# correctness
|
||||
@skipIfTorchInductor("segfaults")
|
||||
def test_scalar_check(self, device):
|
||||
zero_d = torch.randn((), device=device)
|
||||
one_d = torch.randn((1,), device=device)
|
||||
|
|
@ -3801,6 +3802,7 @@ else:
|
|||
self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 0, dim=0)])
|
||||
|
||||
# functions that operate over a dimension but don't reduce.
|
||||
@skipIfTorchInductor("RuntimeError: Trying to create tensor with negative dimension -1: [-1]")
|
||||
def test_dim_function_empty(self, device):
|
||||
shape = (0, 1, 2, 0)
|
||||
x = torch.randn(shape, device=device)
|
||||
|
|
@ -5686,6 +5688,7 @@ class TestTorch(TestCase):
|
|||
added = zeros.index_add(0, torch.arange(0, size[0], dtype=idx_dtype, device=device), tensor, alpha=-1)
|
||||
self.assertEqual(added, -tensor)
|
||||
|
||||
@skipIfTorchInductor("AssertionError: RuntimeError not raised by <lambda>")
|
||||
def test_index_add_correctness(self):
|
||||
# Check whether index_add can get correct result when
|
||||
# alpha is 1, and dtype of index is torch.long,
|
||||
|
|
|
|||
|
|
@ -1796,9 +1796,9 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
|
||||
try:
|
||||
sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
|
||||
except TypeError as exc:
|
||||
except TypeError as e:
|
||||
log.warning(
|
||||
f"{func.get_filename()} {func.get_function()} {args} {kwargs} {exc}"
|
||||
f"{func.get_filename()} {func.get_function()} {args} {kwargs} {e}"
|
||||
)
|
||||
unimplemented("arg mismatch inlining")
|
||||
|
||||
|
|
@ -1822,7 +1822,15 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
parent, code, sub_locals, parent.symbolic_globals, closure_cells, func
|
||||
)
|
||||
|
||||
tracer.run()
|
||||
try:
|
||||
tracer.run()
|
||||
except exc.SkipFrame as e:
|
||||
msg = f"SKIPPED INLINING {code}: {e}"
|
||||
log.debug(msg)
|
||||
raise Unsupported(msg) from e
|
||||
except Exception as e:
|
||||
log.debug(f"FAILED INLINING {code}")
|
||||
raise
|
||||
assert tracer.symbolic_result is not None
|
||||
func.export_freevars(parent, tracer)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user