mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Separate provenance tracking to different levels (#160383)
Summary: as title. We've got request from various parties who are interested in turning on the provenance tracking by default. In this PR, we prepare to turn on part of the provenance tracking that doesn't have too much overhead by default.
- Change `provenance_tracking` config to `provenance_tracking_level`
- turn on the following provenance tracking by default when `basic_provenance_tracking`=True
- `set_kernel_post_grad_provenance_tracing` for kernels, this add mapping between triton kernels and post_grad nodes
- `dump_inductor_provenance_info` if we're dumping tlparse log
- `get_graph_provenance_json` and dump `reate_mapping_pre_post_grad_nodes`. This creates mapping between pre_grad and post_grad nodes. Since we're not turning on the provenance tracking in GraphTransformObserver by default, the mapping here maybe incomplete/limited.
- add stack trace from post grad nodes to inductor IR nodes
- add exception swallowing for all functions above
Test Plan:
CI
Rollback Plan:
Differential Revision: D80031559
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160383
Approved by: https://github.com/angelayi
This commit is contained in:
parent
3fc7a95176
commit
aa99e0958f
|
|
@ -55,7 +55,7 @@ class TestGraphTransformObserver(TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
@torch._inductor.config.patch("trace.provenance_tracking", True)
|
||||
@torch._inductor.config.patch("trace.provenance_tracking_level", 1)
|
||||
def test_graph_transform_observer_node_tracking(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
@ -156,7 +156,7 @@ class TestGraphTransformObserver(TestCase):
|
|||
[NodeSourceAction.REPLACE, NodeSourceAction.CREATE],
|
||||
)
|
||||
|
||||
@torch._inductor.config.patch("trace.provenance_tracking", True)
|
||||
@torch._inductor.config.patch("trace.provenance_tracking_level", 1)
|
||||
def test_graph_transform_observer_deepcopy(self):
|
||||
class SimpleLinearModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
@ -179,7 +179,7 @@ class TestGraphTransformObserver(TestCase):
|
|||
self.assertEqual(len(gm2._erase_node_hooks), 0)
|
||||
self.assertEqual(len(gm2._deepcopy_hooks), 0)
|
||||
|
||||
@torch._inductor.config.patch("trace.provenance_tracking", True)
|
||||
@torch._inductor.config.patch("trace.provenance_tracking_level", 1)
|
||||
def test_graph_transform_observer_replace(self):
|
||||
# the node sohuld should not be duplicated
|
||||
class Model(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class Model3(torch.nn.Module):
|
|||
|
||||
|
||||
@config.patch("trace.enabled", True)
|
||||
@config.patch("trace.provenance_tracking", True)
|
||||
@config.patch("trace.provenance_tracking_level", 1)
|
||||
class TestProvenanceTracingArtifact(TestCase):
|
||||
"""
|
||||
This test checks that generated provenance tracing artifact from "post_grad" to
|
||||
|
|
|
|||
|
|
@ -5385,7 +5385,7 @@ class CppScheduling(BaseScheduling):
|
|||
)
|
||||
kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()])
|
||||
# below add provenance tracing info for cpu CppKernel types
|
||||
if config.trace.provenance_tracking:
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
set_kernel_post_grad_provenance_tracing(nodes, kernel_name)
|
||||
|
||||
kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"
|
||||
|
|
|
|||
|
|
@ -1453,7 +1453,7 @@ class SIMDScheduling(BaseScheduling):
|
|||
with V.set_kernel_handler(kernel):
|
||||
src_code = kernel.codegen_kernel()
|
||||
kernel_name = self.define_kernel(src_code, node_schedule, kernel)
|
||||
if config.trace.provenance_tracking:
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
set_kernel_post_grad_provenance_tracing(
|
||||
node_schedule, # type: ignore[arg-type]
|
||||
kernel_name,
|
||||
|
|
@ -1664,7 +1664,7 @@ class SIMDScheduling(BaseScheduling):
|
|||
|
||||
kernel.kernel_name = self.define_kernel(src_code, node_schedule, kernel)
|
||||
|
||||
if config.trace.provenance_tracking:
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
set_kernel_post_grad_provenance_tracing(
|
||||
node_schedule, kernel.kernel_name
|
||||
)
|
||||
|
|
@ -1849,7 +1849,7 @@ class SIMDScheduling(BaseScheduling):
|
|||
for src_code, kernel, _ in kernel_code_list:
|
||||
kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel)
|
||||
# dump provenance node info for ComboKernelNode/ForeachKernel type
|
||||
if config.trace.provenance_tracking:
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
set_kernel_post_grad_provenance_tracing(
|
||||
combo_kernel_node.snodes, kernel_name
|
||||
)
|
||||
|
|
|
|||
|
|
@ -481,7 +481,7 @@ class ExternKernelOutLine(WrapperLine):
|
|||
kernel_name = node.get_kernel_name()
|
||||
device = d.type if (d := node.get_device()) else V.graph.device_type
|
||||
# set provenance tracing kernel mapping for ExternKernel types
|
||||
if config.trace.provenance_tracking:
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
set_kernel_post_grad_provenance_tracing(node, kernel_name, is_extern=True)
|
||||
self.wrapper._generate_extern_kernel_out_helper(
|
||||
kernel_name,
|
||||
|
|
|
|||
|
|
@ -64,7 +64,10 @@ from torch._inductor.cudagraph_utils import (
|
|||
log_cudagraph_skip_and_bump_counter,
|
||||
PlaceholderInfo,
|
||||
)
|
||||
from torch._inductor.debug import save_args_for_compile_fx_inner
|
||||
from torch._inductor.debug import (
|
||||
create_mapping_pre_post_grad_nodes,
|
||||
save_args_for_compile_fx_inner,
|
||||
)
|
||||
from torch._inductor.output_code import (
|
||||
CompiledAOTI,
|
||||
CompiledFxGraph,
|
||||
|
|
@ -1055,19 +1058,18 @@ def _compile_fx_inner(
|
|||
|
||||
log.debug("FX codegen and compilation took %.3fs", time.time() - start)
|
||||
|
||||
if config.trace.provenance_tracking:
|
||||
# Dump provenance artifacts for debugging trace
|
||||
provenance_info = torch._inductor.debug.dump_inductor_provenance_info()
|
||||
# provenance_info might be None if trace.provenance_tracking is not set
|
||||
if provenance_info:
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "inductor_provenance_tracking_node_mappings",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(provenance_info),
|
||||
)
|
||||
# Dump provenance artifacts for debugging trace
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "inductor_provenance_tracking_node_mappings",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(
|
||||
torch._inductor.debug.dump_inductor_provenance_info()
|
||||
),
|
||||
)
|
||||
|
||||
# This message is for printing overview information of inductor mm counts, shapes,etc after lowering
|
||||
if log.isEnabledFor(logging.INFO):
|
||||
|
|
@ -1310,20 +1312,10 @@ class _InProcessFxCompile(FxCompile):
|
|||
},
|
||||
payload_fn=lambda: inductor_post_grad_graph_str,
|
||||
)
|
||||
if config.trace.provenance_tracking:
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
provenance_tracking_json = (
|
||||
torch.fx.traceback.get_graph_provenance_json(gm.graph)
|
||||
)
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "inductor_post_to_pre_grad_nodes",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(provenance_tracking_json),
|
||||
)
|
||||
from torch._inductor.debug import create_mapping_pre_post_grad_nodes
|
||||
|
||||
torch._inductor.debug._inductor_post_to_pre_grad_nodes = (
|
||||
create_mapping_pre_post_grad_nodes(
|
||||
torch._inductor.debug._pre_grad_graph_id,
|
||||
|
|
@ -2205,7 +2197,9 @@ def compile_fx(
|
|||
with (
|
||||
_use_lazy_graph_module(dynamo_config.use_lazy_graph_module),
|
||||
enable_python_dispatcher(),
|
||||
torch.fx.traceback.preserve_node_meta(config.trace.provenance_tracking),
|
||||
torch.fx.traceback.preserve_node_meta(
|
||||
config.trace.provenance_tracking_level == 1
|
||||
),
|
||||
torch._inductor.debug.reset_provenance_globals(),
|
||||
):
|
||||
# Pre-grad passes cannot be run if we weren't given a GraphModule.
|
||||
|
|
@ -2239,7 +2233,7 @@ def compile_fx(
|
|||
)
|
||||
torch._inductor.debug._pre_grad_graph_id = id(model_.graph)
|
||||
|
||||
if config.trace.provenance_tracking:
|
||||
if config.trace.provenance_tracking_level == 1:
|
||||
for node in model_.graph.nodes:
|
||||
if node.stack_trace:
|
||||
torch._inductor.debug._inductor_pre_grad_node_stack_trace[
|
||||
|
|
|
|||
|
|
@ -1832,10 +1832,18 @@ class trace:
|
|||
|
||||
log_autotuning_results = os.environ.get("LOG_AUTOTUNE_RESULTS", "0") == "1"
|
||||
|
||||
# Save mapping info from inductor generated triton kernel to post_grad fx nodes to pre_grad fx nodes
|
||||
provenance_tracking = (
|
||||
os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
|
||||
or os.environ.get("INDUCTOR_PROVENANCE", "0") == "1"
|
||||
# Save mapping info from inductor generated kernel to post_grad/pre_grad fx nodes
|
||||
# Levels:
|
||||
# 0 - disabled (default)
|
||||
# 1 - normal
|
||||
# 2 - basic
|
||||
# Backward compatibility:
|
||||
# If TORCH_COMPILE_DEBUG=1, level is set to at least 1.
|
||||
# If INDUCTOR_PROVENANCE is set, use its integer value.
|
||||
provenance_tracking_level: int = int(
|
||||
os.environ.get(
|
||||
"INDUCTOR_PROVENANCE", os.environ.get("TORCH_COMPILE_DEBUG", "0")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -769,8 +769,6 @@ def create_mapping_pre_post_grad_nodes(
|
|||
"postToPre": {},
|
||||
}
|
||||
|
||||
log.info("Creating node mappings for provenance tracking")
|
||||
|
||||
if not isinstance(post_to_pre_grad_nodes_json, dict):
|
||||
log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict")
|
||||
return empty_return
|
||||
|
|
@ -860,8 +858,6 @@ def create_node_mapping_kernel_to_post_grad(
|
|||
"postToCppCode": {},
|
||||
}
|
||||
|
||||
log.info("Creating node mappings for provenance tracking")
|
||||
|
||||
if not isinstance(triton_kernel_to_post_grad_json, dict):
|
||||
log.error(
|
||||
"Provenance tacking error: triton_kernel_to_post_grad_json is not a dict"
|
||||
|
|
@ -905,28 +901,36 @@ def create_node_mapping_kernel_to_post_grad(
|
|||
def dump_inductor_provenance_info(
|
||||
filename: str = "inductor_generated_kernel_to_post_grad_nodes.json",
|
||||
) -> dict[str, Any]:
|
||||
global _pre_grad_graph_id
|
||||
global _inductor_post_to_pre_grad_nodes
|
||||
global _inductor_triton_kernel_to_post_grad_node_info
|
||||
if config.trace.enabled:
|
||||
with V.debug.fopen(filename, "w") as fd:
|
||||
log.info("Writing provenance tracing debugging info to %s", fd.name)
|
||||
json.dump(_inductor_triton_kernel_to_post_grad_node_info, fd)
|
||||
node_mapping = {}
|
||||
if _pre_grad_graph_id:
|
||||
node_mapping_kernel = create_node_mapping_kernel_to_post_grad(
|
||||
_inductor_triton_kernel_to_post_grad_node_info
|
||||
)
|
||||
node_mapping = {
|
||||
**_inductor_post_to_pre_grad_nodes,
|
||||
**node_mapping_kernel,
|
||||
}
|
||||
try:
|
||||
global _pre_grad_graph_id
|
||||
global _inductor_post_to_pre_grad_nodes
|
||||
global _inductor_triton_kernel_to_post_grad_node_info
|
||||
if config.trace.enabled:
|
||||
with V.debug.fopen(
|
||||
"inductor_provenance_tracking_node_mappings.json", "w"
|
||||
) as fd:
|
||||
json.dump(node_mapping, fd)
|
||||
return node_mapping
|
||||
with V.debug.fopen(filename, "w") as fd:
|
||||
log.info("Writing provenance tracing debugging info to %s", fd.name)
|
||||
json.dump(_inductor_triton_kernel_to_post_grad_node_info, fd)
|
||||
node_mapping = {}
|
||||
if _pre_grad_graph_id:
|
||||
node_mapping_kernel = create_node_mapping_kernel_to_post_grad(
|
||||
_inductor_triton_kernel_to_post_grad_node_info
|
||||
)
|
||||
node_mapping = {
|
||||
**_inductor_post_to_pre_grad_nodes,
|
||||
**node_mapping_kernel,
|
||||
}
|
||||
if config.trace.enabled:
|
||||
with V.debug.fopen(
|
||||
"inductor_provenance_tracking_node_mappings.json", "w"
|
||||
) as fd:
|
||||
json.dump(node_mapping, fd)
|
||||
return node_mapping
|
||||
except Exception as e:
|
||||
# Since this is just debugging, it should never interfere with regular
|
||||
# program execution, so we use this try-except to guard against any error
|
||||
# TODO: log the error to scuba table for better signal
|
||||
log.error("Unexpected error in dump_inductor_provenance_info: %s", e)
|
||||
log.error(traceback.format_exc())
|
||||
return {}
|
||||
|
||||
|
||||
def set_kernel_post_grad_provenance_tracing(
|
||||
|
|
@ -934,42 +938,49 @@ def set_kernel_post_grad_provenance_tracing(
|
|||
kernel_name: str,
|
||||
is_extern: bool = False,
|
||||
) -> None:
|
||||
from .codegen.simd_kernel_features import DisableReduction, EnableReduction
|
||||
try:
|
||||
from .codegen.simd_kernel_features import DisableReduction, EnableReduction
|
||||
|
||||
global _inductor_triton_kernel_to_post_grad_node_info
|
||||
if is_extern:
|
||||
assert isinstance(node_schedule, ExternKernelOut)
|
||||
curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault(
|
||||
kernel_name, []
|
||||
)
|
||||
# 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel.
|
||||
# "origin_node" is more precise and says that the contents of this node corresponds
|
||||
# EXACTLY to the output of a particular FX node, but it's not always available
|
||||
if node_schedule.origin_node:
|
||||
origin_node_name = node_schedule.origin_node.name
|
||||
if origin_node_name not in curr_node_info:
|
||||
curr_node_info.append(origin_node_name)
|
||||
else:
|
||||
curr_node_info.extend(
|
||||
origin.name
|
||||
for origin in node_schedule.origins
|
||||
if origin.name not in curr_node_info
|
||||
global _inductor_triton_kernel_to_post_grad_node_info
|
||||
if is_extern:
|
||||
assert isinstance(node_schedule, ExternKernelOut)
|
||||
curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault(
|
||||
kernel_name, []
|
||||
)
|
||||
else:
|
||||
assert isinstance(node_schedule, list)
|
||||
for snode in node_schedule:
|
||||
if snode not in (EnableReduction, DisableReduction):
|
||||
if snode.node is not None:
|
||||
curr_node_info = (
|
||||
_inductor_triton_kernel_to_post_grad_node_info.setdefault(
|
||||
kernel_name, []
|
||||
# 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel.
|
||||
# "origin_node" is more precise and says that the contents of this node corresponds
|
||||
# EXACTLY to the output of a particular FX node, but it's not always available
|
||||
if node_schedule.origin_node:
|
||||
origin_node_name = node_schedule.origin_node.name
|
||||
if origin_node_name not in curr_node_info:
|
||||
curr_node_info.append(origin_node_name)
|
||||
else:
|
||||
curr_node_info.extend(
|
||||
origin.name
|
||||
for origin in node_schedule.origins
|
||||
if origin.name not in curr_node_info
|
||||
)
|
||||
else:
|
||||
assert isinstance(node_schedule, list)
|
||||
for snode in node_schedule:
|
||||
if snode not in (EnableReduction, DisableReduction):
|
||||
if snode.node is not None:
|
||||
curr_node_info = (
|
||||
_inductor_triton_kernel_to_post_grad_node_info.setdefault(
|
||||
kernel_name, []
|
||||
)
|
||||
)
|
||||
)
|
||||
curr_node_info.extend(
|
||||
origin.name
|
||||
for origin in snode.node.origins
|
||||
if origin.name not in curr_node_info
|
||||
)
|
||||
curr_node_info.extend(
|
||||
origin.name
|
||||
for origin in snode.node.origins
|
||||
if origin.name not in curr_node_info
|
||||
)
|
||||
except Exception as e:
|
||||
# Since this is just debugging, it should never interfere with regular
|
||||
# program execution, so we use this try-except to guard against any error
|
||||
# TODO: log the error to scuba table for better signal
|
||||
log.error("Unexpected error in set_kernel_post_grad_provenance_tracing: %s", e)
|
||||
log.error(traceback.format_exc())
|
||||
|
||||
|
||||
def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None:
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ def _transfer_meta(
|
|||
# transfer metadata after pattern matching occurs.
|
||||
# skip "val" and "tensor_meta" because this info is too specific; it's unlikely
|
||||
# to remain accurate after pattern matching has occurred.
|
||||
if config.trace.provenance_tracking:
|
||||
if config.trace.provenance_tracking_level == 1:
|
||||
# We handle "from_node" field of the node meta specially to record that the new node comes from the old_node.
|
||||
new_from_node = new_meta.get("from_node", []).copy()
|
||||
new_from_node.append(NodeSource(old_node, pass_name, NodeSourceAction.REPLACE))
|
||||
|
|
|
|||
|
|
@ -2345,7 +2345,8 @@ def make_fx(
|
|||
record_module_stack,
|
||||
_allow_fake_constant,
|
||||
_error_on_data_dependent_ops,
|
||||
record_stack_traces=record_stack_traces or config.trace.provenance_tracking,
|
||||
record_stack_traces=record_stack_traces
|
||||
or config.trace.provenance_tracking_level == 1,
|
||||
)
|
||||
|
||||
@functools.wraps(f)
|
||||
|
|
|
|||
|
|
@ -43,7 +43,8 @@ class GraphTransformObserver:
|
|||
self.log_url = log_url
|
||||
|
||||
self.active = (
|
||||
self.log_url is not None or inductor_config.trace.provenance_tracking
|
||||
self.log_url is not None
|
||||
or inductor_config.trace.provenance_tracking_level == 1
|
||||
)
|
||||
|
||||
if self.active:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
import logging
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
|
@ -10,6 +11,8 @@ from .graph import Graph
|
|||
from .node import Node
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"preserve_node_meta",
|
||||
"has_preserved_node_meta",
|
||||
|
|
@ -311,12 +314,20 @@ def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
|
|||
"""
|
||||
Given an fx.Graph, return a json that contains the provenance information of each node.
|
||||
"""
|
||||
provenance_tracking_json = {}
|
||||
for node in graph.nodes:
|
||||
if node.op == "call_function":
|
||||
provenance_tracking_json[node.name] = (
|
||||
[source.to_dict() for source in node.meta["from_node"]]
|
||||
if "from_node" in node.meta
|
||||
else []
|
||||
)
|
||||
return provenance_tracking_json
|
||||
try:
|
||||
provenance_tracking_json = {}
|
||||
for node in graph.nodes:
|
||||
if node.op == "call_function":
|
||||
provenance_tracking_json[node.name] = (
|
||||
[source.to_dict() for source in node.meta["from_node"]]
|
||||
if "from_node" in node.meta
|
||||
else []
|
||||
)
|
||||
return provenance_tracking_json
|
||||
except Exception as e:
|
||||
# Since this is just debugging, it should never interfere with regular
|
||||
# program execution, so we use this try-except to guard against any error
|
||||
# TODO: log the error to scuba table for better signal
|
||||
log.error("Unexpected error in get_graph_provenance_json: %s", e)
|
||||
log.error(traceback.format_exc())
|
||||
return {}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user