Separate provenance tracking to different levels (#160383)

Summary: as title. We've got request from various parties who are interested in turning on the provenance tracking by default. In this PR, we prepare to turn on part of the provenance tracking that doesn't have too much overhead by default.

- Change `provenance_tracking` config to `provenance_tracking_level`
- turn on the following provenance tracking by default when `basic_provenance_tracking`=True
    - `set_kernel_post_grad_provenance_tracing` for kernels, this add mapping between triton kernels and post_grad nodes
    - `dump_inductor_provenance_info` if we're dumping tlparse log
    - `get_graph_provenance_json` and dump `reate_mapping_pre_post_grad_nodes`. This creates mapping between pre_grad and post_grad nodes. Since we're not turning on the provenance tracking in GraphTransformObserver by default, the mapping here maybe incomplete/limited.
    - add stack trace from post grad nodes to inductor IR nodes
    - add exception swallowing for all functions above

Test Plan:
CI

Rollback Plan:

Differential Revision: D80031559

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160383
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu 2025-08-15 04:59:32 +00:00 committed by PyTorch MergeBot
parent 3fc7a95176
commit aa99e0958f
12 changed files with 136 additions and 110 deletions

View File

@ -55,7 +55,7 @@ class TestGraphTransformObserver(TestCase):
)
)
@torch._inductor.config.patch("trace.provenance_tracking", True)
@torch._inductor.config.patch("trace.provenance_tracking_level", 1)
def test_graph_transform_observer_node_tracking(self):
class M(torch.nn.Module):
def forward(self, x):
@ -156,7 +156,7 @@ class TestGraphTransformObserver(TestCase):
[NodeSourceAction.REPLACE, NodeSourceAction.CREATE],
)
@torch._inductor.config.patch("trace.provenance_tracking", True)
@torch._inductor.config.patch("trace.provenance_tracking_level", 1)
def test_graph_transform_observer_deepcopy(self):
class SimpleLinearModel(torch.nn.Module):
def forward(self, x):
@ -179,7 +179,7 @@ class TestGraphTransformObserver(TestCase):
self.assertEqual(len(gm2._erase_node_hooks), 0)
self.assertEqual(len(gm2._deepcopy_hooks), 0)
@torch._inductor.config.patch("trace.provenance_tracking", True)
@torch._inductor.config.patch("trace.provenance_tracking_level", 1)
def test_graph_transform_observer_replace(self):
# the node sohuld should not be duplicated
class Model(torch.nn.Module):

View File

@ -62,7 +62,7 @@ class Model3(torch.nn.Module):
@config.patch("trace.enabled", True)
@config.patch("trace.provenance_tracking", True)
@config.patch("trace.provenance_tracking_level", 1)
class TestProvenanceTracingArtifact(TestCase):
"""
This test checks that generated provenance tracing artifact from "post_grad" to

View File

@ -5385,7 +5385,7 @@ class CppScheduling(BaseScheduling):
)
kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()])
# below add provenance tracing info for cpu CppKernel types
if config.trace.provenance_tracking:
if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing(nodes, kernel_name)
kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"

View File

@ -1453,7 +1453,7 @@ class SIMDScheduling(BaseScheduling):
with V.set_kernel_handler(kernel):
src_code = kernel.codegen_kernel()
kernel_name = self.define_kernel(src_code, node_schedule, kernel)
if config.trace.provenance_tracking:
if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing(
node_schedule, # type: ignore[arg-type]
kernel_name,
@ -1664,7 +1664,7 @@ class SIMDScheduling(BaseScheduling):
kernel.kernel_name = self.define_kernel(src_code, node_schedule, kernel)
if config.trace.provenance_tracking:
if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing(
node_schedule, kernel.kernel_name
)
@ -1849,7 +1849,7 @@ class SIMDScheduling(BaseScheduling):
for src_code, kernel, _ in kernel_code_list:
kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel)
# dump provenance node info for ComboKernelNode/ForeachKernel type
if config.trace.provenance_tracking:
if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing(
combo_kernel_node.snodes, kernel_name
)

View File

@ -481,7 +481,7 @@ class ExternKernelOutLine(WrapperLine):
kernel_name = node.get_kernel_name()
device = d.type if (d := node.get_device()) else V.graph.device_type
# set provenance tracing kernel mapping for ExternKernel types
if config.trace.provenance_tracking:
if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing(node, kernel_name, is_extern=True)
self.wrapper._generate_extern_kernel_out_helper(
kernel_name,

View File

@ -64,7 +64,10 @@ from torch._inductor.cudagraph_utils import (
log_cudagraph_skip_and_bump_counter,
PlaceholderInfo,
)
from torch._inductor.debug import save_args_for_compile_fx_inner
from torch._inductor.debug import (
create_mapping_pre_post_grad_nodes,
save_args_for_compile_fx_inner,
)
from torch._inductor.output_code import (
CompiledAOTI,
CompiledFxGraph,
@ -1055,19 +1058,18 @@ def _compile_fx_inner(
log.debug("FX codegen and compilation took %.3fs", time.time() - start)
if config.trace.provenance_tracking:
# Dump provenance artifacts for debugging trace
provenance_info = torch._inductor.debug.dump_inductor_provenance_info()
# provenance_info might be None if trace.provenance_tracking is not set
if provenance_info:
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_provenance_tracking_node_mappings",
"encoding": "json",
},
payload_fn=lambda: json.dumps(provenance_info),
)
# Dump provenance artifacts for debugging trace
if config.trace.provenance_tracking_level != 0:
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_provenance_tracking_node_mappings",
"encoding": "json",
},
payload_fn=lambda: json.dumps(
torch._inductor.debug.dump_inductor_provenance_info()
),
)
# This message is for printing overview information of inductor mm counts, shapes,etc after lowering
if log.isEnabledFor(logging.INFO):
@ -1310,20 +1312,10 @@ class _InProcessFxCompile(FxCompile):
},
payload_fn=lambda: inductor_post_grad_graph_str,
)
if config.trace.provenance_tracking:
if config.trace.provenance_tracking_level != 0:
provenance_tracking_json = (
torch.fx.traceback.get_graph_provenance_json(gm.graph)
)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_post_to_pre_grad_nodes",
"encoding": "json",
},
payload_fn=lambda: json.dumps(provenance_tracking_json),
)
from torch._inductor.debug import create_mapping_pre_post_grad_nodes
torch._inductor.debug._inductor_post_to_pre_grad_nodes = (
create_mapping_pre_post_grad_nodes(
torch._inductor.debug._pre_grad_graph_id,
@ -2205,7 +2197,9 @@ def compile_fx(
with (
_use_lazy_graph_module(dynamo_config.use_lazy_graph_module),
enable_python_dispatcher(),
torch.fx.traceback.preserve_node_meta(config.trace.provenance_tracking),
torch.fx.traceback.preserve_node_meta(
config.trace.provenance_tracking_level == 1
),
torch._inductor.debug.reset_provenance_globals(),
):
# Pre-grad passes cannot be run if we weren't given a GraphModule.
@ -2239,7 +2233,7 @@ def compile_fx(
)
torch._inductor.debug._pre_grad_graph_id = id(model_.graph)
if config.trace.provenance_tracking:
if config.trace.provenance_tracking_level == 1:
for node in model_.graph.nodes:
if node.stack_trace:
torch._inductor.debug._inductor_pre_grad_node_stack_trace[

View File

@ -1832,10 +1832,18 @@ class trace:
log_autotuning_results = os.environ.get("LOG_AUTOTUNE_RESULTS", "0") == "1"
# Save mapping info from inductor generated triton kernel to post_grad fx nodes to pre_grad fx nodes
provenance_tracking = (
os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
or os.environ.get("INDUCTOR_PROVENANCE", "0") == "1"
# Save mapping info from inductor generated kernel to post_grad/pre_grad fx nodes
# Levels:
# 0 - disabled (default)
# 1 - normal
# 2 - basic
# Backward compatibility:
# If TORCH_COMPILE_DEBUG=1, level is set to at least 1.
# If INDUCTOR_PROVENANCE is set, use its integer value.
provenance_tracking_level: int = int(
os.environ.get(
"INDUCTOR_PROVENANCE", os.environ.get("TORCH_COMPILE_DEBUG", "0")
)
)

View File

@ -769,8 +769,6 @@ def create_mapping_pre_post_grad_nodes(
"postToPre": {},
}
log.info("Creating node mappings for provenance tracking")
if not isinstance(post_to_pre_grad_nodes_json, dict):
log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict")
return empty_return
@ -860,8 +858,6 @@ def create_node_mapping_kernel_to_post_grad(
"postToCppCode": {},
}
log.info("Creating node mappings for provenance tracking")
if not isinstance(triton_kernel_to_post_grad_json, dict):
log.error(
"Provenance tacking error: triton_kernel_to_post_grad_json is not a dict"
@ -905,28 +901,36 @@ def create_node_mapping_kernel_to_post_grad(
def dump_inductor_provenance_info(
filename: str = "inductor_generated_kernel_to_post_grad_nodes.json",
) -> dict[str, Any]:
global _pre_grad_graph_id
global _inductor_post_to_pre_grad_nodes
global _inductor_triton_kernel_to_post_grad_node_info
if config.trace.enabled:
with V.debug.fopen(filename, "w") as fd:
log.info("Writing provenance tracing debugging info to %s", fd.name)
json.dump(_inductor_triton_kernel_to_post_grad_node_info, fd)
node_mapping = {}
if _pre_grad_graph_id:
node_mapping_kernel = create_node_mapping_kernel_to_post_grad(
_inductor_triton_kernel_to_post_grad_node_info
)
node_mapping = {
**_inductor_post_to_pre_grad_nodes,
**node_mapping_kernel,
}
try:
global _pre_grad_graph_id
global _inductor_post_to_pre_grad_nodes
global _inductor_triton_kernel_to_post_grad_node_info
if config.trace.enabled:
with V.debug.fopen(
"inductor_provenance_tracking_node_mappings.json", "w"
) as fd:
json.dump(node_mapping, fd)
return node_mapping
with V.debug.fopen(filename, "w") as fd:
log.info("Writing provenance tracing debugging info to %s", fd.name)
json.dump(_inductor_triton_kernel_to_post_grad_node_info, fd)
node_mapping = {}
if _pre_grad_graph_id:
node_mapping_kernel = create_node_mapping_kernel_to_post_grad(
_inductor_triton_kernel_to_post_grad_node_info
)
node_mapping = {
**_inductor_post_to_pre_grad_nodes,
**node_mapping_kernel,
}
if config.trace.enabled:
with V.debug.fopen(
"inductor_provenance_tracking_node_mappings.json", "w"
) as fd:
json.dump(node_mapping, fd)
return node_mapping
except Exception as e:
# Since this is just debugging, it should never interfere with regular
# program execution, so we use this try-except to guard against any error
# TODO: log the error to scuba table for better signal
log.error("Unexpected error in dump_inductor_provenance_info: %s", e)
log.error(traceback.format_exc())
return {}
def set_kernel_post_grad_provenance_tracing(
@ -934,42 +938,49 @@ def set_kernel_post_grad_provenance_tracing(
kernel_name: str,
is_extern: bool = False,
) -> None:
from .codegen.simd_kernel_features import DisableReduction, EnableReduction
try:
from .codegen.simd_kernel_features import DisableReduction, EnableReduction
global _inductor_triton_kernel_to_post_grad_node_info
if is_extern:
assert isinstance(node_schedule, ExternKernelOut)
curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault(
kernel_name, []
)
# 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel.
# "origin_node" is more precise and says that the contents of this node corresponds
# EXACTLY to the output of a particular FX node, but it's not always available
if node_schedule.origin_node:
origin_node_name = node_schedule.origin_node.name
if origin_node_name not in curr_node_info:
curr_node_info.append(origin_node_name)
else:
curr_node_info.extend(
origin.name
for origin in node_schedule.origins
if origin.name not in curr_node_info
global _inductor_triton_kernel_to_post_grad_node_info
if is_extern:
assert isinstance(node_schedule, ExternKernelOut)
curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault(
kernel_name, []
)
else:
assert isinstance(node_schedule, list)
for snode in node_schedule:
if snode not in (EnableReduction, DisableReduction):
if snode.node is not None:
curr_node_info = (
_inductor_triton_kernel_to_post_grad_node_info.setdefault(
kernel_name, []
# 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel.
# "origin_node" is more precise and says that the contents of this node corresponds
# EXACTLY to the output of a particular FX node, but it's not always available
if node_schedule.origin_node:
origin_node_name = node_schedule.origin_node.name
if origin_node_name not in curr_node_info:
curr_node_info.append(origin_node_name)
else:
curr_node_info.extend(
origin.name
for origin in node_schedule.origins
if origin.name not in curr_node_info
)
else:
assert isinstance(node_schedule, list)
for snode in node_schedule:
if snode not in (EnableReduction, DisableReduction):
if snode.node is not None:
curr_node_info = (
_inductor_triton_kernel_to_post_grad_node_info.setdefault(
kernel_name, []
)
)
)
curr_node_info.extend(
origin.name
for origin in snode.node.origins
if origin.name not in curr_node_info
)
curr_node_info.extend(
origin.name
for origin in snode.node.origins
if origin.name not in curr_node_info
)
except Exception as e:
# Since this is just debugging, it should never interfere with regular
# program execution, so we use this try-except to guard against any error
# TODO: log the error to scuba table for better signal
log.error("Unexpected error in set_kernel_post_grad_provenance_tracing: %s", e)
log.error(traceback.format_exc())
def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None:

View File

@ -129,7 +129,7 @@ def _transfer_meta(
# transfer metadata after pattern matching occurs.
# skip "val" and "tensor_meta" because this info is too specific; it's unlikely
# to remain accurate after pattern matching has occurred.
if config.trace.provenance_tracking:
if config.trace.provenance_tracking_level == 1:
# We handle "from_node" field of the node meta specially to record that the new node comes from the old_node.
new_from_node = new_meta.get("from_node", []).copy()
new_from_node.append(NodeSource(old_node, pass_name, NodeSourceAction.REPLACE))

View File

@ -2345,7 +2345,8 @@ def make_fx(
record_module_stack,
_allow_fake_constant,
_error_on_data_dependent_ops,
record_stack_traces=record_stack_traces or config.trace.provenance_tracking,
record_stack_traces=record_stack_traces
or config.trace.provenance_tracking_level == 1,
)
@functools.wraps(f)

View File

@ -43,7 +43,8 @@ class GraphTransformObserver:
self.log_url = log_url
self.active = (
self.log_url is not None or inductor_config.trace.provenance_tracking
self.log_url is not None
or inductor_config.trace.provenance_tracking_level == 1
)
if self.active:

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import copy
import logging
import traceback
from contextlib import contextmanager
from enum import Enum
@ -10,6 +11,8 @@ from .graph import Graph
from .node import Node
log = logging.getLogger(__name__)
__all__ = [
"preserve_node_meta",
"has_preserved_node_meta",
@ -311,12 +314,20 @@ def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
"""
Given an fx.Graph, return a json that contains the provenance information of each node.
"""
provenance_tracking_json = {}
for node in graph.nodes:
if node.op == "call_function":
provenance_tracking_json[node.name] = (
[source.to_dict() for source in node.meta["from_node"]]
if "from_node" in node.meta
else []
)
return provenance_tracking_json
try:
provenance_tracking_json = {}
for node in graph.nodes:
if node.op == "call_function":
provenance_tracking_json[node.name] = (
[source.to_dict() for source in node.meta["from_node"]]
if "from_node" in node.meta
else []
)
return provenance_tracking_json
except Exception as e:
# Since this is just debugging, it should never interfere with regular
# program execution, so we use this try-except to guard against any error
# TODO: log the error to scuba table for better signal
log.error("Unexpected error in get_graph_provenance_json: %s", e)
log.error(traceback.format_exc())
return {}