mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][annotate] Graph break cleanly on fx.traceback.annotate reconstruction (#166006)
This avoids generation of bad bytecode, leading to really confusing error. I am not sure why we can't reconstruct cleanly, it has to do with the input being a dict, while other supported ctx managers take bools. Fixing that is for another day. Lets give a good error message for now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166006 Approved by: https://github.com/yushangdi, https://github.com/SherlockNoMad
This commit is contained in:
parent
ad4dc52bf6
commit
830e789a55
|
|
@ -288,6 +288,18 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
|
||||||
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
|
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_graph_break(self):
|
||||||
|
def fn(x):
|
||||||
|
with torch.fx.traceback.annotate({"pp_stage": 0}):
|
||||||
|
x = torch.sin(x)
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
x = torch.cos(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
opt_fn = torch.compile(fn, backend="eager")
|
||||||
|
x = torch.randn(10, requires_grad=True)
|
||||||
|
self.assertEqual(fn(x), opt_fn(x))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -2810,5 +2810,15 @@
|
||||||
"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict."
|
"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict."
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
],
|
||||||
|
"GB0279": [
|
||||||
|
{
|
||||||
|
"Gb_type": "torch.fx.traceback.annotate escaped from compiled region",
|
||||||
|
"Context": "str(self)",
|
||||||
|
"Explanation": "Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
|
||||||
|
"Hints": [
|
||||||
|
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||||
|
]
|
||||||
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1295,6 +1295,16 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable):
|
||||||
def fn_name(self):
|
def fn_name(self):
|
||||||
return "annotate"
|
return "annotate"
|
||||||
|
|
||||||
|
def reconstruct_type(self, codegen: "PyCodegen"):
|
||||||
|
unimplemented_v2(
|
||||||
|
gb_type="torch.fx.traceback.annotate escaped from compiled region",
|
||||||
|
context=str(self),
|
||||||
|
explanation="Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
|
||||||
|
hints=[
|
||||||
|
*graph_break_hints.SUPPORTABLE,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DynamoConfigPatchVariable(ContextWrappingVariable):
|
class DynamoConfigPatchVariable(ContextWrappingVariable):
|
||||||
"""represents torch._dynamo.patch_dynamo_config"""
|
"""represents torch._dynamo.patch_dynamo_config"""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user