[dynamo][annotate] Graph break cleanly on fx.traceback.annotate reconstruction (#166006)

This avoids generation of bad bytecode, leading to really confusing
error. I am not sure why we can't reconstruct cleanly, it has to do with
the input being a dict, while other supported ctx managers take bools.

Fixing that is for another day. Lets give a good error message for now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166006
Approved by: https://github.com/yushangdi, https://github.com/SherlockNoMad
This commit is contained in:
Animesh Jain 2025-10-21 10:46:17 -07:00 committed by PyTorch MergeBot
parent ad4dc52bf6
commit 830e789a55
3 changed files with 32 additions and 0 deletions

View File

@ -288,6 +288,18 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
)
def test_graph_break(self):
def fn(x):
with torch.fx.traceback.annotate({"pp_stage": 0}):
x = torch.sin(x)
torch._dynamo.graph_break()
x = torch.cos(x)
return x
opt_fn = torch.compile(fn, backend="eager")
x = torch.randn(10, requires_grad=True)
self.assertEqual(fn(x), opt_fn(x))
if __name__ == "__main__":
run_tests()

View File

@ -2810,5 +2810,15 @@
"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict."
]
}
],
"GB0279": [
{
"Gb_type": "torch.fx.traceback.annotate escaped from compiled region",
"Context": "str(self)",
"Explanation": "Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
"Hints": [
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
]
}

View File

@ -1295,6 +1295,16 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable):
def fn_name(self):
return "annotate"
def reconstruct_type(self, codegen: "PyCodegen"):
unimplemented_v2(
gb_type="torch.fx.traceback.annotate escaped from compiled region",
context=str(self),
explanation="Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
class DynamoConfigPatchVariable(ContextWrappingVariable):
"""represents torch._dynamo.patch_dynamo_config"""