mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Add python stack trace to AOTI generated code (#160539)
Summary:
We add a thread_local KernelContext object so Strobelight (and other potential profilers) can read the stack trace information of the running kernel.
This will bring extra overhead, so we guard this behind the `cpp.enable_kernel_profile` flag.
Example output code:
```cpp
#include <torch/csrc/inductor/aoti_runtime/kernel_context_tls.h>
namespace torch::aot_inductor {
thread_local KernelContext* tls_kernel_context = nullptr;
}
// Other code .....
void AOTInductorModel::run_impl(
AtenTensorHandle*
input_handles, // array of input AtenTensorHandle; handles
// are stolen; the array itself is borrowed
AtenTensorHandle*
output_handles, // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor
) {
__check_inputs_outputs(input_handles, output_handles);
auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 4);
auto arg2_1 = std::move(inputs[0]);
auto arg3_1 = std::move(inputs[1]);
auto arg4_1 = std::move(inputs[2]);
auto arg5_1 = std::move(inputs[3]);
[[maybe_unused]] auto& fc1_weight = constants_->at(0);
[[maybe_unused]] auto& fc1_bias = constants_->at(1);
inputs.clear();
[[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
static constexpr int64_t int_array_0[] = {8L, 16L};
static constexpr int64_t int_array_1[] = {16L, 1L};
AtenTensorHandle buf0_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf0_handle));
RAIIAtenTensorHandle buf0(buf0_handle);
// Topologically Sorted Source Nodes: [linear], Original ATen: [aten.t, aten.addmm]
// [Provenance debug handles] aoti_torch_cpu_addmm_out:1
static constexpr int64_t int_array_2[] = {10L, 16L};
static constexpr int64_t int_array_3[] = {1L, 10L};
{
KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 829, in forward
x = self.fc1(x)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/linear.py", line 134, in forward
return F.linear(input, self.weight, self.bias)
)");
RAIIAtenRecordFunctionHandle record_aoti_torch_cpu_addmm_out_("aoti_torch_cpu_addmm_out", nullptr);
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(buf0, fc1_bias, arg2_1, wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper(fc1_weight, 2, int_array_2, int_array_3, 0L)), 1L, 1L));
}
arg2_1.reset();
auto buf1 = std::move(buf0); // reuse
static constexpr int64_t int_array_4[] = {10L, 20L};
static constexpr int64_t int_array_5[] = {20L, 1L};
AtenTensorHandle buf2_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_4, int_array_5, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf2_handle));
RAIIAtenTensorHandle buf2(buf2_handle);
// [Provenance debug handles] cpp_fused_mul_relu_sigmoid_0:2
{
KernelContextGuard _ctx("cpp_fused_mul_relu_sigmoid_0", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 831, in forward
x = self.sigmoid(x)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/activation.py", line 359, in forward
return torch.sigmoid(input)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 830, in forward
x = self.relu(x)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/activation.py", line 144, in forward
return F.relu(input, inplace=self.inplace)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 832, in forward
d = a * 3.14
)");
cpp_fused_mul_relu_sigmoid_0((float*)(buf1.data_ptr()), (const float*)(arg3_1.data_ptr()), (float*)(buf2.data_ptr()));
}
arg3_1.reset();
static constexpr int64_t int_array_6[] = {10L, 30L};
static constexpr int64_t int_array_7[] = {30L, 1L};
AtenTensorHandle buf3_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_6, int_array_7, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf3_handle));
RAIIAtenTensorHandle buf3(buf3_handle);
// Topologically Sorted Source Nodes: [mul, addmm], Original ATen: [aten.mul, aten.addmm]
// [Provenance debug handles] aoti_torch_cpu_addmm_out:3
{
KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 833, in forward
y = torch.addmm(c, d, b)
)");
RAIIAtenRecordFunctionHandle record_aoti_torch_cpu_addmm_out_("aoti_torch_cpu_addmm_out", nullptr);
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(buf3, arg5_1, buf2, arg4_1, 1L, 1L));
}
arg4_1.reset();
arg5_1.reset();
buf2.reset();
auto buf4 = std::move(buf3); // reuse
// [Provenance debug handles] cpp_fused_gelu_1:4
{
KernelContextGuard _ctx("cpp_fused_gelu_1", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 834, in forward
z = torch.nn.functional.gelu(y)
)");
cpp_fused_gelu_1((float*)(buf4.data_ptr()));
}
output_handles[0] = buf1.release();
output_handles[1] = buf4.release();
} // AOTInductorModel::run_impl
```
Test Plan:
```
buck run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing -- -r stack_traces
```
Rollback Plan:
Differential Revision: D78436007
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160539
Approved by: https://github.com/yiming0416
This commit is contained in:
parent
972030fe2e
commit
d2eff5d454
|
|
@ -7,6 +7,7 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import zipfile
|
||||
|
|
@ -24,7 +25,7 @@ from torch._inductor.debug import (
|
|||
)
|
||||
from torch._inductor.fx_passes.post_grad import post_grad_passes
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch._inductor.utils import run_and_get_code, run_and_get_cpp_code
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.testing._internal.common_utils import IS_MACOS
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
|
@ -32,8 +33,12 @@ from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
|||
|
||||
try:
|
||||
from .test_aot_inductor_utils import AOTIRunnerUtil
|
||||
from .test_torchinductor import copy_tests
|
||||
except ImportError:
|
||||
from test_aot_inductor_utils import AOTIRunnerUtil
|
||||
from test_torchinductor import (
|
||||
copy_tests, # @manual=fbcode//caffe2/test/inductor:test_inductor-library
|
||||
)
|
||||
|
||||
|
||||
trace_log = logging.getLogger("torch.__trace")
|
||||
|
|
@ -806,5 +811,135 @@ class TestProvenanceTracingStackTraces(TestCase):
|
|||
self.assertTrue("aoti_torch_cpu_convolution" in keys)
|
||||
|
||||
|
||||
class ProvenanceTracingKernelContextTemplate:
|
||||
def test_jit_inductor_with_flag(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x, a, b, c):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.sigmoid(x)
|
||||
d = a * 3.14
|
||||
y = torch.addmm(c, d, b)
|
||||
z = torch.nn.functional.gelu(y)
|
||||
return x, z
|
||||
|
||||
model = Model().to(self.device)
|
||||
x = torch.randn(8, 10).to(self.device)
|
||||
a = torch.randn(10, 20).to(self.device)
|
||||
b = torch.randn(20, 30).to(self.device)
|
||||
c = torch.randn(10, 30).to(self.device)
|
||||
example_inputs = (x, a, b, c)
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"cpp.enable_kernel_profile": True,
|
||||
}
|
||||
):
|
||||
torch.compile(model)(*example_inputs)
|
||||
|
||||
@unittest.skipIf(sys.platform == "darwin", "Different kernel names on MacOS")
|
||||
def test_aoti_python_stack_traces(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x, a, b, c):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.sigmoid(x)
|
||||
d = a * 3.14
|
||||
y = torch.addmm(c, d, b)
|
||||
z = torch.nn.functional.gelu(y)
|
||||
return x, z
|
||||
|
||||
x = torch.randn(8, 10).to(self.device)
|
||||
a = torch.randn(10, 20).to(self.device)
|
||||
b = torch.randn(20, 30).to(self.device)
|
||||
c = torch.randn(10, 30).to(self.device)
|
||||
example_inputs = (x, a, b, c)
|
||||
model = Model().to(self.device)
|
||||
|
||||
ep = torch.export.export(model, example_inputs)
|
||||
_, code = run_and_get_cpp_code(torch._inductor.aoti_compile_and_package, ep)
|
||||
|
||||
self.assertTrue("KernelContextGuard" not in code)
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"trace.provenance_tracking_level": 1,
|
||||
"cpp.enable_kernel_profile": True,
|
||||
}
|
||||
):
|
||||
package_path, code = run_and_get_cpp_code(
|
||||
torch._inductor.aoti_compile_and_package, ep
|
||||
)
|
||||
|
||||
FileCheck().check(
|
||||
"#include <torch/csrc/inductor/aoti_runtime/kernel_context_tls.h>"
|
||||
).check("thread_local KernelContext* tls_kernel_context = nullptr;").run(
|
||||
code
|
||||
)
|
||||
|
||||
if self.device == "cuda":
|
||||
FileCheck().check(
|
||||
"""KernelContextGuard _ctx("aoti_torch_cuda_mm_out", R"("""
|
||||
).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda_mm_out(").check(
|
||||
"""KernelContextGuard _ctx("triton_poi_fused_addmm_relu_sigmoid_0", R"("""
|
||||
).check("call_triton_poi_fused_addmm_relu_sigmoid_0(").check(
|
||||
"""KernelContextGuard _ctx("triton_poi_fused_mul_1", R"("""
|
||||
).check("call_triton_poi_fused_mul_1(").check(
|
||||
"""KernelContextGuard _ctx("aoti_torch_cuda_mm_out", R"("""
|
||||
).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda_mm_out(").check(
|
||||
""" KernelContextGuard _ctx("triton_poi_fused_addmm_gelu_2", R"("""
|
||||
).check("call_triton_poi_fused_addmm_gelu_2(").run(code)
|
||||
else:
|
||||
FileCheck().check(
|
||||
"""KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"("""
|
||||
).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(").check(
|
||||
"""KernelContextGuard _ctx("cpp_fused_mul_relu_sigmoid_0", R"("""
|
||||
).check("cpp_fused_mul_relu_sigmoid_0(").check(
|
||||
"""KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"("""
|
||||
).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(").check(
|
||||
""" KernelContextGuard _ctx("cpp_fused_gelu_1", R"("""
|
||||
).check("cpp_fused_gelu_1(").run(code)
|
||||
|
||||
compiled_model = torch._inductor.aoti_load_package(package_path)
|
||||
result = compiled_model(*example_inputs)
|
||||
self.assertEqual(result, model(*example_inputs))
|
||||
|
||||
|
||||
class TestProvenanceTracingKernelContextCpu(TestCase):
|
||||
device = "cpu"
|
||||
|
||||
|
||||
copy_tests(
|
||||
ProvenanceTracingKernelContextTemplate,
|
||||
TestProvenanceTracingKernelContextCpu,
|
||||
"cpu",
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "No CUDA")
|
||||
class TestProvenanceTracingKernelContextGpu(TestCase):
|
||||
device = "cuda"
|
||||
|
||||
|
||||
copy_tests(
|
||||
ProvenanceTracingKernelContextTemplate,
|
||||
TestProvenanceTracingKernelContextGpu,
|
||||
"cuda",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ class TestUtils(TestCase):
|
|||
|
||||
def test_flops_fx(self):
|
||||
def create_fx_node(
|
||||
aten: torch._ops.OpOverloadPacket, args, kwargs
|
||||
aten, op_overload: torch._ops.OpOverload, args, kwargs
|
||||
) -> tuple[torch.fx.Node, torch.fx.Node]:
|
||||
node1 = torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
|
|
@ -101,8 +101,13 @@ class TestUtils(TestCase):
|
|||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
name: str = aten.overloads()[0]
|
||||
op_overload: torch._ops.OpOverload = getattr(aten, name)
|
||||
# name: str = aten.overloads()[0]
|
||||
# if aten == torch.ops.aten.addmm:
|
||||
# name = "default"
|
||||
# print(aten)
|
||||
# print(aten.overloads())
|
||||
# print(name)
|
||||
# op_overload: torch._ops.OpOverload = getattr(aten, name)
|
||||
node2 = torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
name="",
|
||||
|
|
@ -119,17 +124,25 @@ class TestUtils(TestCase):
|
|||
trues = [
|
||||
(
|
||||
torch.ops.aten.addmm,
|
||||
torch.ops.aten.addmm.default,
|
||||
(torch.Tensor(4, 4), torch.Tensor(4, 5), torch.Tensor(5, 4)),
|
||||
{},
|
||||
),
|
||||
(
|
||||
torch.ops.aten.bmm,
|
||||
torch.ops.aten.bmm.default,
|
||||
(torch.Tensor(10, 4, 5), torch.Tensor(10, 5, 4)),
|
||||
{},
|
||||
),
|
||||
(torch.ops.aten.mm, (torch.Tensor(2, 3), torch.Tensor(3, 2)), {}),
|
||||
(
|
||||
torch.ops.aten.mm,
|
||||
torch.ops.aten.mm.default,
|
||||
(torch.Tensor(2, 3), torch.Tensor(3, 2)),
|
||||
{},
|
||||
),
|
||||
(
|
||||
torch.ops.aten.convolution,
|
||||
torch.ops.aten.convolution.default,
|
||||
(
|
||||
torch.Tensor(2, 2, 3),
|
||||
torch.Tensor(2, 2, 2),
|
||||
|
|
@ -145,6 +158,7 @@ class TestUtils(TestCase):
|
|||
),
|
||||
(
|
||||
torch.ops.aten._convolution,
|
||||
torch.ops.aten._convolution.deprecated,
|
||||
(
|
||||
torch.Tensor(2, 2, 2),
|
||||
torch.Tensor(2, 2, 2),
|
||||
|
|
@ -166,17 +180,19 @@ class TestUtils(TestCase):
|
|||
falses = [
|
||||
(
|
||||
torch.ops.aten.add,
|
||||
torch.ops.aten.add.Tensor,
|
||||
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
|
||||
{},
|
||||
),
|
||||
(
|
||||
torch.ops.aten.mul,
|
||||
torch.ops.aten.mul.Tensor,
|
||||
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
|
||||
{},
|
||||
),
|
||||
]
|
||||
for t, args, kwargs in trues:
|
||||
fx_node_1, fx_node_2 = create_fx_node(t, args, kwargs)
|
||||
for t, t2, args, kwargs in trues:
|
||||
fx_node_1, fx_node_2 = create_fx_node(t, t2, args, kwargs)
|
||||
self.assertTrue(
|
||||
countable_fx(fx_node_1), f"Expected true {t}: {fx_node_1}"
|
||||
)
|
||||
|
|
@ -185,8 +201,8 @@ class TestUtils(TestCase):
|
|||
)
|
||||
self.assertNotEqual(count_flops_fx(fx_node_1), None)
|
||||
self.assertNotEqual(count_flops_fx(fx_node_2), None)
|
||||
for f, args, kwargs in falses:
|
||||
fx_node_1, fx_node_2 = create_fx_node(f, args, kwargs)
|
||||
for f, f2, args, kwargs in falses:
|
||||
fx_node_1, fx_node_2 = create_fx_node(f, f2, args, kwargs)
|
||||
self.assertFalse(
|
||||
countable_fx(fx_node_1), f"Expected false {f}: {fx_node_1}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5469,7 +5469,16 @@ class CppScheduling(BaseScheduling):
|
|||
src_code, self.kernel_group.scheduled_nodes
|
||||
)
|
||||
self.codegen_comment(self.kernel_group.scheduled_nodes, kernel_name)
|
||||
if config.cpp.enable_kernel_profile:
|
||||
V.graph.wrapper_code.write_kernel_context_guard_begin()
|
||||
V.graph.wrapper_code.write_kernel_context_guard(
|
||||
kernel_name,
|
||||
self.kernel_group.scheduled_nodes, # type: ignore[arg-type]
|
||||
)
|
||||
self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name)
|
||||
if config.cpp.enable_kernel_profile:
|
||||
V.graph.wrapper_code.write_kernel_context_guard_end()
|
||||
|
||||
self.reset_kernel_group()
|
||||
self._set_flush_status(False)
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from torch.utils._ordered_set import OrderedSet
|
|||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
|
||||
from .. import config, cpp_builder, ir
|
||||
from ..ir import ExternKernel
|
||||
from ..utils import _align, DeferredLineBase, LineContext, normalize_name
|
||||
from ..virtualized import V
|
||||
from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
||||
|
|
@ -43,6 +44,8 @@ if TYPE_CHECKING:
|
|||
# At most, the list nesting can go one layer deep.
|
||||
_OUTPUT_ARGS_TYPE = list[Union[Optional[str], list[Optional[str]]]]
|
||||
|
||||
from ..scheduler import BaseSchedulerNode
|
||||
|
||||
|
||||
class HasWriteLine(Protocol):
|
||||
def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None: ...
|
||||
|
|
@ -233,6 +236,18 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
self.header.splice(f"""#include \"{self.model_class_name_suffix}.h\"""")
|
||||
self.header.splice("\n")
|
||||
|
||||
if config.cpp.enable_kernel_profile:
|
||||
self.header.splice(
|
||||
"#include <torch/csrc/inductor/aoti_runtime/kernel_context_tls.h>"
|
||||
)
|
||||
self.header.splice(
|
||||
"""
|
||||
namespace torch::aot_inductor {
|
||||
thread_local KernelContext* tls_kernel_context = nullptr;
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
def _include_extra_header(self, header: str):
|
||||
# This is needed for cpp to python dtype conversion
|
||||
self.header.splice(f"#include <{header}>")
|
||||
|
|
@ -1249,7 +1264,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
device: str,
|
||||
*,
|
||||
debug_args: Optional[list[str]] = None,
|
||||
debug_handle: Optional[int] = None,
|
||||
stack_traces: Optional[OrderedSet[str]] = None,
|
||||
) -> None:
|
||||
"""debug_args kwarg allows CppWrapperCpuArrayRef to pass in wrapped arguments in
|
||||
place of args while preserving debug printer output."""
|
||||
|
|
@ -1266,21 +1281,26 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
]
|
||||
with debug_printer_manager:
|
||||
shim_fn = self.get_c_shim_func_name(kernel, device)
|
||||
self.write_provenance_debug_handle(shim_fn, debug_handle)
|
||||
shim_fn_codes = (
|
||||
shim_fn_codes = [
|
||||
f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));"
|
||||
)
|
||||
]
|
||||
if enable_kernel_profile:
|
||||
debug_handle_str = "" if debug_handle is None else f":{debug_handle}"
|
||||
shim_fn_codes = textwrap.dedent(
|
||||
f"""
|
||||
{{
|
||||
RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}{debug_handle_str}", nullptr);
|
||||
{shim_fn_codes}
|
||||
}}
|
||||
"""
|
||||
)
|
||||
self.writeline(shim_fn_codes)
|
||||
stack_trace_str = 'R"('
|
||||
if stack_traces:
|
||||
for stack_trace in stack_traces:
|
||||
for line in stack_trace.split("\n"):
|
||||
stack_trace_str += f"\n{line}"
|
||||
stack_trace_str += "\n"
|
||||
stack_trace_str += ')"'
|
||||
|
||||
shim_fn_codes = [
|
||||
"{",
|
||||
f"""KernelContextGuard _ctx("{shim_fn}", {stack_trace_str});""",
|
||||
f"""RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}", nullptr);""",
|
||||
shim_fn_codes[0],
|
||||
"}",
|
||||
]
|
||||
self.writelines(shim_fn_codes)
|
||||
|
||||
def generate_c_shim_extern_kernel_alloc(
|
||||
self, extern_kernel: ir.ExternKernelAlloc, args: list[str]
|
||||
|
|
@ -1373,7 +1393,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
out_view: Optional[str],
|
||||
args: list[str],
|
||||
device: str,
|
||||
debug_handle: Optional[int] = None,
|
||||
stack_traces: Optional[OrderedSet[str]] = None,
|
||||
) -> None:
|
||||
if out_view:
|
||||
out_name = f"{out}_as_strided"
|
||||
|
|
@ -1383,7 +1403,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
args.insert(0, out)
|
||||
|
||||
self.generate_c_shim_extern_kernel_call(
|
||||
kernel, args, device, debug_handle=debug_handle
|
||||
kernel, args, device, stack_traces=stack_traces
|
||||
)
|
||||
|
||||
def _get_scatter_reduce_enum(self, reduce):
|
||||
|
|
@ -2897,3 +2917,53 @@ if (!custom_op_wrapper) {
|
|||
writer.writeline(call_str)
|
||||
|
||||
return tmp_var_name
|
||||
|
||||
def write_kernel_context_guard_begin(
|
||||
self,
|
||||
):
|
||||
# Beginning of a kernel context guarded block.
|
||||
# The block looks like this:
|
||||
# {
|
||||
# KernelContextGuard _ctx("{kernel_name}", {stack_trace_str});
|
||||
# ... operations...
|
||||
# }
|
||||
self.writeline("{")
|
||||
|
||||
def write_kernel_context_guard_end(
|
||||
self,
|
||||
):
|
||||
# End of a kernel context guarded block.
|
||||
self.writeline("}")
|
||||
|
||||
def write_kernel_context_guard(
|
||||
self,
|
||||
kernel_name: str,
|
||||
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
|
||||
):
|
||||
def aggregate_stack_traces(
|
||||
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
|
||||
) -> OrderedSet[str]:
|
||||
if isinstance(node_schedule, list):
|
||||
return functools.reduce(
|
||||
lambda a, b: a | b,
|
||||
[
|
||||
node.node.get_stack_traces()
|
||||
for node in node_schedule
|
||||
if hasattr(node, "node") and node.node
|
||||
],
|
||||
OrderedSet(),
|
||||
)
|
||||
elif isinstance(node_schedule, ExternKernel):
|
||||
return node_schedule.get_stack_traces()
|
||||
else:
|
||||
return OrderedSet()
|
||||
|
||||
stack_trace_str = 'R"('
|
||||
stack_traces = aggregate_stack_traces(node_schedule)
|
||||
|
||||
for stack_trace in stack_traces:
|
||||
for line in stack_trace.split("\n"):
|
||||
stack_trace_str += f"\n{line}"
|
||||
stack_trace_str += "\n"
|
||||
stack_trace_str += ')"'
|
||||
self.writeline(f'KernelContextGuard _ctx("{kernel_name}", {stack_trace_str});')
|
||||
|
|
|
|||
|
|
@ -1707,6 +1707,9 @@ class SIMDScheduling(BaseScheduling):
|
|||
return True
|
||||
|
||||
def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures):
|
||||
"""
|
||||
Generate code for nodes in kernel_features
|
||||
"""
|
||||
node_schedule = kernel_features.node_schedule
|
||||
|
||||
tiling, tiling_score = self.get_tiling_and_scores(
|
||||
|
|
@ -1747,7 +1750,15 @@ class SIMDScheduling(BaseScheduling):
|
|||
node for node in node_schedule if isinstance(node, BaseSchedulerNode)
|
||||
]
|
||||
self.codegen_comment(base_scheduler_nodes, final_kernel.kernel_name)
|
||||
if config.cpp.enable_kernel_profile:
|
||||
V.graph.wrapper_code.write_kernel_context_guard_begin()
|
||||
V.graph.wrapper_code.write_kernel_context_guard(
|
||||
final_kernel.kernel_name,
|
||||
base_scheduler_nodes, # type: ignore[arg-type]
|
||||
)
|
||||
final_kernel.call_kernel(final_kernel.kernel_name)
|
||||
if config.cpp.enable_kernel_profile:
|
||||
V.graph.wrapper_code.write_kernel_context_guard_end()
|
||||
|
||||
if config.nan_asserts:
|
||||
final_kernel.codegen_nan_check()
|
||||
|
|
|
|||
|
|
@ -77,6 +77,8 @@ if TYPE_CHECKING:
|
|||
import triton
|
||||
|
||||
from ..graph import GraphLowering
|
||||
from ..ir import ExternKernel
|
||||
from ..scheduler import BaseSchedulerNode
|
||||
from .wrapper_fxir import FxConverter
|
||||
|
||||
|
||||
|
|
@ -528,6 +530,7 @@ class ExternKernelOutLine(WrapperLine):
|
|||
node.output_view.codegen_reference() if node.output_view else None,
|
||||
args,
|
||||
device,
|
||||
self.node.get_stack_traces(),
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
|
|
@ -1554,6 +1557,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
out_view: Optional[str],
|
||||
args: list[str],
|
||||
device: str,
|
||||
stack_traces: Optional[OrderedSet[str]] = None,
|
||||
) -> None:
|
||||
# add debug printer code for triton kernel calls at (jit) inductor level
|
||||
debug_printer_manager = V.graph.wrapper_code.debug_printer
|
||||
|
|
@ -3690,6 +3694,29 @@ class PythonWrapperCodegen(CodeGen):
|
|||
def can_prove_buffer_has_static_shape(buffer):
|
||||
return PythonWrapperCodegen.static_shape_for_buffer_or_none(buffer) is not None
|
||||
|
||||
def write_kernel_context_guard(
|
||||
self,
|
||||
kernel_name: str,
|
||||
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
|
||||
):
|
||||
return
|
||||
|
||||
def write_kernel_context_guard_begin(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Mark the beginning of kernel context guard
|
||||
"""
|
||||
return
|
||||
|
||||
def write_kernel_context_guard_end(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Mark the end of kernel context guard
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
|
||||
"""
|
||||
|
|
|
|||
51
torch/csrc/inductor/aoti_runtime/kernel_context_tls.h
Normal file
51
torch/csrc/inductor/aoti_runtime/kernel_context_tls.h
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace torch::aot_inductor {
|
||||
|
||||
struct KernelContext {
|
||||
std::string kernel_name;
|
||||
std::string python_stack;
|
||||
|
||||
KernelContext(std::string name, std::string stack)
|
||||
: kernel_name(std::move(name)), python_stack(std::move(stack)) {}
|
||||
};
|
||||
|
||||
// Thread-local pointer
|
||||
extern thread_local KernelContext* tls_kernel_context;
|
||||
|
||||
inline KernelContext* current_kernel_context() {
|
||||
return tls_kernel_context;
|
||||
}
|
||||
|
||||
inline void set_kernel_context(KernelContext* ctx) {
|
||||
tls_kernel_context = ctx;
|
||||
}
|
||||
|
||||
inline void clear_kernel_context() {
|
||||
tls_kernel_context = nullptr;
|
||||
}
|
||||
|
||||
struct KernelContextGuard {
|
||||
KernelContextGuard(const std::string& name, const std::string& stack)
|
||||
: owned_context_(name, stack) {
|
||||
set_kernel_context(&owned_context_);
|
||||
}
|
||||
~KernelContextGuard() {
|
||||
clear_kernel_context();
|
||||
}
|
||||
|
||||
// Delete copy constructor and copy assignment operator
|
||||
KernelContextGuard(const KernelContextGuard&) = delete;
|
||||
KernelContextGuard& operator=(const KernelContextGuard&) = delete;
|
||||
|
||||
KernelContextGuard(KernelContextGuard&&) = default;
|
||||
KernelContextGuard& operator=(KernelContextGuard&&) = delete;
|
||||
|
||||
private:
|
||||
KernelContext owned_context_;
|
||||
};
|
||||
|
||||
} // namespace torch::aot_inductor
|
||||
Loading…
Reference in New Issue
Block a user