pytorch/test/inductor/test_provenance_tracing.py
Shangdi Yu d2eff5d454 Add python stack trace to AOTI generated code (#160539)
Summary:
We add a thread_local KernelContext object so Strobelight (and other potential profilers) can read the stack trace information of the running kernel.

This will bring extra overhead, so we guard this behind the `cpp.enable_kernel_profile` flag.

Example output code:

```cpp
#include <torch/csrc/inductor/aoti_runtime/kernel_context_tls.h>
namespace torch::aot_inductor {
thread_local KernelContext* tls_kernel_context = nullptr;
}
// Other code .....
void AOTInductorModel::run_impl(
    AtenTensorHandle*
        input_handles, // array of input AtenTensorHandle; handles
                        // are stolen; the array itself is borrowed
    AtenTensorHandle*
        output_handles, // array for writing output AtenTensorHandle; handles
                        // will be stolen by the caller; the array itself is
                        // borrowed
    DeviceStreamType stream,
    AOTIProxyExecutorHandle proxy_executor
) {
    __check_inputs_outputs(input_handles, output_handles);
    auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 4);
    auto arg2_1 = std::move(inputs[0]);
    auto arg3_1 = std::move(inputs[1]);
    auto arg4_1 = std::move(inputs[2]);
    auto arg5_1 = std::move(inputs[3]);
    [[maybe_unused]] auto& fc1_weight = constants_->at(0);
    [[maybe_unused]] auto& fc1_bias = constants_->at(1);
    inputs.clear();
    [[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
    static constexpr int64_t int_array_0[] = {8L, 16L};
    static constexpr int64_t int_array_1[] = {16L, 1L};
    AtenTensorHandle buf0_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf0_handle));
    RAIIAtenTensorHandle buf0(buf0_handle);
    // Topologically Sorted Source Nodes: [linear], Original ATen: [aten.t, aten.addmm]
    // [Provenance debug handles] aoti_torch_cpu_addmm_out:1
    static constexpr int64_t int_array_2[] = {10L, 16L};
    static constexpr int64_t int_array_3[] = {1L, 10L};
    {
    KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 829, in forward
    x = self.fc1(x)
  File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/linear.py", line 134, in forward
    return F.linear(input, self.weight, self.bias)
)");
    RAIIAtenRecordFunctionHandle record_aoti_torch_cpu_addmm_out_("aoti_torch_cpu_addmm_out", nullptr);
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(buf0, fc1_bias, arg2_1, wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper(fc1_weight, 2, int_array_2, int_array_3, 0L)), 1L, 1L));
    }
    arg2_1.reset();
    auto buf1 = std::move(buf0);  // reuse
    static constexpr int64_t int_array_4[] = {10L, 20L};
    static constexpr int64_t int_array_5[] = {20L, 1L};
    AtenTensorHandle buf2_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_4, int_array_5, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf2_handle));
    RAIIAtenTensorHandle buf2(buf2_handle);
    // [Provenance debug handles] cpp_fused_mul_relu_sigmoid_0:2
    {
    KernelContextGuard _ctx("cpp_fused_mul_relu_sigmoid_0", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 831, in forward
    x = self.sigmoid(x)
  File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/activation.py", line 359, in forward
    return torch.sigmoid(input)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 830, in forward
    x = self.relu(x)
  File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/activation.py", line 144, in forward
    return F.relu(input, inplace=self.inplace)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 832, in forward
    d = a * 3.14
)");
    cpp_fused_mul_relu_sigmoid_0((float*)(buf1.data_ptr()), (const float*)(arg3_1.data_ptr()), (float*)(buf2.data_ptr()));
    }
    arg3_1.reset();
    static constexpr int64_t int_array_6[] = {10L, 30L};
    static constexpr int64_t int_array_7[] = {30L, 1L};
    AtenTensorHandle buf3_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_6, int_array_7, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf3_handle));
    RAIIAtenTensorHandle buf3(buf3_handle);
    // Topologically Sorted Source Nodes: [mul, addmm], Original ATen: [aten.mul, aten.addmm]
    // [Provenance debug handles] aoti_torch_cpu_addmm_out:3
    {
    KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 833, in forward
    y = torch.addmm(c, d, b)
)");
    RAIIAtenRecordFunctionHandle record_aoti_torch_cpu_addmm_out_("aoti_torch_cpu_addmm_out", nullptr);
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(buf3, arg5_1, buf2, arg4_1, 1L, 1L));
    }
    arg4_1.reset();
    arg5_1.reset();
    buf2.reset();
    auto buf4 = std::move(buf3);  // reuse
    // [Provenance debug handles] cpp_fused_gelu_1:4
    {
    KernelContextGuard _ctx("cpp_fused_gelu_1", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 834, in forward
    z = torch.nn.functional.gelu(y)
)");
    cpp_fused_gelu_1((float*)(buf4.data_ptr()));
    }
    output_handles[0] = buf1.release();
    output_handles[1] = buf4.release();
} // AOTInductorModel::run_impl
```

Test Plan:
```
buck run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing -- -r  stack_traces
```

Rollback Plan:

Differential Revision: D78436007

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160539
Approved by: https://github.com/yiming0416
2025-10-29 22:47:52 +00:00

946 lines
36 KiB
Python

# Owner(s): ["module: inductor"]
import contextlib
import io
import json
import logging
import os
import re
import shutil
import sys
import tempfile
import unittest
import zipfile
from pathlib import Path
import torch
from torch._C import FileCheck
from torch._dynamo.utils import detect_fake_mode
from torch._inductor import config
from torch._inductor.debug import (
create_kernel_information_json,
create_mapping_pre_post_grad_nodes,
create_node_mapping_kernel_to_post_grad,
reset_inductor_kernel_provenance_debug_handle,
)
from torch._inductor.fx_passes.post_grad import post_grad_passes
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code, run_and_get_cpp_code
from torch._inductor.virtualized import V
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.triton_utils import requires_cuda_and_triton
try:
from .test_aot_inductor_utils import AOTIRunnerUtil
from .test_torchinductor import copy_tests
except ImportError:
from test_aot_inductor_utils import AOTIRunnerUtil
from test_torchinductor import (
copy_tests, # @manual=fbcode//caffe2/test/inductor:test_inductor-library
)
trace_log = logging.getLogger("torch.__trace")
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b, c):
x = a * 3.14
y = torch.addmm(c, x, b)
z = torch.nn.functional.gelu(y)
return z
class Model2(torch.nn.Module):
# this test model is used for combo kernel provenance tracing info
def __init__(self):
super().__init__()
def forward(self, a, b, c):
a1 = torch.nn.functional.relu(a)
b1 = torch.nn.functional.sigmoid(b)
c1 = torch.nn.functional.tanh(c)
return a1, b1, c1
class Model3(torch.nn.Module):
def __init__(self, n, k):
super().__init__()
self.weight = torch.randn(n, k, device="cuda")
self.bias = torch.randn(n, device="cuda")
def forward(self, a):
return torch.nn.functional.linear(a, self.weight, self.bias)
class Model4(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x, a, b, c):
x = self.fc1(x)
x = self.relu(x)
x = self.sigmoid(x)
d = a * 3.14
y = torch.addmm(c, d, b)
z = torch.nn.functional.gelu(y)
return x, z
@config.patch("trace.enabled", True)
@config.patch("trace.provenance_tracking_level", 1)
class TestProvenanceTracingArtifact(TestCase):
"""
This test checks that generated provenance tracing artifact from "post_grad" to
corresponding "inductor triton kernel node" is expected.
"""
def _check_provenance_tracing_kernel_to_post_grad(self, filepath, expected_data):
self.assertTrue(filepath.is_dir())
filename = Path(filepath) / "inductor_provenance_tracking_node_mappings.json"
with open(filename) as f:
actual_data = json.load(f)
actual_data = actual_data["cppCodeToPost"]
# check that the generated provenance tracing artifact is expected
self.assertEqual(sorted(actual_data.items()), sorted(expected_data.items()))
def _check_provenance_tracking_node_mappings(self, filepath, expected_mapping):
self.assertTrue(filepath.is_dir())
filename = Path(filepath) / "inductor_provenance_tracking_node_mappings.json"
with open(filename) as f:
actual_data = json.load(f)
# check that the generated provenance tracing node mapping is expected
self.assertEqual(sorted(actual_data.items()), sorted(expected_mapping))
def _test_triton_kernel_to_post_grad_tracing(self, device):
a = torch.randn(10, 20, device=device)
b = torch.randn(20, 30, device=device)
c = torch.randn(10, 30, device=device)
example_inputs = (a, b, c)
model = Model().to(device)
filepath = None
for backend in ["aot_inductor", "inductor"]:
reset_inductor_kernel_provenance_debug_handle()
try:
with config.patch(
{
"trace.debug_dir": tempfile.mkdtemp(),
"force_disable_caches": True,
}
):
with self.assertLogs(
logging.getLogger("torch._inductor.debug"),
level=logging.WARNING,
) as cm:
if backend == "aot_inductor":
AOTIRunnerUtil.run(model, example_inputs)
else:
ep = torch.export._trace._export(model, example_inputs)
compiled = torch.compile(ep.module(), backend=backend)
compiled(*example_inputs)
self.assertEqual(len(cm.output), 1)
m = re.match(r"WARNING.* debug trace: (.*)", cm.output[0])
self.assertTrue(m)
filepath = Path(m.group(1))
if device == "cuda":
expected_mapping = [
(
"cppCodeToPost",
{
"triton_poi_fused_mul_0:1": ["mul"],
"triton_poi_fused_addmm_gelu_1:3": [
"mul_3",
"mul_1",
"add_tensor",
"add",
"erf",
"mul_2",
],
},
),
(
"postToCppCode",
{
"mul": ["triton_poi_fused_mul_0:1"],
"mul_3": ["triton_poi_fused_addmm_gelu_1:3"],
"mul_1": ["triton_poi_fused_addmm_gelu_1:3"],
"add_tensor": ["triton_poi_fused_addmm_gelu_1:3"],
"add": ["triton_poi_fused_addmm_gelu_1:3"],
"erf": ["triton_poi_fused_addmm_gelu_1:3"],
"mul_2": ["triton_poi_fused_addmm_gelu_1:3"],
},
),
(
"postToPre",
{
"mul": ["mul"],
"mm_default": ["addmm"],
"add_tensor": ["addmm"],
"mul_1": ["gelu"],
"mul_2": ["gelu"],
"erf": ["gelu"],
"add": ["gelu"],
"mul_3": ["gelu"],
},
),
(
"preToPost",
{
"mul": ["mul"],
"addmm": ["mm_default", "add_tensor"],
"gelu": ["mul_1", "mul_2", "erf", "add", "mul_3"],
},
),
]
if backend == "aot_inductor":
expected_mapping[0][1]["aoti_torch_cuda_mm_out:2"] = [
"mm_default"
]
expected_mapping[1][1]["mm_default"] = [
"aoti_torch_cuda_mm_out:2"
]
else:
expected_mapping[0][1]["extern_kernels.mm:2"] = [
"mm_default"
]
expected_mapping[1][1]["mm_default"] = [
"extern_kernels.mm:2"
]
self._check_provenance_tracking_node_mappings(
filepath, expected_mapping
)
else:
assert device == "cpu"
# check the inductor kernel to post grad nodes mapping is expected for cpu
if backend == "aot_inductor":
expected_data = {
"cpp_fused_mul_0:1": ["mul"],
"aoti_torch_cpu_addmm_out:2": ["addmm"],
"cpp_fused_gelu_1:3": [
"mul_3",
"mul_1",
"add",
"erf",
"mul_2",
],
}
else:
# backend == "inductor"
expected_data = {
"cpp_fused_mul_0:1": ["mul"],
"cpp_fused_gelu_1:3": [
"mul_3",
"mul_1",
"add",
"erf",
"mul_2",
],
"extern_kernels.addmm:2": ["addmm"],
}
self._check_provenance_tracing_kernel_to_post_grad(
filepath, expected_data
)
finally:
if filepath:
shutil.rmtree(filepath)
@requires_cuda_and_triton
def test_triton_kernel_to_post_grad_tracing_cuda(self):
self._test_triton_kernel_to_post_grad_tracing(device="cuda")
def test_triton_kernel_to_post_grad_tracing_cpu(self):
self._test_triton_kernel_to_post_grad_tracing(device="cpu")
@requires_cuda_and_triton
def test_triton_kernel_to_post_grad_tracing_extern_kernel(self):
M = 8
N = 6
K = 16
model = Model3(N, K)
batch = 2
a = torch.randn(batch, M, K, device="cuda")
example_inputs = (a,)
filepath = None
for backend in ["aot_inductor", "inductor"]:
reset_inductor_kernel_provenance_debug_handle()
try:
with config.patch(
{
"trace.debug_dir": tempfile.mkdtemp(),
"force_disable_caches": True,
}
):
with self.assertLogs(
logging.getLogger("torch._inductor.debug"),
level=logging.WARNING,
) as cm:
if backend == "aot_inductor":
AOTIRunnerUtil.run(model, example_inputs)
else:
ep = torch.export._trace._export(model, example_inputs)
compiled = torch.compile(ep.module(), backend=backend)
compiled(*example_inputs)
self.assertEqual(len(cm.output), 1)
m = re.match(r"WARNING.* debug trace: (.*)", cm.output[0])
self.assertTrue(m)
filepath = Path(m.group(1))
if backend == "inductor":
expected_data = {
"extern_kernels.addmm:1": ["addmm"],
}
else:
# backend = aot_inductor
expected_data = {
"aoti_torch_cuda_addmm_out:2": ["addmm"],
"triton_poi_fused_0:1": ["_tensor_constant1"],
}
self._check_provenance_tracing_kernel_to_post_grad(
filepath, expected_data
)
finally:
if filepath:
shutil.rmtree(filepath)
@requires_cuda_and_triton
def _test_pt_tracing_combo_kernel(self, backend):
"""This test checks that generated provenance tracing artifact from triton combo kernel to post grad nodes"""
a = torch.randn(10, 10, device="cuda")
b = torch.randn(20, 20, device="cuda")
c = torch.randn(10, 10, device="cuda")
example_inputs = (a, b, c)
model = Model2()
reset_inductor_kernel_provenance_debug_handle()
with config.patch(
{
"trace.debug_dir": tempfile.mkdtemp(),
"force_disable_caches": True,
"combo_kernels": True,
"benchmark_combo_kernel": False,
}
):
with self.assertLogs(
logging.getLogger("torch._inductor.debug"),
level=logging.WARNING,
) as cm:
if backend == "aot_inductor":
AOTIRunnerUtil.run(model, example_inputs)
else:
ep = torch.export._trace._export(model, example_inputs)
compiled = torch.compile(ep.module(), backend=backend)
compiled(*example_inputs)
self.assertEqual(len(cm.output), 1)
m = re.match(r"WARNING.* debug trace: (.*)", cm.output[0])
self.assertTrue(m)
filepath = Path(m.group(1)).resolve()
expected_data = {"triton_poi_fused_0:1": ["relu", "sigmoid", "tanh"]}
self._check_provenance_tracing_kernel_to_post_grad(filepath, expected_data)
@requires_cuda_and_triton
def test_triton_kernel_to_post_grad_tracing_combo_kernel(self):
self._test_pt_tracing_combo_kernel(backend="inductor")
self._test_pt_tracing_combo_kernel(backend="aot_inductor")
class TestProvenanceTracingNodeMapping(TestCase):
def test_create_node_mapping(self):
pre_grad_graph_id = 140156815043952
post_to_pre_grad_nodes_json = {
"add_tensor": [
{
"from_node": [
{
"from_node": [
{
"from_node": [],
"graph_id": 140156815043952,
"name": "linear",
}
],
"graph_id": 140152856025632,
"name": "addmm",
}
],
"graph_id": 140151961816272,
"name": "add",
},
],
"mm_default": [
{
"from_node": [],
"graph_id": -1,
"name": "",
},
{
"from_node": [
{
"from_node": [
{
"from_node": [],
"graph_id": 140156815043952,
"name": "linear",
}
],
"graph_id": 140152856025632,
"name": "addmm",
}
],
"graph_id": 140151961816272,
"name": "mm",
},
],
"permute": [
{
"from_node": [],
"graph_id": 140156815043952,
"name": "linear",
}
],
"relu": [
{
"from_node": [],
"graph_id": 140156815043952,
"name": "relu",
}
],
}
triton_kernel_to_post_grad_json = {
"triton_poi_fused_addmm_relu_sigmoid_0": ["relu", "add_tensor"]
}
result = create_mapping_pre_post_grad_nodes(
pre_grad_graph_id,
post_to_pre_grad_nodes_json,
)
result = {
**result,
**create_node_mapping_kernel_to_post_grad(
triton_kernel_to_post_grad_json,
),
}
self.assertEqual(
result,
{
"cppCodeToPost": {
"triton_poi_fused_addmm_relu_sigmoid_0": [
"relu",
"add_tensor",
]
},
"postToCppCode": {
"add_tensor": ["triton_poi_fused_addmm_relu_sigmoid_0"],
"relu": ["triton_poi_fused_addmm_relu_sigmoid_0"],
},
"postToPre": {
"add_tensor": ["linear"],
"mm_default": ["linear"],
"permute": ["linear"],
"relu": ["relu"],
},
"preToPost": {
"linear": ["add_tensor", "mm_default", "permute"],
"relu": ["relu"],
},
},
)
class TestProvenanceTracingNodeMeta(TestCase):
def get_node_with_target(self, gm, target):
"""
Return first node in gm with target
"""
return next(iter([node for node in gm.graph.nodes if node.target == target]))
@requires_cuda_and_triton # test only works for cuda pattern matcher
def test_pattern_matcher_transfer_meta(self):
"""
Test that stack trace is transfered when node is decomposed in post_grad_passes
"""
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.sigmoid(x)
return x * 3
x = torch.randn(8, 10).to("cuda")
example_inputs = (x,)
model = Model().to("cuda")
# mimic the before_post_grad graph
ep = torch.export.export(model, example_inputs).run_decompositions()
gm = ep.module()
# Set fake mode for V
fake_inputs = [
node.meta.get("val") for node in gm.graph.nodes if node.op == "placeholder"
]
fake_mode = detect_fake_mode(fake_inputs)
V.set_fake_mode(fake_mode)
addmm_node = self.get_node_with_target(gm, torch.ops.aten.addmm.default)
stack_trace = addmm_node.meta["stack_trace"]
post_grad_passes(gm, True) # for this test is_inference doesn't matter
mm_node = self.get_node_with_target(gm, torch.ops.aten.mm.default)
add_node = self.get_node_with_target(gm, torch.ops.aten.add.Tensor)
self.assertEqual(add_node.meta["stack_trace"], stack_trace)
self.assertEqual(mm_node.meta["stack_trace"], stack_trace)
class ProvenanceArtifactFilter(logging.Filter):
def filter(self, record):
if "artifact" in record.metadata:
return (
record.metadata["artifact"]["name"]
== "inductor_provenance_tracking_kernel_stack_traces"
)
return False
class StructuredTracePayloadFormatter(logging.Formatter):
def format(self, record):
return record.payload.strip()
class TestProvenanceTracingStackTraces(TestCase):
@contextlib.contextmanager
def _setup_provenance_capture(self):
"""Helper to turn on and capture the 'inductor_tlparse_runtime' structured trace."""
payload_buffer = io.StringIO()
payload_handler = logging.StreamHandler(payload_buffer)
payload_handler.setLevel(logging.DEBUG)
payload_handler.setFormatter(StructuredTracePayloadFormatter())
payload_handler.addFilter(ProvenanceArtifactFilter())
trace_log.addHandler(payload_handler)
try:
yield payload_buffer
finally:
trace_log.removeHandler(payload_handler)
def extract_code_line(self, s, i=-2):
# Extract ith line
return s.split("\n")[i].strip()
@torch._inductor.config.patch({"trace.provenance_tracking_level": 2})
@requires_cuda_and_triton
def test_tlparse_kernel_stack_traces(self):
device = "cuda"
model = Model4().to(device)
x = torch.randn(8, 10).to(device)
a = torch.randn(10, 20).to(device)
b = torch.randn(20, 30).to(device)
c = torch.randn(10, 30).to(device)
example_inputs = (x, a, b, c)
expected = {
"triton_poi_fused_addmm_relu_sigmoid_threshold_backward_0:2": [
"x = self.sigmoid(x)",
"x = self.fc1(x)",
"x = self.relu(x)",
],
"triton_poi_fused_mul_1:3": [
"d = a * 3.14",
],
"triton_poi_fused_addmm_gelu_2:5": [
"z = torch.nn.functional.gelu(y)",
"y = torch.addmm(c, d, b)",
],
"extern_kernels.mm:1": [
"x = self.fc1(x)",
],
"extern_kernels.mm:4": [
"y = torch.addmm(c, d, b)",
],
}
compiled = torch.compile(model)
# should produce the same provenance if there's cache hit
for _ in range(2):
# reset cache
torch._dynamo.reset()
reset_inductor_kernel_provenance_debug_handle()
with self._setup_provenance_capture() as payload_buffer:
compiled = torch.compile(model)
compiled(*example_inputs)
payload_content = payload_buffer.getvalue().strip()
data = json.loads(payload_content)
self.assertEqual(set(data.keys()), set(expected.keys()))
for key, expected_lines in expected.items():
actual_lines = [self.extract_code_line(s) for s in data[key]]
self.assertEqual(
sorted(actual_lines),
sorted(expected_lines),
f"Mismatch for key: {key}",
)
@torch._inductor.config.patch(
{"trace.provenance_tracking_level": 2, "max_autotune_gemm_backends": "ATEN"}
)
@requires_cuda_and_triton
def test_deferred_triton_kernels(self):
def foo(m, inp):
a = m(inp)
return a
foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
m = torch.nn.Linear(512, 512, bias=True).half().cuda()
inp = torch.rand([1, 512]).half().cuda()
with self._setup_provenance_capture() as payload_buffer:
with torch.no_grad():
_, out_code = run_and_get_code(foo_c, m, inp)
payload_content = payload_buffer.getvalue().strip()
data = json.loads(payload_content)
self.assertTrue("a = m(inp)" in str(data))
# Check that debug handle is in the output code
FileCheck().check("Topologically Sorted Source Nodes: [a]").check(
"[Provenance debug handles]"
).run(out_code[0])
def _check_kernel_information_json(self, kernel_info, expected_kernels):
"""Validate kernel information JSON structure and content."""
self.assertIsInstance(kernel_info, dict)
for expected in expected_kernels:
self.assertIn(
expected,
kernel_info,
f"Expected kernel {expected} not found in {list(kernel_info)}",
)
for data in kernel_info.values():
self.assertIsInstance(data, dict)
for field in ["stack_traces", "post_grad_nodes", "pre_grad_nodes"]:
self.assertIn(field, data)
self.assertIsInstance(data[field], list)
for item in data[field]:
self.assertIsInstance(item, str)
@requires_cuda_and_triton
@torch._inductor.config.patch("trace.provenance_tracking_level", 1)
def test_kernel_information_generation(self):
"""Test basic kernel information generation in AOTI packages."""
model = Model4().to("cuda")
x = torch.randn(8, 10, device="cuda")
a = torch.randn(10, 20, device="cuda")
b = torch.randn(20, 30, device="cuda")
c = torch.randn(10, 30, device="cuda")
inputs = (x, a, b, c)
with tempfile.TemporaryDirectory() as temp_dir:
ep = torch.export.export(model, inputs, strict=False)
pt2_file = os.path.join(temp_dir, "model.pt2")
reset_inductor_kernel_provenance_debug_handle()
torch._inductor.aoti_compile_and_package(ep, package_path=pt2_file)
# Extract and check kernel_information.json exists in the package
with zipfile.ZipFile(pt2_file, "r") as zip_ref:
zip_ref.extractall(temp_dir)
json_path = os.path.join(
temp_dir,
"model",
"data",
"aotinductor",
"model",
"kernel_information.json",
)
self.assertTrue(
os.path.exists(json_path),
f"kernel_information.json not found in extracted package at {json_path}",
)
with open(json_path) as f:
kernel_info = json.load(f)
expected = {
"triton_poi_fused_addmm_relu_sigmoid_0:2": {
"stack_traces": [
"x = self.sigmoid(x)",
"x = self.fc1(x)",
"x = self.relu(x)",
],
"post_grad_nodes": ["sigmoid", "relu", "add_tensor_1"],
"pre_grad_nodes": ["sigmoid", "relu", "linear"],
},
"triton_poi_fused_mul_1:3": {
"stack_traces": [
"d = a * 3.14",
],
"post_grad_nodes": ["mul"],
"pre_grad_nodes": ["mul"],
},
"triton_poi_fused_addmm_gelu_2:5": {
"stack_traces": [
"z = torch.nn.functional.gelu(y)",
"y = torch.addmm(c, d, b)",
],
"post_grad_nodes": [
"mul_3",
"mul_1",
"add_tensor",
"add",
"erf",
"mul_2",
],
"pre_grad_nodes": ["gelu", "addmm"],
},
"aoti_torch_cuda_mm_out:1": {
"stack_traces": [
"x = self.fc1(x)",
],
"post_grad_nodes": ["mm_default_1"],
"pre_grad_nodes": ["linear"],
},
"aoti_torch_cuda_mm_out:4": {
"stack_traces": [
"y = torch.addmm(c, d, b)",
],
"post_grad_nodes": ["mm_default"],
"pre_grad_nodes": ["addmm"],
},
}
self._check_kernel_information_json(kernel_info, expected.keys())
self.assertEqual(set(kernel_info.keys()), set(expected.keys()))
for key, data in expected.items():
all_lines = ",".join(kernel_info[key]["stack_traces"])
for s in data["stack_traces"]:
self.assertTrue(s in all_lines)
self.assertEqual(
sorted(kernel_info[key]["pre_grad_nodes"]),
sorted(data["pre_grad_nodes"]),
f"Mismatch for key: {key}",
)
self.assertEqual(
sorted(kernel_info[key]["post_grad_nodes"]),
sorted(data["post_grad_nodes"]),
f"Mismatch for key: {key}",
)
@torch._inductor.config.patch("trace.provenance_tracking_level", 0)
def test_no_kernel_information_without_provenance_tracking(self):
"""Test that kernel_information.json is not generated without provenance tracking."""
class SimpleModel(torch.nn.Module):
def forward(self, x):
return x * 2.0
model = SimpleModel()
x = torch.randn(4, 8)
# Compile with AOTI but without provenance tracking
with tempfile.TemporaryDirectory() as temp_dir:
ep = torch.export.export(model, (x,), strict=False)
pt2_file = os.path.join(temp_dir, "model.pt2")
torch._inductor.aoti_compile_and_package(ep, package_path=pt2_file)
# Extract and check kernel_information.json was NOT created in the package
extract_dir = os.path.join(temp_dir, "extracted")
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(pt2_file, "r") as zip_ref:
zip_ref.extractall(extract_dir)
expected_json_path = os.path.join(extract_dir, "kernel_information.json")
self.assertFalse(
os.path.exists(expected_json_path),
"kernel_information.json should not exist in package when provenance tracking is disabled",
)
def test_create_kernel_information_json_function(self):
"""Test the create_kernel_information_json function directly."""
# Test with empty state
result = create_kernel_information_json()
self.assertIsInstance(result, dict)
self.assertEqual(len(result), 0) # Should be empty with no provenance data
@unittest.skipIf(
IS_MACOS,
"MacOS generates different debug handles",
)
@torch._inductor.config.patch("trace.provenance_tracking_level", 1)
def test_cpu_extern_kernel(self):
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(16, 33, 3)
def forward(self, x):
return self.conv(x)
model = Foo()
x = torch.randn(20, 16, 50, 100)
with self._setup_provenance_capture() as payload_buffer:
reset_inductor_kernel_provenance_debug_handle()
ep = torch.export.export(model, (x,))
torch._inductor.aoti_compile_and_package(ep)
payload_content = payload_buffer.getvalue().strip()
data = json.loads(payload_content)
keys = [k.split(":")[0] for k in data]
self.assertTrue("aoti_torch_cpu_convolution" in keys)
class ProvenanceTracingKernelContextTemplate:
def test_jit_inductor_with_flag(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x, a, b, c):
x = self.fc1(x)
x = self.relu(x)
x = self.sigmoid(x)
d = a * 3.14
y = torch.addmm(c, d, b)
z = torch.nn.functional.gelu(y)
return x, z
model = Model().to(self.device)
x = torch.randn(8, 10).to(self.device)
a = torch.randn(10, 20).to(self.device)
b = torch.randn(20, 30).to(self.device)
c = torch.randn(10, 30).to(self.device)
example_inputs = (x, a, b, c)
with config.patch(
{
"cpp.enable_kernel_profile": True,
}
):
torch.compile(model)(*example_inputs)
@unittest.skipIf(sys.platform == "darwin", "Different kernel names on MacOS")
def test_aoti_python_stack_traces(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x, a, b, c):
x = self.fc1(x)
x = self.relu(x)
x = self.sigmoid(x)
d = a * 3.14
y = torch.addmm(c, d, b)
z = torch.nn.functional.gelu(y)
return x, z
x = torch.randn(8, 10).to(self.device)
a = torch.randn(10, 20).to(self.device)
b = torch.randn(20, 30).to(self.device)
c = torch.randn(10, 30).to(self.device)
example_inputs = (x, a, b, c)
model = Model().to(self.device)
ep = torch.export.export(model, example_inputs)
_, code = run_and_get_cpp_code(torch._inductor.aoti_compile_and_package, ep)
self.assertTrue("KernelContextGuard" not in code)
with config.patch(
{
"trace.provenance_tracking_level": 1,
"cpp.enable_kernel_profile": True,
}
):
package_path, code = run_and_get_cpp_code(
torch._inductor.aoti_compile_and_package, ep
)
FileCheck().check(
"#include <torch/csrc/inductor/aoti_runtime/kernel_context_tls.h>"
).check("thread_local KernelContext* tls_kernel_context = nullptr;").run(
code
)
if self.device == "cuda":
FileCheck().check(
"""KernelContextGuard _ctx("aoti_torch_cuda_mm_out", R"("""
).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda_mm_out(").check(
"""KernelContextGuard _ctx("triton_poi_fused_addmm_relu_sigmoid_0", R"("""
).check("call_triton_poi_fused_addmm_relu_sigmoid_0(").check(
"""KernelContextGuard _ctx("triton_poi_fused_mul_1", R"("""
).check("call_triton_poi_fused_mul_1(").check(
"""KernelContextGuard _ctx("aoti_torch_cuda_mm_out", R"("""
).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda_mm_out(").check(
""" KernelContextGuard _ctx("triton_poi_fused_addmm_gelu_2", R"("""
).check("call_triton_poi_fused_addmm_gelu_2(").run(code)
else:
FileCheck().check(
"""KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"("""
).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(").check(
"""KernelContextGuard _ctx("cpp_fused_mul_relu_sigmoid_0", R"("""
).check("cpp_fused_mul_relu_sigmoid_0(").check(
"""KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"("""
).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(").check(
""" KernelContextGuard _ctx("cpp_fused_gelu_1", R"("""
).check("cpp_fused_gelu_1(").run(code)
compiled_model = torch._inductor.aoti_load_package(package_path)
result = compiled_model(*example_inputs)
self.assertEqual(result, model(*example_inputs))
class TestProvenanceTracingKernelContextCpu(TestCase):
device = "cpu"
copy_tests(
ProvenanceTracingKernelContextTemplate,
TestProvenanceTracingKernelContextCpu,
"cpu",
)
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
@unittest.skipIf(not torch.cuda.is_available(), "No CUDA")
class TestProvenanceTracingKernelContextGpu(TestCase):
device = "cuda"
copy_tests(
ProvenanceTracingKernelContextTemplate,
TestProvenanceTracingKernelContextGpu,
"cuda",
)
if __name__ == "__main__":
run_tests()