Add python stack trace to AOTI generated code (#160539)

Summary:
We add a thread_local KernelContext object so Strobelight (and other potential profilers) can read the stack trace information of the running kernel.

This will bring extra overhead, so we guard this behind the `cpp.enable_kernel_profile` flag.

Example output code:

```cpp
#include <torch/csrc/inductor/aoti_runtime/kernel_context_tls.h>
namespace torch::aot_inductor {
thread_local KernelContext* tls_kernel_context = nullptr;
}
// Other code .....
void AOTInductorModel::run_impl(
    AtenTensorHandle*
        input_handles, // array of input AtenTensorHandle; handles
                        // are stolen; the array itself is borrowed
    AtenTensorHandle*
        output_handles, // array for writing output AtenTensorHandle; handles
                        // will be stolen by the caller; the array itself is
                        // borrowed
    DeviceStreamType stream,
    AOTIProxyExecutorHandle proxy_executor
) {
    __check_inputs_outputs(input_handles, output_handles);
    auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 4);
    auto arg2_1 = std::move(inputs[0]);
    auto arg3_1 = std::move(inputs[1]);
    auto arg4_1 = std::move(inputs[2]);
    auto arg5_1 = std::move(inputs[3]);
    [[maybe_unused]] auto& fc1_weight = constants_->at(0);
    [[maybe_unused]] auto& fc1_bias = constants_->at(1);
    inputs.clear();
    [[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
    static constexpr int64_t int_array_0[] = {8L, 16L};
    static constexpr int64_t int_array_1[] = {16L, 1L};
    AtenTensorHandle buf0_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf0_handle));
    RAIIAtenTensorHandle buf0(buf0_handle);
    // Topologically Sorted Source Nodes: [linear], Original ATen: [aten.t, aten.addmm]
    // [Provenance debug handles] aoti_torch_cpu_addmm_out:1
    static constexpr int64_t int_array_2[] = {10L, 16L};
    static constexpr int64_t int_array_3[] = {1L, 10L};
    {
    KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 829, in forward
    x = self.fc1(x)
  File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/linear.py", line 134, in forward
    return F.linear(input, self.weight, self.bias)
)");
    RAIIAtenRecordFunctionHandle record_aoti_torch_cpu_addmm_out_("aoti_torch_cpu_addmm_out", nullptr);
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(buf0, fc1_bias, arg2_1, wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper(fc1_weight, 2, int_array_2, int_array_3, 0L)), 1L, 1L));
    }
    arg2_1.reset();
    auto buf1 = std::move(buf0);  // reuse
    static constexpr int64_t int_array_4[] = {10L, 20L};
    static constexpr int64_t int_array_5[] = {20L, 1L};
    AtenTensorHandle buf2_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_4, int_array_5, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf2_handle));
    RAIIAtenTensorHandle buf2(buf2_handle);
    // [Provenance debug handles] cpp_fused_mul_relu_sigmoid_0:2
    {
    KernelContextGuard _ctx("cpp_fused_mul_relu_sigmoid_0", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 831, in forward
    x = self.sigmoid(x)
  File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/activation.py", line 359, in forward
    return torch.sigmoid(input)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 830, in forward
    x = self.relu(x)
  File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/activation.py", line 144, in forward
    return F.relu(input, inplace=self.inplace)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 832, in forward
    d = a * 3.14
)");
    cpp_fused_mul_relu_sigmoid_0((float*)(buf1.data_ptr()), (const float*)(arg3_1.data_ptr()), (float*)(buf2.data_ptr()));
    }
    arg3_1.reset();
    static constexpr int64_t int_array_6[] = {10L, 30L};
    static constexpr int64_t int_array_7[] = {30L, 1L};
    AtenTensorHandle buf3_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_6, int_array_7, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf3_handle));
    RAIIAtenTensorHandle buf3(buf3_handle);
    // Topologically Sorted Source Nodes: [mul, addmm], Original ATen: [aten.mul, aten.addmm]
    // [Provenance debug handles] aoti_torch_cpu_addmm_out:3
    {
    KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 833, in forward
    y = torch.addmm(c, d, b)
)");
    RAIIAtenRecordFunctionHandle record_aoti_torch_cpu_addmm_out_("aoti_torch_cpu_addmm_out", nullptr);
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(buf3, arg5_1, buf2, arg4_1, 1L, 1L));
    }
    arg4_1.reset();
    arg5_1.reset();
    buf2.reset();
    auto buf4 = std::move(buf3);  // reuse
    // [Provenance debug handles] cpp_fused_gelu_1:4
    {
    KernelContextGuard _ctx("cpp_fused_gelu_1", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 834, in forward
    z = torch.nn.functional.gelu(y)
)");
    cpp_fused_gelu_1((float*)(buf4.data_ptr()));
    }
    output_handles[0] = buf1.release();
    output_handles[1] = buf4.release();
} // AOTInductorModel::run_impl
```

Test Plan:
```
buck run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing -- -r  stack_traces
```

Rollback Plan:

Differential Revision: D78436007

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160539
Approved by: https://github.com/yiming0416
This commit is contained in:
Shangdi Yu 2025-10-29 22:47:49 +00:00 committed by PyTorch MergeBot
parent 972030fe2e
commit d2eff5d454
7 changed files with 344 additions and 25 deletions

View File

@ -7,6 +7,7 @@ import logging
import os
import re
import shutil
import sys
import tempfile
import unittest
import zipfile
@ -24,7 +25,7 @@ from torch._inductor.debug import (
)
from torch._inductor.fx_passes.post_grad import post_grad_passes
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code
from torch._inductor.utils import run_and_get_code, run_and_get_cpp_code
from torch._inductor.virtualized import V
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.triton_utils import requires_cuda_and_triton
@ -32,8 +33,12 @@ from torch.testing._internal.triton_utils import requires_cuda_and_triton
try:
from .test_aot_inductor_utils import AOTIRunnerUtil
from .test_torchinductor import copy_tests
except ImportError:
from test_aot_inductor_utils import AOTIRunnerUtil
from test_torchinductor import (
copy_tests, # @manual=fbcode//caffe2/test/inductor:test_inductor-library
)
trace_log = logging.getLogger("torch.__trace")
@ -806,5 +811,135 @@ class TestProvenanceTracingStackTraces(TestCase):
self.assertTrue("aoti_torch_cpu_convolution" in keys)
class ProvenanceTracingKernelContextTemplate:
def test_jit_inductor_with_flag(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x, a, b, c):
x = self.fc1(x)
x = self.relu(x)
x = self.sigmoid(x)
d = a * 3.14
y = torch.addmm(c, d, b)
z = torch.nn.functional.gelu(y)
return x, z
model = Model().to(self.device)
x = torch.randn(8, 10).to(self.device)
a = torch.randn(10, 20).to(self.device)
b = torch.randn(20, 30).to(self.device)
c = torch.randn(10, 30).to(self.device)
example_inputs = (x, a, b, c)
with config.patch(
{
"cpp.enable_kernel_profile": True,
}
):
torch.compile(model)(*example_inputs)
@unittest.skipIf(sys.platform == "darwin", "Different kernel names on MacOS")
def test_aoti_python_stack_traces(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x, a, b, c):
x = self.fc1(x)
x = self.relu(x)
x = self.sigmoid(x)
d = a * 3.14
y = torch.addmm(c, d, b)
z = torch.nn.functional.gelu(y)
return x, z
x = torch.randn(8, 10).to(self.device)
a = torch.randn(10, 20).to(self.device)
b = torch.randn(20, 30).to(self.device)
c = torch.randn(10, 30).to(self.device)
example_inputs = (x, a, b, c)
model = Model().to(self.device)
ep = torch.export.export(model, example_inputs)
_, code = run_and_get_cpp_code(torch._inductor.aoti_compile_and_package, ep)
self.assertTrue("KernelContextGuard" not in code)
with config.patch(
{
"trace.provenance_tracking_level": 1,
"cpp.enable_kernel_profile": True,
}
):
package_path, code = run_and_get_cpp_code(
torch._inductor.aoti_compile_and_package, ep
)
FileCheck().check(
"#include <torch/csrc/inductor/aoti_runtime/kernel_context_tls.h>"
).check("thread_local KernelContext* tls_kernel_context = nullptr;").run(
code
)
if self.device == "cuda":
FileCheck().check(
"""KernelContextGuard _ctx("aoti_torch_cuda_mm_out", R"("""
).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda_mm_out(").check(
"""KernelContextGuard _ctx("triton_poi_fused_addmm_relu_sigmoid_0", R"("""
).check("call_triton_poi_fused_addmm_relu_sigmoid_0(").check(
"""KernelContextGuard _ctx("triton_poi_fused_mul_1", R"("""
).check("call_triton_poi_fused_mul_1(").check(
"""KernelContextGuard _ctx("aoti_torch_cuda_mm_out", R"("""
).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda_mm_out(").check(
""" KernelContextGuard _ctx("triton_poi_fused_addmm_gelu_2", R"("""
).check("call_triton_poi_fused_addmm_gelu_2(").run(code)
else:
FileCheck().check(
"""KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"("""
).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(").check(
"""KernelContextGuard _ctx("cpp_fused_mul_relu_sigmoid_0", R"("""
).check("cpp_fused_mul_relu_sigmoid_0(").check(
"""KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"("""
).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(").check(
""" KernelContextGuard _ctx("cpp_fused_gelu_1", R"("""
).check("cpp_fused_gelu_1(").run(code)
compiled_model = torch._inductor.aoti_load_package(package_path)
result = compiled_model(*example_inputs)
self.assertEqual(result, model(*example_inputs))
class TestProvenanceTracingKernelContextCpu(TestCase):
device = "cpu"
copy_tests(
ProvenanceTracingKernelContextTemplate,
TestProvenanceTracingKernelContextCpu,
"cpu",
)
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
@unittest.skipIf(not torch.cuda.is_available(), "No CUDA")
class TestProvenanceTracingKernelContextGpu(TestCase):
device = "cuda"
copy_tests(
ProvenanceTracingKernelContextTemplate,
TestProvenanceTracingKernelContextGpu,
"cuda",
)
if __name__ == "__main__":
run_tests()

View File

@ -91,7 +91,7 @@ class TestUtils(TestCase):
def test_flops_fx(self):
def create_fx_node(
aten: torch._ops.OpOverloadPacket, args, kwargs
aten, op_overload: torch._ops.OpOverload, args, kwargs
) -> tuple[torch.fx.Node, torch.fx.Node]:
node1 = torch.fx.Node(
graph=torch.fx.Graph(),
@ -101,8 +101,13 @@ class TestUtils(TestCase):
args=args,
kwargs=kwargs,
)
name: str = aten.overloads()[0]
op_overload: torch._ops.OpOverload = getattr(aten, name)
# name: str = aten.overloads()[0]
# if aten == torch.ops.aten.addmm:
# name = "default"
# print(aten)
# print(aten.overloads())
# print(name)
# op_overload: torch._ops.OpOverload = getattr(aten, name)
node2 = torch.fx.Node(
graph=torch.fx.Graph(),
name="",
@ -119,17 +124,25 @@ class TestUtils(TestCase):
trues = [
(
torch.ops.aten.addmm,
torch.ops.aten.addmm.default,
(torch.Tensor(4, 4), torch.Tensor(4, 5), torch.Tensor(5, 4)),
{},
),
(
torch.ops.aten.bmm,
torch.ops.aten.bmm.default,
(torch.Tensor(10, 4, 5), torch.Tensor(10, 5, 4)),
{},
),
(torch.ops.aten.mm, (torch.Tensor(2, 3), torch.Tensor(3, 2)), {}),
(
torch.ops.aten.mm,
torch.ops.aten.mm.default,
(torch.Tensor(2, 3), torch.Tensor(3, 2)),
{},
),
(
torch.ops.aten.convolution,
torch.ops.aten.convolution.default,
(
torch.Tensor(2, 2, 3),
torch.Tensor(2, 2, 2),
@ -145,6 +158,7 @@ class TestUtils(TestCase):
),
(
torch.ops.aten._convolution,
torch.ops.aten._convolution.deprecated,
(
torch.Tensor(2, 2, 2),
torch.Tensor(2, 2, 2),
@ -166,17 +180,19 @@ class TestUtils(TestCase):
falses = [
(
torch.ops.aten.add,
torch.ops.aten.add.Tensor,
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
{},
),
(
torch.ops.aten.mul,
torch.ops.aten.mul.Tensor,
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
{},
),
]
for t, args, kwargs in trues:
fx_node_1, fx_node_2 = create_fx_node(t, args, kwargs)
for t, t2, args, kwargs in trues:
fx_node_1, fx_node_2 = create_fx_node(t, t2, args, kwargs)
self.assertTrue(
countable_fx(fx_node_1), f"Expected true {t}: {fx_node_1}"
)
@ -185,8 +201,8 @@ class TestUtils(TestCase):
)
self.assertNotEqual(count_flops_fx(fx_node_1), None)
self.assertNotEqual(count_flops_fx(fx_node_2), None)
for f, args, kwargs in falses:
fx_node_1, fx_node_2 = create_fx_node(f, args, kwargs)
for f, f2, args, kwargs in falses:
fx_node_1, fx_node_2 = create_fx_node(f, f2, args, kwargs)
self.assertFalse(
countable_fx(fx_node_1), f"Expected false {f}: {fx_node_1}"
)

View File

@ -5469,7 +5469,16 @@ class CppScheduling(BaseScheduling):
src_code, self.kernel_group.scheduled_nodes
)
self.codegen_comment(self.kernel_group.scheduled_nodes, kernel_name)
if config.cpp.enable_kernel_profile:
V.graph.wrapper_code.write_kernel_context_guard_begin()
V.graph.wrapper_code.write_kernel_context_guard(
kernel_name,
self.kernel_group.scheduled_nodes, # type: ignore[arg-type]
)
self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name)
if config.cpp.enable_kernel_profile:
V.graph.wrapper_code.write_kernel_context_guard_end()
self.reset_kernel_group()
self._set_flush_status(False)

View File

@ -22,6 +22,7 @@ from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import config, cpp_builder, ir
from ..ir import ExternKernel
from ..utils import _align, DeferredLineBase, LineContext, normalize_name
from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
@ -43,6 +44,8 @@ if TYPE_CHECKING:
# At most, the list nesting can go one layer deep.
_OUTPUT_ARGS_TYPE = list[Union[Optional[str], list[Optional[str]]]]
from ..scheduler import BaseSchedulerNode
class HasWriteLine(Protocol):
def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None: ...
@ -233,6 +236,18 @@ class CppWrapperCpu(PythonWrapperCodegen):
self.header.splice(f"""#include \"{self.model_class_name_suffix}.h\"""")
self.header.splice("\n")
if config.cpp.enable_kernel_profile:
self.header.splice(
"#include <torch/csrc/inductor/aoti_runtime/kernel_context_tls.h>"
)
self.header.splice(
"""
namespace torch::aot_inductor {
thread_local KernelContext* tls_kernel_context = nullptr;
}
"""
)
def _include_extra_header(self, header: str):
# This is needed for cpp to python dtype conversion
self.header.splice(f"#include <{header}>")
@ -1249,7 +1264,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
device: str,
*,
debug_args: Optional[list[str]] = None,
debug_handle: Optional[int] = None,
stack_traces: Optional[OrderedSet[str]] = None,
) -> None:
"""debug_args kwarg allows CppWrapperCpuArrayRef to pass in wrapped arguments in
place of args while preserving debug printer output."""
@ -1266,21 +1281,26 @@ class CppWrapperCpu(PythonWrapperCodegen):
]
with debug_printer_manager:
shim_fn = self.get_c_shim_func_name(kernel, device)
self.write_provenance_debug_handle(shim_fn, debug_handle)
shim_fn_codes = (
shim_fn_codes = [
f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));"
)
]
if enable_kernel_profile:
debug_handle_str = "" if debug_handle is None else f":{debug_handle}"
shim_fn_codes = textwrap.dedent(
f"""
{{
RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}{debug_handle_str}", nullptr);
{shim_fn_codes}
}}
"""
)
self.writeline(shim_fn_codes)
stack_trace_str = 'R"('
if stack_traces:
for stack_trace in stack_traces:
for line in stack_trace.split("\n"):
stack_trace_str += f"\n{line}"
stack_trace_str += "\n"
stack_trace_str += ')"'
shim_fn_codes = [
"{",
f"""KernelContextGuard _ctx("{shim_fn}", {stack_trace_str});""",
f"""RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}", nullptr);""",
shim_fn_codes[0],
"}",
]
self.writelines(shim_fn_codes)
def generate_c_shim_extern_kernel_alloc(
self, extern_kernel: ir.ExternKernelAlloc, args: list[str]
@ -1373,7 +1393,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
out_view: Optional[str],
args: list[str],
device: str,
debug_handle: Optional[int] = None,
stack_traces: Optional[OrderedSet[str]] = None,
) -> None:
if out_view:
out_name = f"{out}_as_strided"
@ -1383,7 +1403,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
args.insert(0, out)
self.generate_c_shim_extern_kernel_call(
kernel, args, device, debug_handle=debug_handle
kernel, args, device, stack_traces=stack_traces
)
def _get_scatter_reduce_enum(self, reduce):
@ -2897,3 +2917,53 @@ if (!custom_op_wrapper) {
writer.writeline(call_str)
return tmp_var_name
def write_kernel_context_guard_begin(
self,
):
# Beginning of a kernel context guarded block.
# The block looks like this:
# {
# KernelContextGuard _ctx("{kernel_name}", {stack_trace_str});
# ... operations...
# }
self.writeline("{")
def write_kernel_context_guard_end(
self,
):
# End of a kernel context guarded block.
self.writeline("}")
def write_kernel_context_guard(
self,
kernel_name: str,
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
):
def aggregate_stack_traces(
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
) -> OrderedSet[str]:
if isinstance(node_schedule, list):
return functools.reduce(
lambda a, b: a | b,
[
node.node.get_stack_traces()
for node in node_schedule
if hasattr(node, "node") and node.node
],
OrderedSet(),
)
elif isinstance(node_schedule, ExternKernel):
return node_schedule.get_stack_traces()
else:
return OrderedSet()
stack_trace_str = 'R"('
stack_traces = aggregate_stack_traces(node_schedule)
for stack_trace in stack_traces:
for line in stack_trace.split("\n"):
stack_trace_str += f"\n{line}"
stack_trace_str += "\n"
stack_trace_str += ')"'
self.writeline(f'KernelContextGuard _ctx("{kernel_name}", {stack_trace_str});')

View File

@ -1707,6 +1707,9 @@ class SIMDScheduling(BaseScheduling):
return True
def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures):
"""
Generate code for nodes in kernel_features
"""
node_schedule = kernel_features.node_schedule
tiling, tiling_score = self.get_tiling_and_scores(
@ -1747,7 +1750,15 @@ class SIMDScheduling(BaseScheduling):
node for node in node_schedule if isinstance(node, BaseSchedulerNode)
]
self.codegen_comment(base_scheduler_nodes, final_kernel.kernel_name)
if config.cpp.enable_kernel_profile:
V.graph.wrapper_code.write_kernel_context_guard_begin()
V.graph.wrapper_code.write_kernel_context_guard(
final_kernel.kernel_name,
base_scheduler_nodes, # type: ignore[arg-type]
)
final_kernel.call_kernel(final_kernel.kernel_name)
if config.cpp.enable_kernel_profile:
V.graph.wrapper_code.write_kernel_context_guard_end()
if config.nan_asserts:
final_kernel.codegen_nan_check()

View File

@ -77,6 +77,8 @@ if TYPE_CHECKING:
import triton
from ..graph import GraphLowering
from ..ir import ExternKernel
from ..scheduler import BaseSchedulerNode
from .wrapper_fxir import FxConverter
@ -528,6 +530,7 @@ class ExternKernelOutLine(WrapperLine):
node.output_view.codegen_reference() if node.output_view else None,
args,
device,
self.node.get_stack_traces(),
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
@ -1554,6 +1557,7 @@ class PythonWrapperCodegen(CodeGen):
out_view: Optional[str],
args: list[str],
device: str,
stack_traces: Optional[OrderedSet[str]] = None,
) -> None:
# add debug printer code for triton kernel calls at (jit) inductor level
debug_printer_manager = V.graph.wrapper_code.debug_printer
@ -3690,6 +3694,29 @@ class PythonWrapperCodegen(CodeGen):
def can_prove_buffer_has_static_shape(buffer):
return PythonWrapperCodegen.static_shape_for_buffer_or_none(buffer) is not None
def write_kernel_context_guard(
self,
kernel_name: str,
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
):
return
def write_kernel_context_guard_begin(
self,
):
"""
Mark the beginning of kernel context guard
"""
return
def write_kernel_context_guard_end(
self,
):
"""
Mark the end of kernel context guard
"""
return
class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
"""

View File

@ -0,0 +1,51 @@
#pragma once
#include <string>
#include <utility>
namespace torch::aot_inductor {
struct KernelContext {
std::string kernel_name;
std::string python_stack;
KernelContext(std::string name, std::string stack)
: kernel_name(std::move(name)), python_stack(std::move(stack)) {}
};
// Thread-local pointer
extern thread_local KernelContext* tls_kernel_context;
inline KernelContext* current_kernel_context() {
return tls_kernel_context;
}
inline void set_kernel_context(KernelContext* ctx) {
tls_kernel_context = ctx;
}
inline void clear_kernel_context() {
tls_kernel_context = nullptr;
}
struct KernelContextGuard {
KernelContextGuard(const std::string& name, const std::string& stack)
: owned_context_(name, stack) {
set_kernel_context(&owned_context_);
}
~KernelContextGuard() {
clear_kernel_context();
}
// Delete copy constructor and copy assignment operator
KernelContextGuard(const KernelContextGuard&) = delete;
KernelContextGuard& operator=(const KernelContextGuard&) = delete;
KernelContextGuard(KernelContextGuard&&) = default;
KernelContextGuard& operator=(KernelContextGuard&&) = delete;
private:
KernelContext owned_context_;
};
} // namespace torch::aot_inductor