Separate provenance tracking to different levels (#160383)

Summary: as title. We've got request from various parties who are interested in turning on the provenance tracking by default. In this PR, we prepare to turn on part of the provenance tracking that doesn't have too much overhead by default.

- Change `provenance_tracking` config to `provenance_tracking_level`
- turn on the following provenance tracking by default when `basic_provenance_tracking`=True
    - `set_kernel_post_grad_provenance_tracing` for kernels, this add mapping between triton kernels and post_grad nodes
    - `dump_inductor_provenance_info` if we're dumping tlparse log
    - `get_graph_provenance_json` and dump `reate_mapping_pre_post_grad_nodes`. This creates mapping between pre_grad and post_grad nodes. Since we're not turning on the provenance tracking in GraphTransformObserver by default, the mapping here maybe incomplete/limited.
    - add stack trace from post grad nodes to inductor IR nodes
    - add exception swallowing for all functions above

Test Plan:
CI

Rollback Plan:

Differential Revision: D80031559

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160383
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu 2025-08-15 04:59:32 +00:00 committed by PyTorch MergeBot
parent 3fc7a95176
commit aa99e0958f
12 changed files with 136 additions and 110 deletions

View File

@ -55,7 +55,7 @@ class TestGraphTransformObserver(TestCase):
) )
) )
@torch._inductor.config.patch("trace.provenance_tracking", True) @torch._inductor.config.patch("trace.provenance_tracking_level", 1)
def test_graph_transform_observer_node_tracking(self): def test_graph_transform_observer_node_tracking(self):
class M(torch.nn.Module): class M(torch.nn.Module):
def forward(self, x): def forward(self, x):
@ -156,7 +156,7 @@ class TestGraphTransformObserver(TestCase):
[NodeSourceAction.REPLACE, NodeSourceAction.CREATE], [NodeSourceAction.REPLACE, NodeSourceAction.CREATE],
) )
@torch._inductor.config.patch("trace.provenance_tracking", True) @torch._inductor.config.patch("trace.provenance_tracking_level", 1)
def test_graph_transform_observer_deepcopy(self): def test_graph_transform_observer_deepcopy(self):
class SimpleLinearModel(torch.nn.Module): class SimpleLinearModel(torch.nn.Module):
def forward(self, x): def forward(self, x):
@ -179,7 +179,7 @@ class TestGraphTransformObserver(TestCase):
self.assertEqual(len(gm2._erase_node_hooks), 0) self.assertEqual(len(gm2._erase_node_hooks), 0)
self.assertEqual(len(gm2._deepcopy_hooks), 0) self.assertEqual(len(gm2._deepcopy_hooks), 0)
@torch._inductor.config.patch("trace.provenance_tracking", True) @torch._inductor.config.patch("trace.provenance_tracking_level", 1)
def test_graph_transform_observer_replace(self): def test_graph_transform_observer_replace(self):
# the node sohuld should not be duplicated # the node sohuld should not be duplicated
class Model(torch.nn.Module): class Model(torch.nn.Module):

View File

@ -62,7 +62,7 @@ class Model3(torch.nn.Module):
@config.patch("trace.enabled", True) @config.patch("trace.enabled", True)
@config.patch("trace.provenance_tracking", True) @config.patch("trace.provenance_tracking_level", 1)
class TestProvenanceTracingArtifact(TestCase): class TestProvenanceTracingArtifact(TestCase):
""" """
This test checks that generated provenance tracing artifact from "post_grad" to This test checks that generated provenance tracing artifact from "post_grad" to

View File

@ -5385,7 +5385,7 @@ class CppScheduling(BaseScheduling):
) )
kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()]) kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()])
# below add provenance tracing info for cpu CppKernel types # below add provenance tracing info for cpu CppKernel types
if config.trace.provenance_tracking: if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing(nodes, kernel_name) set_kernel_post_grad_provenance_tracing(nodes, kernel_name)
kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"

View File

@ -1453,7 +1453,7 @@ class SIMDScheduling(BaseScheduling):
with V.set_kernel_handler(kernel): with V.set_kernel_handler(kernel):
src_code = kernel.codegen_kernel() src_code = kernel.codegen_kernel()
kernel_name = self.define_kernel(src_code, node_schedule, kernel) kernel_name = self.define_kernel(src_code, node_schedule, kernel)
if config.trace.provenance_tracking: if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing( set_kernel_post_grad_provenance_tracing(
node_schedule, # type: ignore[arg-type] node_schedule, # type: ignore[arg-type]
kernel_name, kernel_name,
@ -1664,7 +1664,7 @@ class SIMDScheduling(BaseScheduling):
kernel.kernel_name = self.define_kernel(src_code, node_schedule, kernel) kernel.kernel_name = self.define_kernel(src_code, node_schedule, kernel)
if config.trace.provenance_tracking: if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing( set_kernel_post_grad_provenance_tracing(
node_schedule, kernel.kernel_name node_schedule, kernel.kernel_name
) )
@ -1849,7 +1849,7 @@ class SIMDScheduling(BaseScheduling):
for src_code, kernel, _ in kernel_code_list: for src_code, kernel, _ in kernel_code_list:
kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel) kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel)
# dump provenance node info for ComboKernelNode/ForeachKernel type # dump provenance node info for ComboKernelNode/ForeachKernel type
if config.trace.provenance_tracking: if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing( set_kernel_post_grad_provenance_tracing(
combo_kernel_node.snodes, kernel_name combo_kernel_node.snodes, kernel_name
) )

View File

@ -481,7 +481,7 @@ class ExternKernelOutLine(WrapperLine):
kernel_name = node.get_kernel_name() kernel_name = node.get_kernel_name()
device = d.type if (d := node.get_device()) else V.graph.device_type device = d.type if (d := node.get_device()) else V.graph.device_type
# set provenance tracing kernel mapping for ExternKernel types # set provenance tracing kernel mapping for ExternKernel types
if config.trace.provenance_tracking: if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing(node, kernel_name, is_extern=True) set_kernel_post_grad_provenance_tracing(node, kernel_name, is_extern=True)
self.wrapper._generate_extern_kernel_out_helper( self.wrapper._generate_extern_kernel_out_helper(
kernel_name, kernel_name,

View File

@ -64,7 +64,10 @@ from torch._inductor.cudagraph_utils import (
log_cudagraph_skip_and_bump_counter, log_cudagraph_skip_and_bump_counter,
PlaceholderInfo, PlaceholderInfo,
) )
from torch._inductor.debug import save_args_for_compile_fx_inner from torch._inductor.debug import (
create_mapping_pre_post_grad_nodes,
save_args_for_compile_fx_inner,
)
from torch._inductor.output_code import ( from torch._inductor.output_code import (
CompiledAOTI, CompiledAOTI,
CompiledFxGraph, CompiledFxGraph,
@ -1055,19 +1058,18 @@ def _compile_fx_inner(
log.debug("FX codegen and compilation took %.3fs", time.time() - start) log.debug("FX codegen and compilation took %.3fs", time.time() - start)
if config.trace.provenance_tracking: # Dump provenance artifacts for debugging trace
# Dump provenance artifacts for debugging trace if config.trace.provenance_tracking_level != 0:
provenance_info = torch._inductor.debug.dump_inductor_provenance_info() trace_structured(
# provenance_info might be None if trace.provenance_tracking is not set "artifact",
if provenance_info: metadata_fn=lambda: {
trace_structured( "name": "inductor_provenance_tracking_node_mappings",
"artifact", "encoding": "json",
metadata_fn=lambda: { },
"name": "inductor_provenance_tracking_node_mappings", payload_fn=lambda: json.dumps(
"encoding": "json", torch._inductor.debug.dump_inductor_provenance_info()
}, ),
payload_fn=lambda: json.dumps(provenance_info), )
)
# This message is for printing overview information of inductor mm counts, shapes,etc after lowering # This message is for printing overview information of inductor mm counts, shapes,etc after lowering
if log.isEnabledFor(logging.INFO): if log.isEnabledFor(logging.INFO):
@ -1310,20 +1312,10 @@ class _InProcessFxCompile(FxCompile):
}, },
payload_fn=lambda: inductor_post_grad_graph_str, payload_fn=lambda: inductor_post_grad_graph_str,
) )
if config.trace.provenance_tracking: if config.trace.provenance_tracking_level != 0:
provenance_tracking_json = ( provenance_tracking_json = (
torch.fx.traceback.get_graph_provenance_json(gm.graph) torch.fx.traceback.get_graph_provenance_json(gm.graph)
) )
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_post_to_pre_grad_nodes",
"encoding": "json",
},
payload_fn=lambda: json.dumps(provenance_tracking_json),
)
from torch._inductor.debug import create_mapping_pre_post_grad_nodes
torch._inductor.debug._inductor_post_to_pre_grad_nodes = ( torch._inductor.debug._inductor_post_to_pre_grad_nodes = (
create_mapping_pre_post_grad_nodes( create_mapping_pre_post_grad_nodes(
torch._inductor.debug._pre_grad_graph_id, torch._inductor.debug._pre_grad_graph_id,
@ -2205,7 +2197,9 @@ def compile_fx(
with ( with (
_use_lazy_graph_module(dynamo_config.use_lazy_graph_module), _use_lazy_graph_module(dynamo_config.use_lazy_graph_module),
enable_python_dispatcher(), enable_python_dispatcher(),
torch.fx.traceback.preserve_node_meta(config.trace.provenance_tracking), torch.fx.traceback.preserve_node_meta(
config.trace.provenance_tracking_level == 1
),
torch._inductor.debug.reset_provenance_globals(), torch._inductor.debug.reset_provenance_globals(),
): ):
# Pre-grad passes cannot be run if we weren't given a GraphModule. # Pre-grad passes cannot be run if we weren't given a GraphModule.
@ -2239,7 +2233,7 @@ def compile_fx(
) )
torch._inductor.debug._pre_grad_graph_id = id(model_.graph) torch._inductor.debug._pre_grad_graph_id = id(model_.graph)
if config.trace.provenance_tracking: if config.trace.provenance_tracking_level == 1:
for node in model_.graph.nodes: for node in model_.graph.nodes:
if node.stack_trace: if node.stack_trace:
torch._inductor.debug._inductor_pre_grad_node_stack_trace[ torch._inductor.debug._inductor_pre_grad_node_stack_trace[

View File

@ -1832,10 +1832,18 @@ class trace:
log_autotuning_results = os.environ.get("LOG_AUTOTUNE_RESULTS", "0") == "1" log_autotuning_results = os.environ.get("LOG_AUTOTUNE_RESULTS", "0") == "1"
# Save mapping info from inductor generated triton kernel to post_grad fx nodes to pre_grad fx nodes # Save mapping info from inductor generated kernel to post_grad/pre_grad fx nodes
provenance_tracking = ( # Levels:
os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" # 0 - disabled (default)
or os.environ.get("INDUCTOR_PROVENANCE", "0") == "1" # 1 - normal
# 2 - basic
# Backward compatibility:
# If TORCH_COMPILE_DEBUG=1, level is set to at least 1.
# If INDUCTOR_PROVENANCE is set, use its integer value.
provenance_tracking_level: int = int(
os.environ.get(
"INDUCTOR_PROVENANCE", os.environ.get("TORCH_COMPILE_DEBUG", "0")
)
) )

View File

@ -769,8 +769,6 @@ def create_mapping_pre_post_grad_nodes(
"postToPre": {}, "postToPre": {},
} }
log.info("Creating node mappings for provenance tracking")
if not isinstance(post_to_pre_grad_nodes_json, dict): if not isinstance(post_to_pre_grad_nodes_json, dict):
log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict") log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict")
return empty_return return empty_return
@ -860,8 +858,6 @@ def create_node_mapping_kernel_to_post_grad(
"postToCppCode": {}, "postToCppCode": {},
} }
log.info("Creating node mappings for provenance tracking")
if not isinstance(triton_kernel_to_post_grad_json, dict): if not isinstance(triton_kernel_to_post_grad_json, dict):
log.error( log.error(
"Provenance tacking error: triton_kernel_to_post_grad_json is not a dict" "Provenance tacking error: triton_kernel_to_post_grad_json is not a dict"
@ -905,28 +901,36 @@ def create_node_mapping_kernel_to_post_grad(
def dump_inductor_provenance_info( def dump_inductor_provenance_info(
filename: str = "inductor_generated_kernel_to_post_grad_nodes.json", filename: str = "inductor_generated_kernel_to_post_grad_nodes.json",
) -> dict[str, Any]: ) -> dict[str, Any]:
global _pre_grad_graph_id try:
global _inductor_post_to_pre_grad_nodes global _pre_grad_graph_id
global _inductor_triton_kernel_to_post_grad_node_info global _inductor_post_to_pre_grad_nodes
if config.trace.enabled: global _inductor_triton_kernel_to_post_grad_node_info
with V.debug.fopen(filename, "w") as fd:
log.info("Writing provenance tracing debugging info to %s", fd.name)
json.dump(_inductor_triton_kernel_to_post_grad_node_info, fd)
node_mapping = {}
if _pre_grad_graph_id:
node_mapping_kernel = create_node_mapping_kernel_to_post_grad(
_inductor_triton_kernel_to_post_grad_node_info
)
node_mapping = {
**_inductor_post_to_pre_grad_nodes,
**node_mapping_kernel,
}
if config.trace.enabled: if config.trace.enabled:
with V.debug.fopen( with V.debug.fopen(filename, "w") as fd:
"inductor_provenance_tracking_node_mappings.json", "w" log.info("Writing provenance tracing debugging info to %s", fd.name)
) as fd: json.dump(_inductor_triton_kernel_to_post_grad_node_info, fd)
json.dump(node_mapping, fd) node_mapping = {}
return node_mapping if _pre_grad_graph_id:
node_mapping_kernel = create_node_mapping_kernel_to_post_grad(
_inductor_triton_kernel_to_post_grad_node_info
)
node_mapping = {
**_inductor_post_to_pre_grad_nodes,
**node_mapping_kernel,
}
if config.trace.enabled:
with V.debug.fopen(
"inductor_provenance_tracking_node_mappings.json", "w"
) as fd:
json.dump(node_mapping, fd)
return node_mapping
except Exception as e:
# Since this is just debugging, it should never interfere with regular
# program execution, so we use this try-except to guard against any error
# TODO: log the error to scuba table for better signal
log.error("Unexpected error in dump_inductor_provenance_info: %s", e)
log.error(traceback.format_exc())
return {}
def set_kernel_post_grad_provenance_tracing( def set_kernel_post_grad_provenance_tracing(
@ -934,42 +938,49 @@ def set_kernel_post_grad_provenance_tracing(
kernel_name: str, kernel_name: str,
is_extern: bool = False, is_extern: bool = False,
) -> None: ) -> None:
from .codegen.simd_kernel_features import DisableReduction, EnableReduction try:
from .codegen.simd_kernel_features import DisableReduction, EnableReduction
global _inductor_triton_kernel_to_post_grad_node_info global _inductor_triton_kernel_to_post_grad_node_info
if is_extern: if is_extern:
assert isinstance(node_schedule, ExternKernelOut) assert isinstance(node_schedule, ExternKernelOut)
curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault( curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault(
kernel_name, [] kernel_name, []
)
# 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel.
# "origin_node" is more precise and says that the contents of this node corresponds
# EXACTLY to the output of a particular FX node, but it's not always available
if node_schedule.origin_node:
origin_node_name = node_schedule.origin_node.name
if origin_node_name not in curr_node_info:
curr_node_info.append(origin_node_name)
else:
curr_node_info.extend(
origin.name
for origin in node_schedule.origins
if origin.name not in curr_node_info
) )
else: # 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel.
assert isinstance(node_schedule, list) # "origin_node" is more precise and says that the contents of this node corresponds
for snode in node_schedule: # EXACTLY to the output of a particular FX node, but it's not always available
if snode not in (EnableReduction, DisableReduction): if node_schedule.origin_node:
if snode.node is not None: origin_node_name = node_schedule.origin_node.name
curr_node_info = ( if origin_node_name not in curr_node_info:
_inductor_triton_kernel_to_post_grad_node_info.setdefault( curr_node_info.append(origin_node_name)
kernel_name, [] else:
curr_node_info.extend(
origin.name
for origin in node_schedule.origins
if origin.name not in curr_node_info
)
else:
assert isinstance(node_schedule, list)
for snode in node_schedule:
if snode not in (EnableReduction, DisableReduction):
if snode.node is not None:
curr_node_info = (
_inductor_triton_kernel_to_post_grad_node_info.setdefault(
kernel_name, []
)
) )
) curr_node_info.extend(
curr_node_info.extend( origin.name
origin.name for origin in snode.node.origins
for origin in snode.node.origins if origin.name not in curr_node_info
if origin.name not in curr_node_info )
) except Exception as e:
# Since this is just debugging, it should never interfere with regular
# program execution, so we use this try-except to guard against any error
# TODO: log the error to scuba table for better signal
log.error("Unexpected error in set_kernel_post_grad_provenance_tracing: %s", e)
log.error(traceback.format_exc())
def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None: def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None:

View File

@ -129,7 +129,7 @@ def _transfer_meta(
# transfer metadata after pattern matching occurs. # transfer metadata after pattern matching occurs.
# skip "val" and "tensor_meta" because this info is too specific; it's unlikely # skip "val" and "tensor_meta" because this info is too specific; it's unlikely
# to remain accurate after pattern matching has occurred. # to remain accurate after pattern matching has occurred.
if config.trace.provenance_tracking: if config.trace.provenance_tracking_level == 1:
# We handle "from_node" field of the node meta specially to record that the new node comes from the old_node. # We handle "from_node" field of the node meta specially to record that the new node comes from the old_node.
new_from_node = new_meta.get("from_node", []).copy() new_from_node = new_meta.get("from_node", []).copy()
new_from_node.append(NodeSource(old_node, pass_name, NodeSourceAction.REPLACE)) new_from_node.append(NodeSource(old_node, pass_name, NodeSourceAction.REPLACE))

View File

@ -2345,7 +2345,8 @@ def make_fx(
record_module_stack, record_module_stack,
_allow_fake_constant, _allow_fake_constant,
_error_on_data_dependent_ops, _error_on_data_dependent_ops,
record_stack_traces=record_stack_traces or config.trace.provenance_tracking, record_stack_traces=record_stack_traces
or config.trace.provenance_tracking_level == 1,
) )
@functools.wraps(f) @functools.wraps(f)

View File

@ -43,7 +43,8 @@ class GraphTransformObserver:
self.log_url = log_url self.log_url = log_url
self.active = ( self.active = (
self.log_url is not None or inductor_config.trace.provenance_tracking self.log_url is not None
or inductor_config.trace.provenance_tracking_level == 1
) )
if self.active: if self.active:

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import copy import copy
import logging
import traceback import traceback
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum from enum import Enum
@ -10,6 +11,8 @@ from .graph import Graph
from .node import Node from .node import Node
log = logging.getLogger(__name__)
__all__ = [ __all__ = [
"preserve_node_meta", "preserve_node_meta",
"has_preserved_node_meta", "has_preserved_node_meta",
@ -311,12 +314,20 @@ def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
""" """
Given an fx.Graph, return a json that contains the provenance information of each node. Given an fx.Graph, return a json that contains the provenance information of each node.
""" """
provenance_tracking_json = {} try:
for node in graph.nodes: provenance_tracking_json = {}
if node.op == "call_function": for node in graph.nodes:
provenance_tracking_json[node.name] = ( if node.op == "call_function":
[source.to_dict() for source in node.meta["from_node"]] provenance_tracking_json[node.name] = (
if "from_node" in node.meta [source.to_dict() for source in node.meta["from_node"]]
else [] if "from_node" in node.meta
) else []
return provenance_tracking_json )
return provenance_tracking_json
except Exception as e:
# Since this is just debugging, it should never interfere with regular
# program execution, so we use this try-except to guard against any error
# TODO: log the error to scuba table for better signal
log.error("Unexpected error in get_graph_provenance_json: %s", e)
log.error(traceback.format_exc())
return {}