mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: as title This is requested by the zoomer team so they can add stack trace information to profiler result. Test Plan: ``` buck run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing -- -r stack_traces ``` Rollback Plan: Differential Revision: D80050233 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160779 Approved by: https://github.com/angelayi
This commit is contained in:
parent
b7ca502f29
commit
b74c7cd335
|
|
@ -1,5 +1,7 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
|
@ -28,6 +30,9 @@ except ImportError:
|
|||
from test_aot_inductor_utils import AOTIRunnerUtil
|
||||
|
||||
|
||||
trace_log = logging.getLogger("torch.__trace")
|
||||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -483,5 +488,105 @@ class TestProvenanceTracingNodeMeta(TestCase):
|
|||
self.assertEqual(mm_node.meta["stack_trace"], stack_trace)
|
||||
|
||||
|
||||
class ProvenanceArtifactFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
if "artifact" in record.metadata:
|
||||
return (
|
||||
record.metadata["artifact"]["name"]
|
||||
== "inductor_provenance_tracking_kernel_stack_traces"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
class StructuredTracePayloadFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
return record.payload.strip()
|
||||
|
||||
|
||||
class TestProvenanceTracingStackTraces(TestCase):
|
||||
@contextlib.contextmanager
|
||||
def _setup_provenance_capture(self):
|
||||
"""Helper to turn on and capture the 'inductor_tlparse_runtime' structured trace."""
|
||||
payload_buffer = io.StringIO()
|
||||
payload_handler = logging.StreamHandler(payload_buffer)
|
||||
payload_handler.setLevel(logging.DEBUG)
|
||||
payload_handler.setFormatter(StructuredTracePayloadFormatter())
|
||||
payload_handler.addFilter(ProvenanceArtifactFilter())
|
||||
trace_log.addHandler(payload_handler)
|
||||
try:
|
||||
yield payload_buffer
|
||||
finally:
|
||||
trace_log.removeHandler(payload_handler)
|
||||
|
||||
def extract_code_line(self, s):
|
||||
# Extract last non-empty line
|
||||
return s.split("\n")[-2].strip()
|
||||
|
||||
@torch._inductor.config.patch(
|
||||
{"fx_graph_cache": False, "trace.provenance_tracking_level": 2}
|
||||
)
|
||||
@requires_cuda_and_triton
|
||||
def test_tlparse_kernel_stack_traces(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x, a, b, c):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.sigmoid(x)
|
||||
d = a * 3.14
|
||||
y = torch.addmm(c, d, b)
|
||||
z = torch.nn.functional.gelu(y)
|
||||
return x, z
|
||||
|
||||
device = "cuda"
|
||||
model = Model().to(device)
|
||||
x = torch.randn(8, 10).to(device)
|
||||
a = torch.randn(10, 20).to(device)
|
||||
b = torch.randn(20, 30).to(device)
|
||||
c = torch.randn(10, 30).to(device)
|
||||
example_inputs = (x, a, b, c)
|
||||
|
||||
expected = {
|
||||
"triton_poi_fused_addmm_relu_sigmoid_threshold_backward_0": [
|
||||
"x = self.sigmoid(x)",
|
||||
"x = self.fc1(x)",
|
||||
"x = self.relu(x)",
|
||||
],
|
||||
"triton_poi_fused_mul_1": [
|
||||
"d = a * 3.14",
|
||||
],
|
||||
"triton_poi_fused_addmm_gelu_2": [
|
||||
"z = torch.nn.functional.gelu(y)",
|
||||
"y = torch.addmm(c, d, b)",
|
||||
],
|
||||
"extern_kernels.mm": [
|
||||
"y = torch.addmm(c, d, b)",
|
||||
],
|
||||
}
|
||||
|
||||
with self._setup_provenance_capture() as payload_buffer:
|
||||
compiled = torch.compile(model)
|
||||
compiled(*example_inputs)
|
||||
payload_content = payload_buffer.getvalue().strip()
|
||||
if payload_content:
|
||||
data = json.loads(payload_content)
|
||||
self.assertEqual(set(data.keys()), set(expected.keys()))
|
||||
for key, expected_lines in expected.items():
|
||||
actual_lines = [self.extract_code_line(s) for s in data[key]]
|
||||
print(key)
|
||||
print(actual_lines)
|
||||
print(expected_lines)
|
||||
self.assertEqual(
|
||||
sorted(actual_lines),
|
||||
sorted(expected_lines),
|
||||
f"Mismatch for key: {key}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1070,6 +1070,16 @@ def _compile_fx_inner(
|
|||
torch._inductor.debug.dump_inductor_provenance_info()
|
||||
),
|
||||
)
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "inductor_provenance_tracking_kernel_stack_traces",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(
|
||||
torch._inductor.debug._inductor_kernel_stack_trace
|
||||
),
|
||||
)
|
||||
|
||||
# This message is for printing overview information of inductor mm counts, shapes,etc after lowering
|
||||
if log.isEnabledFor(logging.INFO):
|
||||
|
|
|
|||
|
|
@ -319,6 +319,7 @@ _inductor_post_to_pre_grad_nodes: dict[str, dict[str, list[str]]] = {}
|
|||
_inductor_triton_kernel_to_post_grad_node_info: dict[str, Any] = {}
|
||||
_pre_grad_graph_id: Optional[int] = None
|
||||
_inductor_pre_grad_node_stack_trace: dict[str, str] = {}
|
||||
_inductor_kernel_stack_trace: dict[str, list[str]] = {}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
|
@ -328,6 +329,8 @@ def reset_provenance_globals() -> Iterator[None]:
|
|||
global _pre_grad_graph_id
|
||||
global _inductor_post_to_pre_grad_nodes
|
||||
global _inductor_triton_kernel_to_post_grad_node_info
|
||||
global _inductor_pre_grad_node_stack_trace
|
||||
global _inductor_kernel_stack_trace
|
||||
|
||||
# Store original values
|
||||
original_pre_grad_graph_id = _pre_grad_graph_id
|
||||
|
|
@ -335,11 +338,17 @@ def reset_provenance_globals() -> Iterator[None]:
|
|||
original_triton_kernel_to_post_grad_node_info = (
|
||||
_inductor_triton_kernel_to_post_grad_node_info.copy()
|
||||
)
|
||||
original_inductor_pre_grad_node_stack_trace = (
|
||||
_inductor_pre_grad_node_stack_trace.copy()
|
||||
)
|
||||
original_inductor_kernel_stack_trace = _inductor_kernel_stack_trace.copy()
|
||||
|
||||
# Reset to default values
|
||||
_pre_grad_graph_id = -1
|
||||
_inductor_post_to_pre_grad_nodes = {}
|
||||
_inductor_triton_kernel_to_post_grad_node_info = {}
|
||||
_inductor_pre_grad_node_stack_trace = {}
|
||||
_inductor_kernel_stack_trace = {}
|
||||
|
||||
try:
|
||||
yield
|
||||
|
|
@ -350,6 +359,10 @@ def reset_provenance_globals() -> Iterator[None]:
|
|||
_inductor_triton_kernel_to_post_grad_node_info = (
|
||||
original_triton_kernel_to_post_grad_node_info
|
||||
)
|
||||
_inductor_kernel_stack_trace = original_inductor_kernel_stack_trace
|
||||
_inductor_pre_grad_node_stack_trace = (
|
||||
original_inductor_pre_grad_node_stack_trace
|
||||
)
|
||||
|
||||
|
||||
class DebugContext:
|
||||
|
|
@ -942,6 +955,7 @@ def set_kernel_post_grad_provenance_tracing(
|
|||
from .codegen.simd_kernel_features import DisableReduction, EnableReduction
|
||||
|
||||
global _inductor_triton_kernel_to_post_grad_node_info
|
||||
global _inductor_kernel_stack_trace
|
||||
if is_extern:
|
||||
assert isinstance(node_schedule, ExternKernelOut)
|
||||
curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault(
|
||||
|
|
@ -960,8 +974,12 @@ def set_kernel_post_grad_provenance_tracing(
|
|||
for origin in node_schedule.origins
|
||||
if origin.name not in curr_node_info
|
||||
)
|
||||
_inductor_kernel_stack_trace[kernel_name] = list(
|
||||
node_schedule.get_stack_traces()
|
||||
)
|
||||
else:
|
||||
assert isinstance(node_schedule, list)
|
||||
stack_traces: OrderedSet[str] = OrderedSet()
|
||||
for snode in node_schedule:
|
||||
if snode not in (EnableReduction, DisableReduction):
|
||||
if snode.node is not None:
|
||||
|
|
@ -970,11 +988,13 @@ def set_kernel_post_grad_provenance_tracing(
|
|||
kernel_name, []
|
||||
)
|
||||
)
|
||||
stack_traces.update(snode.node.get_stack_traces())
|
||||
curr_node_info.extend(
|
||||
origin.name
|
||||
for origin in snode.node.origins
|
||||
if origin.name not in curr_node_info
|
||||
)
|
||||
_inductor_kernel_stack_trace[kernel_name] = list(stack_traces)
|
||||
except Exception as e:
|
||||
# Since this is just debugging, it should never interfere with regular
|
||||
# program execution, so we use this try-except to guard against any error
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user