mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][annotate] Graph break cleanly on fx.traceback.annotate reconstruction (#166006)
This avoids generation of bad bytecode, leading to really confusing error. I am not sure why we can't reconstruct cleanly, it has to do with the input being a dict, while other supported ctx managers take bools. Fixing that is for another day. Lets give a good error message for now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166006 Approved by: https://github.com/yushangdi, https://github.com/SherlockNoMad
This commit is contained in:
parent
ad4dc52bf6
commit
830e789a55
|
|
@ -288,6 +288,18 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
|
|||
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_graph_break(self):
|
||||
def fn(x):
|
||||
with torch.fx.traceback.annotate({"pp_stage": 0}):
|
||||
x = torch.sin(x)
|
||||
torch._dynamo.graph_break()
|
||||
x = torch.cos(x)
|
||||
return x
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager")
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -2810,5 +2810,15 @@
|
|||
"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0279": [
|
||||
{
|
||||
"Gb_type": "torch.fx.traceback.annotate escaped from compiled region",
|
||||
"Context": "str(self)",
|
||||
"Explanation": "Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
|
||||
"Hints": [
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1295,6 +1295,16 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable):
|
|||
def fn_name(self):
|
||||
return "annotate"
|
||||
|
||||
def reconstruct_type(self, codegen: "PyCodegen"):
|
||||
unimplemented_v2(
|
||||
gb_type="torch.fx.traceback.annotate escaped from compiled region",
|
||||
context=str(self),
|
||||
explanation="Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
|
||||
hints=[
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class DynamoConfigPatchVariable(ContextWrappingVariable):
|
||||
"""represents torch._dynamo.patch_dynamo_config"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user