diff --git a/test/fx/test_fx_xform_observer.py b/test/fx/test_fx_xform_observer.py index 307be1e7212..d9dcb8504ba 100644 --- a/test/fx/test_fx_xform_observer.py +++ b/test/fx/test_fx_xform_observer.py @@ -55,7 +55,7 @@ class TestGraphTransformObserver(TestCase): ) ) - @torch._inductor.config.patch("trace.provenance_tracking", True) + @torch._inductor.config.patch("trace.provenance_tracking_level", 1) def test_graph_transform_observer_node_tracking(self): class M(torch.nn.Module): def forward(self, x): @@ -156,7 +156,7 @@ class TestGraphTransformObserver(TestCase): [NodeSourceAction.REPLACE, NodeSourceAction.CREATE], ) - @torch._inductor.config.patch("trace.provenance_tracking", True) + @torch._inductor.config.patch("trace.provenance_tracking_level", 1) def test_graph_transform_observer_deepcopy(self): class SimpleLinearModel(torch.nn.Module): def forward(self, x): @@ -179,7 +179,7 @@ class TestGraphTransformObserver(TestCase): self.assertEqual(len(gm2._erase_node_hooks), 0) self.assertEqual(len(gm2._deepcopy_hooks), 0) - @torch._inductor.config.patch("trace.provenance_tracking", True) + @torch._inductor.config.patch("trace.provenance_tracking_level", 1) def test_graph_transform_observer_replace(self): # the node sohuld should not be duplicated class Model(torch.nn.Module): diff --git a/test/inductor/test_provenance_tracing.py b/test/inductor/test_provenance_tracing.py index 77e099cf0cb..22adac53b4f 100644 --- a/test/inductor/test_provenance_tracing.py +++ b/test/inductor/test_provenance_tracing.py @@ -62,7 +62,7 @@ class Model3(torch.nn.Module): @config.patch("trace.enabled", True) -@config.patch("trace.provenance_tracking", True) +@config.patch("trace.provenance_tracking_level", 1) class TestProvenanceTracingArtifact(TestCase): """ This test checks that generated provenance tracing artifact from "post_grad" to diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 1ee9d033d4f..e71a1d91b0f 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -5385,7 +5385,7 @@ class CppScheduling(BaseScheduling): ) kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()]) # below add provenance tracing info for cpu CppKernel types - if config.trace.provenance_tracking: + if config.trace.provenance_tracking_level != 0: set_kernel_post_grad_provenance_tracing(nodes, kernel_name) kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 5b1350a9239..da077e725f7 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1453,7 +1453,7 @@ class SIMDScheduling(BaseScheduling): with V.set_kernel_handler(kernel): src_code = kernel.codegen_kernel() kernel_name = self.define_kernel(src_code, node_schedule, kernel) - if config.trace.provenance_tracking: + if config.trace.provenance_tracking_level != 0: set_kernel_post_grad_provenance_tracing( node_schedule, # type: ignore[arg-type] kernel_name, @@ -1664,7 +1664,7 @@ class SIMDScheduling(BaseScheduling): kernel.kernel_name = self.define_kernel(src_code, node_schedule, kernel) - if config.trace.provenance_tracking: + if config.trace.provenance_tracking_level != 0: set_kernel_post_grad_provenance_tracing( node_schedule, kernel.kernel_name ) @@ -1849,7 +1849,7 @@ class SIMDScheduling(BaseScheduling): for src_code, kernel, _ in kernel_code_list: kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel) # dump provenance node info for ComboKernelNode/ForeachKernel type - if config.trace.provenance_tracking: + if config.trace.provenance_tracking_level != 0: set_kernel_post_grad_provenance_tracing( combo_kernel_node.snodes, kernel_name ) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index ae6c809aba9..27d8a28cb96 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -481,7 +481,7 @@ class ExternKernelOutLine(WrapperLine): kernel_name = node.get_kernel_name() device = d.type if (d := node.get_device()) else V.graph.device_type # set provenance tracing kernel mapping for ExternKernel types - if config.trace.provenance_tracking: + if config.trace.provenance_tracking_level != 0: set_kernel_post_grad_provenance_tracing(node, kernel_name, is_extern=True) self.wrapper._generate_extern_kernel_out_helper( kernel_name, diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index eaab9020f1e..2ee0d6a9caa 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -64,7 +64,10 @@ from torch._inductor.cudagraph_utils import ( log_cudagraph_skip_and_bump_counter, PlaceholderInfo, ) -from torch._inductor.debug import save_args_for_compile_fx_inner +from torch._inductor.debug import ( + create_mapping_pre_post_grad_nodes, + save_args_for_compile_fx_inner, +) from torch._inductor.output_code import ( CompiledAOTI, CompiledFxGraph, @@ -1055,19 +1058,18 @@ def _compile_fx_inner( log.debug("FX codegen and compilation took %.3fs", time.time() - start) - if config.trace.provenance_tracking: - # Dump provenance artifacts for debugging trace - provenance_info = torch._inductor.debug.dump_inductor_provenance_info() - # provenance_info might be None if trace.provenance_tracking is not set - if provenance_info: - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "inductor_provenance_tracking_node_mappings", - "encoding": "json", - }, - payload_fn=lambda: json.dumps(provenance_info), - ) + # Dump provenance artifacts for debugging trace + if config.trace.provenance_tracking_level != 0: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_provenance_tracking_node_mappings", + "encoding": "json", + }, + payload_fn=lambda: json.dumps( + torch._inductor.debug.dump_inductor_provenance_info() + ), + ) # This message is for printing overview information of inductor mm counts, shapes,etc after lowering if log.isEnabledFor(logging.INFO): @@ -1310,20 +1312,10 @@ class _InProcessFxCompile(FxCompile): }, payload_fn=lambda: inductor_post_grad_graph_str, ) - if config.trace.provenance_tracking: + if config.trace.provenance_tracking_level != 0: provenance_tracking_json = ( torch.fx.traceback.get_graph_provenance_json(gm.graph) ) - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "inductor_post_to_pre_grad_nodes", - "encoding": "json", - }, - payload_fn=lambda: json.dumps(provenance_tracking_json), - ) - from torch._inductor.debug import create_mapping_pre_post_grad_nodes - torch._inductor.debug._inductor_post_to_pre_grad_nodes = ( create_mapping_pre_post_grad_nodes( torch._inductor.debug._pre_grad_graph_id, @@ -2205,7 +2197,9 @@ def compile_fx( with ( _use_lazy_graph_module(dynamo_config.use_lazy_graph_module), enable_python_dispatcher(), - torch.fx.traceback.preserve_node_meta(config.trace.provenance_tracking), + torch.fx.traceback.preserve_node_meta( + config.trace.provenance_tracking_level == 1 + ), torch._inductor.debug.reset_provenance_globals(), ): # Pre-grad passes cannot be run if we weren't given a GraphModule. @@ -2239,7 +2233,7 @@ def compile_fx( ) torch._inductor.debug._pre_grad_graph_id = id(model_.graph) - if config.trace.provenance_tracking: + if config.trace.provenance_tracking_level == 1: for node in model_.graph.nodes: if node.stack_trace: torch._inductor.debug._inductor_pre_grad_node_stack_trace[ diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index deebfa273ba..335a9d01cd7 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1832,10 +1832,18 @@ class trace: log_autotuning_results = os.environ.get("LOG_AUTOTUNE_RESULTS", "0") == "1" - # Save mapping info from inductor generated triton kernel to post_grad fx nodes to pre_grad fx nodes - provenance_tracking = ( - os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" - or os.environ.get("INDUCTOR_PROVENANCE", "0") == "1" + # Save mapping info from inductor generated kernel to post_grad/pre_grad fx nodes + # Levels: + # 0 - disabled (default) + # 1 - normal + # 2 - basic + # Backward compatibility: + # If TORCH_COMPILE_DEBUG=1, level is set to at least 1. + # If INDUCTOR_PROVENANCE is set, use its integer value. + provenance_tracking_level: int = int( + os.environ.get( + "INDUCTOR_PROVENANCE", os.environ.get("TORCH_COMPILE_DEBUG", "0") + ) ) diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index 06430b02756..a53c0689d6b 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -769,8 +769,6 @@ def create_mapping_pre_post_grad_nodes( "postToPre": {}, } - log.info("Creating node mappings for provenance tracking") - if not isinstance(post_to_pre_grad_nodes_json, dict): log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict") return empty_return @@ -860,8 +858,6 @@ def create_node_mapping_kernel_to_post_grad( "postToCppCode": {}, } - log.info("Creating node mappings for provenance tracking") - if not isinstance(triton_kernel_to_post_grad_json, dict): log.error( "Provenance tacking error: triton_kernel_to_post_grad_json is not a dict" @@ -905,28 +901,36 @@ def create_node_mapping_kernel_to_post_grad( def dump_inductor_provenance_info( filename: str = "inductor_generated_kernel_to_post_grad_nodes.json", ) -> dict[str, Any]: - global _pre_grad_graph_id - global _inductor_post_to_pre_grad_nodes - global _inductor_triton_kernel_to_post_grad_node_info - if config.trace.enabled: - with V.debug.fopen(filename, "w") as fd: - log.info("Writing provenance tracing debugging info to %s", fd.name) - json.dump(_inductor_triton_kernel_to_post_grad_node_info, fd) - node_mapping = {} - if _pre_grad_graph_id: - node_mapping_kernel = create_node_mapping_kernel_to_post_grad( - _inductor_triton_kernel_to_post_grad_node_info - ) - node_mapping = { - **_inductor_post_to_pre_grad_nodes, - **node_mapping_kernel, - } + try: + global _pre_grad_graph_id + global _inductor_post_to_pre_grad_nodes + global _inductor_triton_kernel_to_post_grad_node_info if config.trace.enabled: - with V.debug.fopen( - "inductor_provenance_tracking_node_mappings.json", "w" - ) as fd: - json.dump(node_mapping, fd) - return node_mapping + with V.debug.fopen(filename, "w") as fd: + log.info("Writing provenance tracing debugging info to %s", fd.name) + json.dump(_inductor_triton_kernel_to_post_grad_node_info, fd) + node_mapping = {} + if _pre_grad_graph_id: + node_mapping_kernel = create_node_mapping_kernel_to_post_grad( + _inductor_triton_kernel_to_post_grad_node_info + ) + node_mapping = { + **_inductor_post_to_pre_grad_nodes, + **node_mapping_kernel, + } + if config.trace.enabled: + with V.debug.fopen( + "inductor_provenance_tracking_node_mappings.json", "w" + ) as fd: + json.dump(node_mapping, fd) + return node_mapping + except Exception as e: + # Since this is just debugging, it should never interfere with regular + # program execution, so we use this try-except to guard against any error + # TODO: log the error to scuba table for better signal + log.error("Unexpected error in dump_inductor_provenance_info: %s", e) + log.error(traceback.format_exc()) + return {} def set_kernel_post_grad_provenance_tracing( @@ -934,42 +938,49 @@ def set_kernel_post_grad_provenance_tracing( kernel_name: str, is_extern: bool = False, ) -> None: - from .codegen.simd_kernel_features import DisableReduction, EnableReduction + try: + from .codegen.simd_kernel_features import DisableReduction, EnableReduction - global _inductor_triton_kernel_to_post_grad_node_info - if is_extern: - assert isinstance(node_schedule, ExternKernelOut) - curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault( - kernel_name, [] - ) - # 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel. - # "origin_node" is more precise and says that the contents of this node corresponds - # EXACTLY to the output of a particular FX node, but it's not always available - if node_schedule.origin_node: - origin_node_name = node_schedule.origin_node.name - if origin_node_name not in curr_node_info: - curr_node_info.append(origin_node_name) - else: - curr_node_info.extend( - origin.name - for origin in node_schedule.origins - if origin.name not in curr_node_info + global _inductor_triton_kernel_to_post_grad_node_info + if is_extern: + assert isinstance(node_schedule, ExternKernelOut) + curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault( + kernel_name, [] ) - else: - assert isinstance(node_schedule, list) - for snode in node_schedule: - if snode not in (EnableReduction, DisableReduction): - if snode.node is not None: - curr_node_info = ( - _inductor_triton_kernel_to_post_grad_node_info.setdefault( - kernel_name, [] + # 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel. + # "origin_node" is more precise and says that the contents of this node corresponds + # EXACTLY to the output of a particular FX node, but it's not always available + if node_schedule.origin_node: + origin_node_name = node_schedule.origin_node.name + if origin_node_name not in curr_node_info: + curr_node_info.append(origin_node_name) + else: + curr_node_info.extend( + origin.name + for origin in node_schedule.origins + if origin.name not in curr_node_info + ) + else: + assert isinstance(node_schedule, list) + for snode in node_schedule: + if snode not in (EnableReduction, DisableReduction): + if snode.node is not None: + curr_node_info = ( + _inductor_triton_kernel_to_post_grad_node_info.setdefault( + kernel_name, [] + ) ) - ) - curr_node_info.extend( - origin.name - for origin in snode.node.origins - if origin.name not in curr_node_info - ) + curr_node_info.extend( + origin.name + for origin in snode.node.origins + if origin.name not in curr_node_info + ) + except Exception as e: + # Since this is just debugging, it should never interfere with regular + # program execution, so we use this try-except to guard against any error + # TODO: log the error to scuba table for better signal + log.error("Unexpected error in set_kernel_post_grad_provenance_tracing: %s", e) + log.error(traceback.format_exc()) def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None: diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 44ebc7ad41c..a80994b2d6b 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -129,7 +129,7 @@ def _transfer_meta( # transfer metadata after pattern matching occurs. # skip "val" and "tensor_meta" because this info is too specific; it's unlikely # to remain accurate after pattern matching has occurred. - if config.trace.provenance_tracking: + if config.trace.provenance_tracking_level == 1: # We handle "from_node" field of the node meta specially to record that the new node comes from the old_node. new_from_node = new_meta.get("from_node", []).copy() new_from_node.append(NodeSource(old_node, pass_name, NodeSourceAction.REPLACE)) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 8f203c9ef24..7d456cbb5cb 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -2345,7 +2345,8 @@ def make_fx( record_module_stack, _allow_fake_constant, _error_on_data_dependent_ops, - record_stack_traces=record_stack_traces or config.trace.provenance_tracking, + record_stack_traces=record_stack_traces + or config.trace.provenance_tracking_level == 1, ) @functools.wraps(f) diff --git a/torch/fx/passes/graph_transform_observer.py b/torch/fx/passes/graph_transform_observer.py index 75a6ef6a2bc..6479af66589 100644 --- a/torch/fx/passes/graph_transform_observer.py +++ b/torch/fx/passes/graph_transform_observer.py @@ -43,7 +43,8 @@ class GraphTransformObserver: self.log_url = log_url self.active = ( - self.log_url is not None or inductor_config.trace.provenance_tracking + self.log_url is not None + or inductor_config.trace.provenance_tracking_level == 1 ) if self.active: diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index 648a80b87b6..0a8ddbc6e16 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import copy +import logging import traceback from contextlib import contextmanager from enum import Enum @@ -10,6 +11,8 @@ from .graph import Graph from .node import Node +log = logging.getLogger(__name__) + __all__ = [ "preserve_node_meta", "has_preserved_node_meta", @@ -311,12 +314,20 @@ def get_graph_provenance_json(graph: Graph) -> dict[str, Any]: """ Given an fx.Graph, return a json that contains the provenance information of each node. """ - provenance_tracking_json = {} - for node in graph.nodes: - if node.op == "call_function": - provenance_tracking_json[node.name] = ( - [source.to_dict() for source in node.meta["from_node"]] - if "from_node" in node.meta - else [] - ) - return provenance_tracking_json + try: + provenance_tracking_json = {} + for node in graph.nodes: + if node.op == "call_function": + provenance_tracking_json[node.name] = ( + [source.to_dict() for source in node.meta["from_node"]] + if "from_node" in node.meta + else [] + ) + return provenance_tracking_json + except Exception as e: + # Since this is just debugging, it should never interfere with regular + # program execution, so we use this try-except to guard against any error + # TODO: log the error to scuba table for better signal + log.error("Unexpected error in get_graph_provenance_json: %s", e) + log.error(traceback.format_exc()) + return {}