mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] Fix an inductor_provenance bug (#166432)
Summary: Fix an inductor_provenance related error seen when running TORCH_COMPILE_DEBUG generated fx_graph_runnable.py. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166432 Approved by: https://github.com/mlazos
This commit is contained in:
parent
3f1824742c
commit
08b0a8f11a
|
|
@ -363,6 +363,25 @@ class FxGraphRunnableTest(TestCase):
|
|||
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
def test_metrics_context(self):
|
||||
"""
|
||||
When TORCH_COMPILE_DEBUG is set, provenance_tracking_level is set to 1, and
|
||||
the generated fx_graph_runnable crashed with,
|
||||
RuntimeError: Cannot add inductor_provenance outside of a MetricsContext
|
||||
"""
|
||||
import torch._inductor.config as inductor_config
|
||||
|
||||
def f(x):
|
||||
return x * 2 + 1
|
||||
|
||||
# Enable provenance tracking to trigger the code path that adds metrics
|
||||
with inductor_config.patch(
|
||||
{"trace.enabled": True, "trace.provenance_tracking_level": 1}
|
||||
):
|
||||
x = torch.randn(4, 4)
|
||||
torch.compile(f)(x)
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -1550,8 +1550,11 @@ class _InProcessFxCompile(FxCompile):
|
|||
payload_fn=lambda: inductor_kernel_stack_trace_str,
|
||||
)
|
||||
if inductor_kernel_stack_trace_str:
|
||||
get_metrics_context().add_to_set(
|
||||
"inductor_provenance", inductor_kernel_stack_trace_str
|
||||
metrics_context = get_metrics_context()
|
||||
if metrics_context.in_progress():
|
||||
metrics_context.add_to_set(
|
||||
"inductor_provenance",
|
||||
inductor_kernel_stack_trace_str,
|
||||
)
|
||||
|
||||
node_runtimes = None
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user