mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
We add a thread_local KernelContext object so Strobelight (and other potential profilers) can read the stack trace information of the running kernel.
This will bring extra overhead, so we guard this behind the `cpp.enable_kernel_profile` flag.
Example output code:
```cpp
#include <torch/csrc/inductor/aoti_runtime/kernel_context_tls.h>
namespace torch::aot_inductor {
thread_local KernelContext* tls_kernel_context = nullptr;
}
// Other code .....
void AOTInductorModel::run_impl(
AtenTensorHandle*
input_handles, // array of input AtenTensorHandle; handles
// are stolen; the array itself is borrowed
AtenTensorHandle*
output_handles, // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor
) {
__check_inputs_outputs(input_handles, output_handles);
auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 4);
auto arg2_1 = std::move(inputs[0]);
auto arg3_1 = std::move(inputs[1]);
auto arg4_1 = std::move(inputs[2]);
auto arg5_1 = std::move(inputs[3]);
[[maybe_unused]] auto& fc1_weight = constants_->at(0);
[[maybe_unused]] auto& fc1_bias = constants_->at(1);
inputs.clear();
[[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
static constexpr int64_t int_array_0[] = {8L, 16L};
static constexpr int64_t int_array_1[] = {16L, 1L};
AtenTensorHandle buf0_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf0_handle));
RAIIAtenTensorHandle buf0(buf0_handle);
// Topologically Sorted Source Nodes: [linear], Original ATen: [aten.t, aten.addmm]
// [Provenance debug handles] aoti_torch_cpu_addmm_out:1
static constexpr int64_t int_array_2[] = {10L, 16L};
static constexpr int64_t int_array_3[] = {1L, 10L};
{
KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 829, in forward
x = self.fc1(x)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/linear.py", line 134, in forward
return F.linear(input, self.weight, self.bias)
)");
RAIIAtenRecordFunctionHandle record_aoti_torch_cpu_addmm_out_("aoti_torch_cpu_addmm_out", nullptr);
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(buf0, fc1_bias, arg2_1, wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper(fc1_weight, 2, int_array_2, int_array_3, 0L)), 1L, 1L));
}
arg2_1.reset();
auto buf1 = std::move(buf0); // reuse
static constexpr int64_t int_array_4[] = {10L, 20L};
static constexpr int64_t int_array_5[] = {20L, 1L};
AtenTensorHandle buf2_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_4, int_array_5, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf2_handle));
RAIIAtenTensorHandle buf2(buf2_handle);
// [Provenance debug handles] cpp_fused_mul_relu_sigmoid_0:2
{
KernelContextGuard _ctx("cpp_fused_mul_relu_sigmoid_0", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 831, in forward
x = self.sigmoid(x)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/activation.py", line 359, in forward
return torch.sigmoid(input)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 830, in forward
x = self.relu(x)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/activation.py", line 144, in forward
return F.relu(input, inplace=self.inplace)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 832, in forward
d = a * 3.14
)");
cpp_fused_mul_relu_sigmoid_0((float*)(buf1.data_ptr()), (const float*)(arg3_1.data_ptr()), (float*)(buf2.data_ptr()));
}
arg3_1.reset();
static constexpr int64_t int_array_6[] = {10L, 30L};
static constexpr int64_t int_array_7[] = {30L, 1L};
AtenTensorHandle buf3_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_6, int_array_7, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf3_handle));
RAIIAtenTensorHandle buf3(buf3_handle);
// Topologically Sorted Source Nodes: [mul, addmm], Original ATen: [aten.mul, aten.addmm]
// [Provenance debug handles] aoti_torch_cpu_addmm_out:3
{
KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 833, in forward
y = torch.addmm(c, d, b)
)");
RAIIAtenRecordFunctionHandle record_aoti_torch_cpu_addmm_out_("aoti_torch_cpu_addmm_out", nullptr);
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(buf3, arg5_1, buf2, arg4_1, 1L, 1L));
}
arg4_1.reset();
arg5_1.reset();
buf2.reset();
auto buf4 = std::move(buf3); // reuse
// [Provenance debug handles] cpp_fused_gelu_1:4
{
KernelContextGuard _ctx("cpp_fused_gelu_1", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 834, in forward
z = torch.nn.functional.gelu(y)
)");
cpp_fused_gelu_1((float*)(buf4.data_ptr()));
}
output_handles[0] = buf1.release();
output_handles[1] = buf4.release();
} // AOTInductorModel::run_impl
```
Test Plan:
```
buck run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing -- -r stack_traces
```
Rollback Plan:
Differential Revision: D78436007
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160539
Approved by: https://github.com/yiming0416
|
||
|---|---|---|
| .. | ||
| aoti_runtime | ||
| cuda | ||
| cutedsl | ||
| mtia | ||
| rocm | ||
| xpu | ||
| __init__.py | ||
| aoti_hipify_utils.py | ||
| block_analysis.py | ||
| common.py | ||
| cpp_bmm_template.py | ||
| cpp_flex_attention_template.py | ||
| cpp_gemm_template.py | ||
| cpp_grouped_gemm_template.py | ||
| cpp_micro_gemm.py | ||
| cpp_template_kernel.py | ||
| cpp_template.py | ||
| cpp_utils.py | ||
| cpp_wrapper_cpu_array_ref.py | ||
| cpp_wrapper_cpu.py | ||
| cpp_wrapper_gpu.py | ||
| cpp_wrapper_mps.py | ||
| cpp.py | ||
| cpu_device_op_overrides.py | ||
| cuda_combined_scheduling.py | ||
| debug_utils.py | ||
| halide.py | ||
| memory_planning.py | ||
| mps_device_op_overrides.py | ||
| mps.py | ||
| multi_kernel.py | ||
| python_wrapper_mtia.py | ||
| segmented_tree.py | ||
| simd_kernel_features.py | ||
| simd.py | ||
| subgraph.py | ||
| triton_combo_kernel.py | ||
| triton_split_scan.py | ||
| triton_utils.py | ||
| triton.py | ||
| wrapper_fxir.py | ||
| wrapper.py | ||