Separate provenance tracking to different levels (#160383)

Summary: as title. We've got request from various parties who are interested in turning on the provenance tracking by default. In this PR, we prepare to turn on part of the provenance tracking that doesn't have too much overhead by default.

- Change `provenance_tracking` config to `provenance_tracking_level`
- turn on the following provenance tracking by default when `basic_provenance_tracking`=True
    - `set_kernel_post_grad_provenance_tracing` for kernels, this add mapping between triton kernels and post_grad nodes
    - `dump_inductor_provenance_info` if we're dumping tlparse log
    - `get_graph_provenance_json` and dump `reate_mapping_pre_post_grad_nodes`. This creates mapping between pre_grad and post_grad nodes. Since we're not turning on the provenance tracking in GraphTransformObserver by default, the mapping here maybe incomplete/limited.
    - add stack trace from post grad nodes to inductor IR nodes
    - add exception swallowing for all functions above

Test Plan:
CI

Rollback Plan:

Differential Revision: D80031559

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160383
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu 2025-08-15 04:59:32 +00:00 committed by PyTorch MergeBot
parent 3fc7a95176
commit aa99e0958f
12 changed files with 136 additions and 110 deletions

View File

@ -55,7 +55,7 @@ class TestGraphTransformObserver(TestCase):
) )
) )
@torch._inductor.config.patch("trace.provenance_tracking", True) @torch._inductor.config.patch("trace.provenance_tracking_level", 1)
def test_graph_transform_observer_node_tracking(self): def test_graph_transform_observer_node_tracking(self):
class M(torch.nn.Module): class M(torch.nn.Module):
def forward(self, x): def forward(self, x):
@ -156,7 +156,7 @@ class TestGraphTransformObserver(TestCase):
[NodeSourceAction.REPLACE, NodeSourceAction.CREATE], [NodeSourceAction.REPLACE, NodeSourceAction.CREATE],
) )
@torch._inductor.config.patch("trace.provenance_tracking", True) @torch._inductor.config.patch("trace.provenance_tracking_level", 1)
def test_graph_transform_observer_deepcopy(self): def test_graph_transform_observer_deepcopy(self):
class SimpleLinearModel(torch.nn.Module): class SimpleLinearModel(torch.nn.Module):
def forward(self, x): def forward(self, x):
@ -179,7 +179,7 @@ class TestGraphTransformObserver(TestCase):
self.assertEqual(len(gm2._erase_node_hooks), 0) self.assertEqual(len(gm2._erase_node_hooks), 0)
self.assertEqual(len(gm2._deepcopy_hooks), 0) self.assertEqual(len(gm2._deepcopy_hooks), 0)
@torch._inductor.config.patch("trace.provenance_tracking", True) @torch._inductor.config.patch("trace.provenance_tracking_level", 1)
def test_graph_transform_observer_replace(self): def test_graph_transform_observer_replace(self):
# the node sohuld should not be duplicated # the node sohuld should not be duplicated
class Model(torch.nn.Module): class Model(torch.nn.Module):

View File

@ -62,7 +62,7 @@ class Model3(torch.nn.Module):
@config.patch("trace.enabled", True) @config.patch("trace.enabled", True)
@config.patch("trace.provenance_tracking", True) @config.patch("trace.provenance_tracking_level", 1)
class TestProvenanceTracingArtifact(TestCase): class TestProvenanceTracingArtifact(TestCase):
""" """
This test checks that generated provenance tracing artifact from "post_grad" to This test checks that generated provenance tracing artifact from "post_grad" to

View File

@ -5385,7 +5385,7 @@ class CppScheduling(BaseScheduling):
) )
kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()]) kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()])
# below add provenance tracing info for cpu CppKernel types # below add provenance tracing info for cpu CppKernel types
if config.trace.provenance_tracking: if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing(nodes, kernel_name) set_kernel_post_grad_provenance_tracing(nodes, kernel_name)
kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"

View File

@ -1453,7 +1453,7 @@ class SIMDScheduling(BaseScheduling):
with V.set_kernel_handler(kernel): with V.set_kernel_handler(kernel):
src_code = kernel.codegen_kernel() src_code = kernel.codegen_kernel()
kernel_name = self.define_kernel(src_code, node_schedule, kernel) kernel_name = self.define_kernel(src_code, node_schedule, kernel)
if config.trace.provenance_tracking: if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing( set_kernel_post_grad_provenance_tracing(
node_schedule, # type: ignore[arg-type] node_schedule, # type: ignore[arg-type]
kernel_name, kernel_name,
@ -1664,7 +1664,7 @@ class SIMDScheduling(BaseScheduling):
kernel.kernel_name = self.define_kernel(src_code, node_schedule, kernel) kernel.kernel_name = self.define_kernel(src_code, node_schedule, kernel)
if config.trace.provenance_tracking: if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing( set_kernel_post_grad_provenance_tracing(
node_schedule, kernel.kernel_name node_schedule, kernel.kernel_name
) )
@ -1849,7 +1849,7 @@ class SIMDScheduling(BaseScheduling):
for src_code, kernel, _ in kernel_code_list: for src_code, kernel, _ in kernel_code_list:
kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel) kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel)
# dump provenance node info for ComboKernelNode/ForeachKernel type # dump provenance node info for ComboKernelNode/ForeachKernel type
if config.trace.provenance_tracking: if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing( set_kernel_post_grad_provenance_tracing(
combo_kernel_node.snodes, kernel_name combo_kernel_node.snodes, kernel_name
) )

View File

@ -481,7 +481,7 @@ class ExternKernelOutLine(WrapperLine):
kernel_name = node.get_kernel_name() kernel_name = node.get_kernel_name()
device = d.type if (d := node.get_device()) else V.graph.device_type device = d.type if (d := node.get_device()) else V.graph.device_type
# set provenance tracing kernel mapping for ExternKernel types # set provenance tracing kernel mapping for ExternKernel types
if config.trace.provenance_tracking: if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing(node, kernel_name, is_extern=True) set_kernel_post_grad_provenance_tracing(node, kernel_name, is_extern=True)
self.wrapper._generate_extern_kernel_out_helper( self.wrapper._generate_extern_kernel_out_helper(
kernel_name, kernel_name,

View File

@ -64,7 +64,10 @@ from torch._inductor.cudagraph_utils import (
log_cudagraph_skip_and_bump_counter, log_cudagraph_skip_and_bump_counter,
PlaceholderInfo, PlaceholderInfo,
) )
from torch._inductor.debug import save_args_for_compile_fx_inner from torch._inductor.debug import (
create_mapping_pre_post_grad_nodes,
save_args_for_compile_fx_inner,
)
from torch._inductor.output_code import ( from torch._inductor.output_code import (
CompiledAOTI, CompiledAOTI,
CompiledFxGraph, CompiledFxGraph,
@ -1055,18 +1058,17 @@ def _compile_fx_inner(
log.debug("FX codegen and compilation took %.3fs", time.time() - start) log.debug("FX codegen and compilation took %.3fs", time.time() - start)
if config.trace.provenance_tracking:
# Dump provenance artifacts for debugging trace # Dump provenance artifacts for debugging trace
provenance_info = torch._inductor.debug.dump_inductor_provenance_info() if config.trace.provenance_tracking_level != 0:
# provenance_info might be None if trace.provenance_tracking is not set
if provenance_info:
trace_structured( trace_structured(
"artifact", "artifact",
metadata_fn=lambda: { metadata_fn=lambda: {
"name": "inductor_provenance_tracking_node_mappings", "name": "inductor_provenance_tracking_node_mappings",
"encoding": "json", "encoding": "json",
}, },
payload_fn=lambda: json.dumps(provenance_info), payload_fn=lambda: json.dumps(
torch._inductor.debug.dump_inductor_provenance_info()
),
) )
# This message is for printing overview information of inductor mm counts, shapes,etc after lowering # This message is for printing overview information of inductor mm counts, shapes,etc after lowering
@ -1310,20 +1312,10 @@ class _InProcessFxCompile(FxCompile):
}, },
payload_fn=lambda: inductor_post_grad_graph_str, payload_fn=lambda: inductor_post_grad_graph_str,
) )
if config.trace.provenance_tracking: if config.trace.provenance_tracking_level != 0:
provenance_tracking_json = ( provenance_tracking_json = (
torch.fx.traceback.get_graph_provenance_json(gm.graph) torch.fx.traceback.get_graph_provenance_json(gm.graph)
) )
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_post_to_pre_grad_nodes",
"encoding": "json",
},
payload_fn=lambda: json.dumps(provenance_tracking_json),
)
from torch._inductor.debug import create_mapping_pre_post_grad_nodes
torch._inductor.debug._inductor_post_to_pre_grad_nodes = ( torch._inductor.debug._inductor_post_to_pre_grad_nodes = (
create_mapping_pre_post_grad_nodes( create_mapping_pre_post_grad_nodes(
torch._inductor.debug._pre_grad_graph_id, torch._inductor.debug._pre_grad_graph_id,
@ -2205,7 +2197,9 @@ def compile_fx(
with ( with (
_use_lazy_graph_module(dynamo_config.use_lazy_graph_module), _use_lazy_graph_module(dynamo_config.use_lazy_graph_module),
enable_python_dispatcher(), enable_python_dispatcher(),
torch.fx.traceback.preserve_node_meta(config.trace.provenance_tracking), torch.fx.traceback.preserve_node_meta(
config.trace.provenance_tracking_level == 1
),
torch._inductor.debug.reset_provenance_globals(), torch._inductor.debug.reset_provenance_globals(),
): ):
# Pre-grad passes cannot be run if we weren't given a GraphModule. # Pre-grad passes cannot be run if we weren't given a GraphModule.
@ -2239,7 +2233,7 @@ def compile_fx(
) )
torch._inductor.debug._pre_grad_graph_id = id(model_.graph) torch._inductor.debug._pre_grad_graph_id = id(model_.graph)
if config.trace.provenance_tracking: if config.trace.provenance_tracking_level == 1:
for node in model_.graph.nodes: for node in model_.graph.nodes:
if node.stack_trace: if node.stack_trace:
torch._inductor.debug._inductor_pre_grad_node_stack_trace[ torch._inductor.debug._inductor_pre_grad_node_stack_trace[

View File

@ -1832,10 +1832,18 @@ class trace:
log_autotuning_results = os.environ.get("LOG_AUTOTUNE_RESULTS", "0") == "1" log_autotuning_results = os.environ.get("LOG_AUTOTUNE_RESULTS", "0") == "1"
# Save mapping info from inductor generated triton kernel to post_grad fx nodes to pre_grad fx nodes # Save mapping info from inductor generated kernel to post_grad/pre_grad fx nodes
provenance_tracking = ( # Levels:
os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" # 0 - disabled (default)
or os.environ.get("INDUCTOR_PROVENANCE", "0") == "1" # 1 - normal
# 2 - basic
# Backward compatibility:
# If TORCH_COMPILE_DEBUG=1, level is set to at least 1.
# If INDUCTOR_PROVENANCE is set, use its integer value.
provenance_tracking_level: int = int(
os.environ.get(
"INDUCTOR_PROVENANCE", os.environ.get("TORCH_COMPILE_DEBUG", "0")
)
) )

View File

@ -769,8 +769,6 @@ def create_mapping_pre_post_grad_nodes(
"postToPre": {}, "postToPre": {},
} }
log.info("Creating node mappings for provenance tracking")
if not isinstance(post_to_pre_grad_nodes_json, dict): if not isinstance(post_to_pre_grad_nodes_json, dict):
log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict") log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict")
return empty_return return empty_return
@ -860,8 +858,6 @@ def create_node_mapping_kernel_to_post_grad(
"postToCppCode": {}, "postToCppCode": {},
} }
log.info("Creating node mappings for provenance tracking")
if not isinstance(triton_kernel_to_post_grad_json, dict): if not isinstance(triton_kernel_to_post_grad_json, dict):
log.error( log.error(
"Provenance tacking error: triton_kernel_to_post_grad_json is not a dict" "Provenance tacking error: triton_kernel_to_post_grad_json is not a dict"
@ -905,6 +901,7 @@ def create_node_mapping_kernel_to_post_grad(
def dump_inductor_provenance_info( def dump_inductor_provenance_info(
filename: str = "inductor_generated_kernel_to_post_grad_nodes.json", filename: str = "inductor_generated_kernel_to_post_grad_nodes.json",
) -> dict[str, Any]: ) -> dict[str, Any]:
try:
global _pre_grad_graph_id global _pre_grad_graph_id
global _inductor_post_to_pre_grad_nodes global _inductor_post_to_pre_grad_nodes
global _inductor_triton_kernel_to_post_grad_node_info global _inductor_triton_kernel_to_post_grad_node_info
@ -927,6 +924,13 @@ def dump_inductor_provenance_info(
) as fd: ) as fd:
json.dump(node_mapping, fd) json.dump(node_mapping, fd)
return node_mapping return node_mapping
except Exception as e:
# Since this is just debugging, it should never interfere with regular
# program execution, so we use this try-except to guard against any error
# TODO: log the error to scuba table for better signal
log.error("Unexpected error in dump_inductor_provenance_info: %s", e)
log.error(traceback.format_exc())
return {}
def set_kernel_post_grad_provenance_tracing( def set_kernel_post_grad_provenance_tracing(
@ -934,6 +938,7 @@ def set_kernel_post_grad_provenance_tracing(
kernel_name: str, kernel_name: str,
is_extern: bool = False, is_extern: bool = False,
) -> None: ) -> None:
try:
from .codegen.simd_kernel_features import DisableReduction, EnableReduction from .codegen.simd_kernel_features import DisableReduction, EnableReduction
global _inductor_triton_kernel_to_post_grad_node_info global _inductor_triton_kernel_to_post_grad_node_info
@ -970,6 +975,12 @@ def set_kernel_post_grad_provenance_tracing(
for origin in snode.node.origins for origin in snode.node.origins
if origin.name not in curr_node_info if origin.name not in curr_node_info
) )
except Exception as e:
# Since this is just debugging, it should never interfere with regular
# program execution, so we use this try-except to guard against any error
# TODO: log the error to scuba table for better signal
log.error("Unexpected error in set_kernel_post_grad_provenance_tracing: %s", e)
log.error(traceback.format_exc())
def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None: def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None:

View File

@ -129,7 +129,7 @@ def _transfer_meta(
# transfer metadata after pattern matching occurs. # transfer metadata after pattern matching occurs.
# skip "val" and "tensor_meta" because this info is too specific; it's unlikely # skip "val" and "tensor_meta" because this info is too specific; it's unlikely
# to remain accurate after pattern matching has occurred. # to remain accurate after pattern matching has occurred.
if config.trace.provenance_tracking: if config.trace.provenance_tracking_level == 1:
# We handle "from_node" field of the node meta specially to record that the new node comes from the old_node. # We handle "from_node" field of the node meta specially to record that the new node comes from the old_node.
new_from_node = new_meta.get("from_node", []).copy() new_from_node = new_meta.get("from_node", []).copy()
new_from_node.append(NodeSource(old_node, pass_name, NodeSourceAction.REPLACE)) new_from_node.append(NodeSource(old_node, pass_name, NodeSourceAction.REPLACE))

View File

@ -2345,7 +2345,8 @@ def make_fx(
record_module_stack, record_module_stack,
_allow_fake_constant, _allow_fake_constant,
_error_on_data_dependent_ops, _error_on_data_dependent_ops,
record_stack_traces=record_stack_traces or config.trace.provenance_tracking, record_stack_traces=record_stack_traces
or config.trace.provenance_tracking_level == 1,
) )
@functools.wraps(f) @functools.wraps(f)

View File

@ -43,7 +43,8 @@ class GraphTransformObserver:
self.log_url = log_url self.log_url = log_url
self.active = ( self.active = (
self.log_url is not None or inductor_config.trace.provenance_tracking self.log_url is not None
or inductor_config.trace.provenance_tracking_level == 1
) )
if self.active: if self.active:

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import copy import copy
import logging
import traceback import traceback
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum from enum import Enum
@ -10,6 +11,8 @@ from .graph import Graph
from .node import Node from .node import Node
log = logging.getLogger(__name__)
__all__ = [ __all__ = [
"preserve_node_meta", "preserve_node_meta",
"has_preserved_node_meta", "has_preserved_node_meta",
@ -311,6 +314,7 @@ def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
""" """
Given an fx.Graph, return a json that contains the provenance information of each node. Given an fx.Graph, return a json that contains the provenance information of each node.
""" """
try:
provenance_tracking_json = {} provenance_tracking_json = {}
for node in graph.nodes: for node in graph.nodes:
if node.op == "call_function": if node.op == "call_function":
@ -320,3 +324,10 @@ def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
else [] else []
) )
return provenance_tracking_json return provenance_tracking_json
except Exception as e:
# Since this is just debugging, it should never interfere with regular
# program execution, so we use this try-except to guard against any error
# TODO: log the error to scuba table for better signal
log.error("Unexpected error in get_graph_provenance_json: %s", e)
log.error(traceback.format_exc())
return {}