mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add python stack trace to AOTI generated code (#160539)
Summary:
We add a thread_local KernelContext object so Strobelight (and other potential profilers) can read the stack trace information of the running kernel.
This will bring extra overhead, so we guard this behind the `cpp.enable_kernel_profile` flag.
Example output code:
```cpp
#include <torch/csrc/inductor/aoti_runtime/kernel_context_tls.h>
namespace torch::aot_inductor {
thread_local KernelContext* tls_kernel_context = nullptr;
}
// Other code .....
void AOTInductorModel::run_impl(
AtenTensorHandle*
input_handles, // array of input AtenTensorHandle; handles
// are stolen; the array itself is borrowed
AtenTensorHandle*
output_handles, // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor
) {
__check_inputs_outputs(input_handles, output_handles);
auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 4);
auto arg2_1 = std::move(inputs[0]);
auto arg3_1 = std::move(inputs[1]);
auto arg4_1 = std::move(inputs[2]);
auto arg5_1 = std::move(inputs[3]);
[[maybe_unused]] auto& fc1_weight = constants_->at(0);
[[maybe_unused]] auto& fc1_bias = constants_->at(1);
inputs.clear();
[[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
static constexpr int64_t int_array_0[] = {8L, 16L};
static constexpr int64_t int_array_1[] = {16L, 1L};
AtenTensorHandle buf0_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf0_handle));
RAIIAtenTensorHandle buf0(buf0_handle);
// Topologically Sorted Source Nodes: [linear], Original ATen: [aten.t, aten.addmm]
// [Provenance debug handles] aoti_torch_cpu_addmm_out:1
static constexpr int64_t int_array_2[] = {10L, 16L};
static constexpr int64_t int_array_3[] = {1L, 10L};
{
KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 829, in forward
x = self.fc1(x)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/linear.py", line 134, in forward
return F.linear(input, self.weight, self.bias)
)");
RAIIAtenRecordFunctionHandle record_aoti_torch_cpu_addmm_out_("aoti_torch_cpu_addmm_out", nullptr);
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(buf0, fc1_bias, arg2_1, wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper(fc1_weight, 2, int_array_2, int_array_3, 0L)), 1L, 1L));
}
arg2_1.reset();
auto buf1 = std::move(buf0); // reuse
static constexpr int64_t int_array_4[] = {10L, 20L};
static constexpr int64_t int_array_5[] = {20L, 1L};
AtenTensorHandle buf2_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_4, int_array_5, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf2_handle));
RAIIAtenTensorHandle buf2(buf2_handle);
// [Provenance debug handles] cpp_fused_mul_relu_sigmoid_0:2
{
KernelContextGuard _ctx("cpp_fused_mul_relu_sigmoid_0", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 831, in forward
x = self.sigmoid(x)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/activation.py", line 359, in forward
return torch.sigmoid(input)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 830, in forward
x = self.relu(x)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/activation.py", line 144, in forward
return F.relu(input, inplace=self.inplace)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 832, in forward
d = a * 3.14
)");
cpp_fused_mul_relu_sigmoid_0((float*)(buf1.data_ptr()), (const float*)(arg3_1.data_ptr()), (float*)(buf2.data_ptr()));
}
arg3_1.reset();
static constexpr int64_t int_array_6[] = {10L, 30L};
static constexpr int64_t int_array_7[] = {30L, 1L};
AtenTensorHandle buf3_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_6, int_array_7, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf3_handle));
RAIIAtenTensorHandle buf3(buf3_handle);
// Topologically Sorted Source Nodes: [mul, addmm], Original ATen: [aten.mul, aten.addmm]
// [Provenance debug handles] aoti_torch_cpu_addmm_out:3
{
KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 833, in forward
y = torch.addmm(c, d, b)
)");
RAIIAtenRecordFunctionHandle record_aoti_torch_cpu_addmm_out_("aoti_torch_cpu_addmm_out", nullptr);
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(buf3, arg5_1, buf2, arg4_1, 1L, 1L));
}
arg4_1.reset();
arg5_1.reset();
buf2.reset();
auto buf4 = std::move(buf3); // reuse
// [Provenance debug handles] cpp_fused_gelu_1:4
{
KernelContextGuard _ctx("cpp_fused_gelu_1", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 834, in forward
z = torch.nn.functional.gelu(y)
)");
cpp_fused_gelu_1((float*)(buf4.data_ptr()));
}
output_handles[0] = buf1.release();
output_handles[1] = buf4.release();
} // AOTInductorModel::run_impl
```
Test Plan:
```
buck run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing -- -r stack_traces
```
Rollback Plan:
Differential Revision: D78436007
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160539
Approved by: https://github.com/yiming0416
This commit is contained in:
parent
972030fe2e
commit
d2eff5d454
|
|
@ -7,6 +7,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
@ -24,7 +25,7 @@ from torch._inductor.debug import (
|
||||||
)
|
)
|
||||||
from torch._inductor.fx_passes.post_grad import post_grad_passes
|
from torch._inductor.fx_passes.post_grad import post_grad_passes
|
||||||
from torch._inductor.test_case import run_tests, TestCase
|
from torch._inductor.test_case import run_tests, TestCase
|
||||||
from torch._inductor.utils import run_and_get_code
|
from torch._inductor.utils import run_and_get_code, run_and_get_cpp_code
|
||||||
from torch._inductor.virtualized import V
|
from torch._inductor.virtualized import V
|
||||||
from torch.testing._internal.common_utils import IS_MACOS
|
from torch.testing._internal.common_utils import IS_MACOS
|
||||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||||
|
|
@ -32,8 +33,12 @@ from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .test_aot_inductor_utils import AOTIRunnerUtil
|
from .test_aot_inductor_utils import AOTIRunnerUtil
|
||||||
|
from .test_torchinductor import copy_tests
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from test_aot_inductor_utils import AOTIRunnerUtil
|
from test_aot_inductor_utils import AOTIRunnerUtil
|
||||||
|
from test_torchinductor import (
|
||||||
|
copy_tests, # @manual=fbcode//caffe2/test/inductor:test_inductor-library
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
trace_log = logging.getLogger("torch.__trace")
|
trace_log = logging.getLogger("torch.__trace")
|
||||||
|
|
@ -806,5 +811,135 @@ class TestProvenanceTracingStackTraces(TestCase):
|
||||||
self.assertTrue("aoti_torch_cpu_convolution" in keys)
|
self.assertTrue("aoti_torch_cpu_convolution" in keys)
|
||||||
|
|
||||||
|
|
||||||
|
class ProvenanceTracingKernelContextTemplate:
|
||||||
|
def test_jit_inductor_with_flag(self):
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = torch.nn.Linear(10, 16)
|
||||||
|
self.relu = torch.nn.ReLU()
|
||||||
|
self.sigmoid = torch.nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x, a, b, c):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.sigmoid(x)
|
||||||
|
d = a * 3.14
|
||||||
|
y = torch.addmm(c, d, b)
|
||||||
|
z = torch.nn.functional.gelu(y)
|
||||||
|
return x, z
|
||||||
|
|
||||||
|
model = Model().to(self.device)
|
||||||
|
x = torch.randn(8, 10).to(self.device)
|
||||||
|
a = torch.randn(10, 20).to(self.device)
|
||||||
|
b = torch.randn(20, 30).to(self.device)
|
||||||
|
c = torch.randn(10, 30).to(self.device)
|
||||||
|
example_inputs = (x, a, b, c)
|
||||||
|
|
||||||
|
with config.patch(
|
||||||
|
{
|
||||||
|
"cpp.enable_kernel_profile": True,
|
||||||
|
}
|
||||||
|
):
|
||||||
|
torch.compile(model)(*example_inputs)
|
||||||
|
|
||||||
|
@unittest.skipIf(sys.platform == "darwin", "Different kernel names on MacOS")
|
||||||
|
def test_aoti_python_stack_traces(self):
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = torch.nn.Linear(10, 16)
|
||||||
|
self.relu = torch.nn.ReLU()
|
||||||
|
self.sigmoid = torch.nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x, a, b, c):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.sigmoid(x)
|
||||||
|
d = a * 3.14
|
||||||
|
y = torch.addmm(c, d, b)
|
||||||
|
z = torch.nn.functional.gelu(y)
|
||||||
|
return x, z
|
||||||
|
|
||||||
|
x = torch.randn(8, 10).to(self.device)
|
||||||
|
a = torch.randn(10, 20).to(self.device)
|
||||||
|
b = torch.randn(20, 30).to(self.device)
|
||||||
|
c = torch.randn(10, 30).to(self.device)
|
||||||
|
example_inputs = (x, a, b, c)
|
||||||
|
model = Model().to(self.device)
|
||||||
|
|
||||||
|
ep = torch.export.export(model, example_inputs)
|
||||||
|
_, code = run_and_get_cpp_code(torch._inductor.aoti_compile_and_package, ep)
|
||||||
|
|
||||||
|
self.assertTrue("KernelContextGuard" not in code)
|
||||||
|
|
||||||
|
with config.patch(
|
||||||
|
{
|
||||||
|
"trace.provenance_tracking_level": 1,
|
||||||
|
"cpp.enable_kernel_profile": True,
|
||||||
|
}
|
||||||
|
):
|
||||||
|
package_path, code = run_and_get_cpp_code(
|
||||||
|
torch._inductor.aoti_compile_and_package, ep
|
||||||
|
)
|
||||||
|
|
||||||
|
FileCheck().check(
|
||||||
|
"#include <torch/csrc/inductor/aoti_runtime/kernel_context_tls.h>"
|
||||||
|
).check("thread_local KernelContext* tls_kernel_context = nullptr;").run(
|
||||||
|
code
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.device == "cuda":
|
||||||
|
FileCheck().check(
|
||||||
|
"""KernelContextGuard _ctx("aoti_torch_cuda_mm_out", R"("""
|
||||||
|
).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda_mm_out(").check(
|
||||||
|
"""KernelContextGuard _ctx("triton_poi_fused_addmm_relu_sigmoid_0", R"("""
|
||||||
|
).check("call_triton_poi_fused_addmm_relu_sigmoid_0(").check(
|
||||||
|
"""KernelContextGuard _ctx("triton_poi_fused_mul_1", R"("""
|
||||||
|
).check("call_triton_poi_fused_mul_1(").check(
|
||||||
|
"""KernelContextGuard _ctx("aoti_torch_cuda_mm_out", R"("""
|
||||||
|
).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda_mm_out(").check(
|
||||||
|
""" KernelContextGuard _ctx("triton_poi_fused_addmm_gelu_2", R"("""
|
||||||
|
).check("call_triton_poi_fused_addmm_gelu_2(").run(code)
|
||||||
|
else:
|
||||||
|
FileCheck().check(
|
||||||
|
"""KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"("""
|
||||||
|
).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(").check(
|
||||||
|
"""KernelContextGuard _ctx("cpp_fused_mul_relu_sigmoid_0", R"("""
|
||||||
|
).check("cpp_fused_mul_relu_sigmoid_0(").check(
|
||||||
|
"""KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"("""
|
||||||
|
).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(").check(
|
||||||
|
""" KernelContextGuard _ctx("cpp_fused_gelu_1", R"("""
|
||||||
|
).check("cpp_fused_gelu_1(").run(code)
|
||||||
|
|
||||||
|
compiled_model = torch._inductor.aoti_load_package(package_path)
|
||||||
|
result = compiled_model(*example_inputs)
|
||||||
|
self.assertEqual(result, model(*example_inputs))
|
||||||
|
|
||||||
|
|
||||||
|
class TestProvenanceTracingKernelContextCpu(TestCase):
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
copy_tests(
|
||||||
|
ProvenanceTracingKernelContextTemplate,
|
||||||
|
TestProvenanceTracingKernelContextCpu,
|
||||||
|
"cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
|
||||||
|
@unittest.skipIf(not torch.cuda.is_available(), "No CUDA")
|
||||||
|
class TestProvenanceTracingKernelContextGpu(TestCase):
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
copy_tests(
|
||||||
|
ProvenanceTracingKernelContextTemplate,
|
||||||
|
TestProvenanceTracingKernelContextGpu,
|
||||||
|
"cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,7 @@ class TestUtils(TestCase):
|
||||||
|
|
||||||
def test_flops_fx(self):
|
def test_flops_fx(self):
|
||||||
def create_fx_node(
|
def create_fx_node(
|
||||||
aten: torch._ops.OpOverloadPacket, args, kwargs
|
aten, op_overload: torch._ops.OpOverload, args, kwargs
|
||||||
) -> tuple[torch.fx.Node, torch.fx.Node]:
|
) -> tuple[torch.fx.Node, torch.fx.Node]:
|
||||||
node1 = torch.fx.Node(
|
node1 = torch.fx.Node(
|
||||||
graph=torch.fx.Graph(),
|
graph=torch.fx.Graph(),
|
||||||
|
|
@ -101,8 +101,13 @@ class TestUtils(TestCase):
|
||||||
args=args,
|
args=args,
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
name: str = aten.overloads()[0]
|
# name: str = aten.overloads()[0]
|
||||||
op_overload: torch._ops.OpOverload = getattr(aten, name)
|
# if aten == torch.ops.aten.addmm:
|
||||||
|
# name = "default"
|
||||||
|
# print(aten)
|
||||||
|
# print(aten.overloads())
|
||||||
|
# print(name)
|
||||||
|
# op_overload: torch._ops.OpOverload = getattr(aten, name)
|
||||||
node2 = torch.fx.Node(
|
node2 = torch.fx.Node(
|
||||||
graph=torch.fx.Graph(),
|
graph=torch.fx.Graph(),
|
||||||
name="",
|
name="",
|
||||||
|
|
@ -119,17 +124,25 @@ class TestUtils(TestCase):
|
||||||
trues = [
|
trues = [
|
||||||
(
|
(
|
||||||
torch.ops.aten.addmm,
|
torch.ops.aten.addmm,
|
||||||
|
torch.ops.aten.addmm.default,
|
||||||
(torch.Tensor(4, 4), torch.Tensor(4, 5), torch.Tensor(5, 4)),
|
(torch.Tensor(4, 4), torch.Tensor(4, 5), torch.Tensor(5, 4)),
|
||||||
{},
|
{},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
torch.ops.aten.bmm,
|
torch.ops.aten.bmm,
|
||||||
|
torch.ops.aten.bmm.default,
|
||||||
(torch.Tensor(10, 4, 5), torch.Tensor(10, 5, 4)),
|
(torch.Tensor(10, 4, 5), torch.Tensor(10, 5, 4)),
|
||||||
{},
|
{},
|
||||||
),
|
),
|
||||||
(torch.ops.aten.mm, (torch.Tensor(2, 3), torch.Tensor(3, 2)), {}),
|
(
|
||||||
|
torch.ops.aten.mm,
|
||||||
|
torch.ops.aten.mm.default,
|
||||||
|
(torch.Tensor(2, 3), torch.Tensor(3, 2)),
|
||||||
|
{},
|
||||||
|
),
|
||||||
(
|
(
|
||||||
torch.ops.aten.convolution,
|
torch.ops.aten.convolution,
|
||||||
|
torch.ops.aten.convolution.default,
|
||||||
(
|
(
|
||||||
torch.Tensor(2, 2, 3),
|
torch.Tensor(2, 2, 3),
|
||||||
torch.Tensor(2, 2, 2),
|
torch.Tensor(2, 2, 2),
|
||||||
|
|
@ -145,6 +158,7 @@ class TestUtils(TestCase):
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
torch.ops.aten._convolution,
|
torch.ops.aten._convolution,
|
||||||
|
torch.ops.aten._convolution.deprecated,
|
||||||
(
|
(
|
||||||
torch.Tensor(2, 2, 2),
|
torch.Tensor(2, 2, 2),
|
||||||
torch.Tensor(2, 2, 2),
|
torch.Tensor(2, 2, 2),
|
||||||
|
|
@ -166,17 +180,19 @@ class TestUtils(TestCase):
|
||||||
falses = [
|
falses = [
|
||||||
(
|
(
|
||||||
torch.ops.aten.add,
|
torch.ops.aten.add,
|
||||||
|
torch.ops.aten.add.Tensor,
|
||||||
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
|
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
|
||||||
{},
|
{},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
torch.ops.aten.mul,
|
torch.ops.aten.mul,
|
||||||
|
torch.ops.aten.mul.Tensor,
|
||||||
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
|
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
|
||||||
{},
|
{},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
for t, args, kwargs in trues:
|
for t, t2, args, kwargs in trues:
|
||||||
fx_node_1, fx_node_2 = create_fx_node(t, args, kwargs)
|
fx_node_1, fx_node_2 = create_fx_node(t, t2, args, kwargs)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
countable_fx(fx_node_1), f"Expected true {t}: {fx_node_1}"
|
countable_fx(fx_node_1), f"Expected true {t}: {fx_node_1}"
|
||||||
)
|
)
|
||||||
|
|
@ -185,8 +201,8 @@ class TestUtils(TestCase):
|
||||||
)
|
)
|
||||||
self.assertNotEqual(count_flops_fx(fx_node_1), None)
|
self.assertNotEqual(count_flops_fx(fx_node_1), None)
|
||||||
self.assertNotEqual(count_flops_fx(fx_node_2), None)
|
self.assertNotEqual(count_flops_fx(fx_node_2), None)
|
||||||
for f, args, kwargs in falses:
|
for f, f2, args, kwargs in falses:
|
||||||
fx_node_1, fx_node_2 = create_fx_node(f, args, kwargs)
|
fx_node_1, fx_node_2 = create_fx_node(f, f2, args, kwargs)
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
countable_fx(fx_node_1), f"Expected false {f}: {fx_node_1}"
|
countable_fx(fx_node_1), f"Expected false {f}: {fx_node_1}"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5469,7 +5469,16 @@ class CppScheduling(BaseScheduling):
|
||||||
src_code, self.kernel_group.scheduled_nodes
|
src_code, self.kernel_group.scheduled_nodes
|
||||||
)
|
)
|
||||||
self.codegen_comment(self.kernel_group.scheduled_nodes, kernel_name)
|
self.codegen_comment(self.kernel_group.scheduled_nodes, kernel_name)
|
||||||
|
if config.cpp.enable_kernel_profile:
|
||||||
|
V.graph.wrapper_code.write_kernel_context_guard_begin()
|
||||||
|
V.graph.wrapper_code.write_kernel_context_guard(
|
||||||
|
kernel_name,
|
||||||
|
self.kernel_group.scheduled_nodes, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name)
|
self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name)
|
||||||
|
if config.cpp.enable_kernel_profile:
|
||||||
|
V.graph.wrapper_code.write_kernel_context_guard_end()
|
||||||
|
|
||||||
self.reset_kernel_group()
|
self.reset_kernel_group()
|
||||||
self._set_flush_status(False)
|
self._set_flush_status(False)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from torch.utils._ordered_set import OrderedSet
|
||||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||||
|
|
||||||
from .. import config, cpp_builder, ir
|
from .. import config, cpp_builder, ir
|
||||||
|
from ..ir import ExternKernel
|
||||||
from ..utils import _align, DeferredLineBase, LineContext, normalize_name
|
from ..utils import _align, DeferredLineBase, LineContext, normalize_name
|
||||||
from ..virtualized import V
|
from ..virtualized import V
|
||||||
from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
||||||
|
|
@ -43,6 +44,8 @@ if TYPE_CHECKING:
|
||||||
# At most, the list nesting can go one layer deep.
|
# At most, the list nesting can go one layer deep.
|
||||||
_OUTPUT_ARGS_TYPE = list[Union[Optional[str], list[Optional[str]]]]
|
_OUTPUT_ARGS_TYPE = list[Union[Optional[str], list[Optional[str]]]]
|
||||||
|
|
||||||
|
from ..scheduler import BaseSchedulerNode
|
||||||
|
|
||||||
|
|
||||||
class HasWriteLine(Protocol):
|
class HasWriteLine(Protocol):
|
||||||
def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None: ...
|
def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None: ...
|
||||||
|
|
@ -233,6 +236,18 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
self.header.splice(f"""#include \"{self.model_class_name_suffix}.h\"""")
|
self.header.splice(f"""#include \"{self.model_class_name_suffix}.h\"""")
|
||||||
self.header.splice("\n")
|
self.header.splice("\n")
|
||||||
|
|
||||||
|
if config.cpp.enable_kernel_profile:
|
||||||
|
self.header.splice(
|
||||||
|
"#include <torch/csrc/inductor/aoti_runtime/kernel_context_tls.h>"
|
||||||
|
)
|
||||||
|
self.header.splice(
|
||||||
|
"""
|
||||||
|
namespace torch::aot_inductor {
|
||||||
|
thread_local KernelContext* tls_kernel_context = nullptr;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
def _include_extra_header(self, header: str):
|
def _include_extra_header(self, header: str):
|
||||||
# This is needed for cpp to python dtype conversion
|
# This is needed for cpp to python dtype conversion
|
||||||
self.header.splice(f"#include <{header}>")
|
self.header.splice(f"#include <{header}>")
|
||||||
|
|
@ -1249,7 +1264,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
device: str,
|
device: str,
|
||||||
*,
|
*,
|
||||||
debug_args: Optional[list[str]] = None,
|
debug_args: Optional[list[str]] = None,
|
||||||
debug_handle: Optional[int] = None,
|
stack_traces: Optional[OrderedSet[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""debug_args kwarg allows CppWrapperCpuArrayRef to pass in wrapped arguments in
|
"""debug_args kwarg allows CppWrapperCpuArrayRef to pass in wrapped arguments in
|
||||||
place of args while preserving debug printer output."""
|
place of args while preserving debug printer output."""
|
||||||
|
|
@ -1266,21 +1281,26 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
]
|
]
|
||||||
with debug_printer_manager:
|
with debug_printer_manager:
|
||||||
shim_fn = self.get_c_shim_func_name(kernel, device)
|
shim_fn = self.get_c_shim_func_name(kernel, device)
|
||||||
self.write_provenance_debug_handle(shim_fn, debug_handle)
|
shim_fn_codes = [
|
||||||
shim_fn_codes = (
|
|
||||||
f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));"
|
f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));"
|
||||||
)
|
]
|
||||||
if enable_kernel_profile:
|
if enable_kernel_profile:
|
||||||
debug_handle_str = "" if debug_handle is None else f":{debug_handle}"
|
stack_trace_str = 'R"('
|
||||||
shim_fn_codes = textwrap.dedent(
|
if stack_traces:
|
||||||
f"""
|
for stack_trace in stack_traces:
|
||||||
{{
|
for line in stack_trace.split("\n"):
|
||||||
RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}{debug_handle_str}", nullptr);
|
stack_trace_str += f"\n{line}"
|
||||||
{shim_fn_codes}
|
stack_trace_str += "\n"
|
||||||
}}
|
stack_trace_str += ')"'
|
||||||
"""
|
|
||||||
)
|
shim_fn_codes = [
|
||||||
self.writeline(shim_fn_codes)
|
"{",
|
||||||
|
f"""KernelContextGuard _ctx("{shim_fn}", {stack_trace_str});""",
|
||||||
|
f"""RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}", nullptr);""",
|
||||||
|
shim_fn_codes[0],
|
||||||
|
"}",
|
||||||
|
]
|
||||||
|
self.writelines(shim_fn_codes)
|
||||||
|
|
||||||
def generate_c_shim_extern_kernel_alloc(
|
def generate_c_shim_extern_kernel_alloc(
|
||||||
self, extern_kernel: ir.ExternKernelAlloc, args: list[str]
|
self, extern_kernel: ir.ExternKernelAlloc, args: list[str]
|
||||||
|
|
@ -1373,7 +1393,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
out_view: Optional[str],
|
out_view: Optional[str],
|
||||||
args: list[str],
|
args: list[str],
|
||||||
device: str,
|
device: str,
|
||||||
debug_handle: Optional[int] = None,
|
stack_traces: Optional[OrderedSet[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if out_view:
|
if out_view:
|
||||||
out_name = f"{out}_as_strided"
|
out_name = f"{out}_as_strided"
|
||||||
|
|
@ -1383,7 +1403,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
args.insert(0, out)
|
args.insert(0, out)
|
||||||
|
|
||||||
self.generate_c_shim_extern_kernel_call(
|
self.generate_c_shim_extern_kernel_call(
|
||||||
kernel, args, device, debug_handle=debug_handle
|
kernel, args, device, stack_traces=stack_traces
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_scatter_reduce_enum(self, reduce):
|
def _get_scatter_reduce_enum(self, reduce):
|
||||||
|
|
@ -2897,3 +2917,53 @@ if (!custom_op_wrapper) {
|
||||||
writer.writeline(call_str)
|
writer.writeline(call_str)
|
||||||
|
|
||||||
return tmp_var_name
|
return tmp_var_name
|
||||||
|
|
||||||
|
def write_kernel_context_guard_begin(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
# Beginning of a kernel context guarded block.
|
||||||
|
# The block looks like this:
|
||||||
|
# {
|
||||||
|
# KernelContextGuard _ctx("{kernel_name}", {stack_trace_str});
|
||||||
|
# ... operations...
|
||||||
|
# }
|
||||||
|
self.writeline("{")
|
||||||
|
|
||||||
|
def write_kernel_context_guard_end(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
# End of a kernel context guarded block.
|
||||||
|
self.writeline("}")
|
||||||
|
|
||||||
|
def write_kernel_context_guard(
|
||||||
|
self,
|
||||||
|
kernel_name: str,
|
||||||
|
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
|
||||||
|
):
|
||||||
|
def aggregate_stack_traces(
|
||||||
|
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
|
||||||
|
) -> OrderedSet[str]:
|
||||||
|
if isinstance(node_schedule, list):
|
||||||
|
return functools.reduce(
|
||||||
|
lambda a, b: a | b,
|
||||||
|
[
|
||||||
|
node.node.get_stack_traces()
|
||||||
|
for node in node_schedule
|
||||||
|
if hasattr(node, "node") and node.node
|
||||||
|
],
|
||||||
|
OrderedSet(),
|
||||||
|
)
|
||||||
|
elif isinstance(node_schedule, ExternKernel):
|
||||||
|
return node_schedule.get_stack_traces()
|
||||||
|
else:
|
||||||
|
return OrderedSet()
|
||||||
|
|
||||||
|
stack_trace_str = 'R"('
|
||||||
|
stack_traces = aggregate_stack_traces(node_schedule)
|
||||||
|
|
||||||
|
for stack_trace in stack_traces:
|
||||||
|
for line in stack_trace.split("\n"):
|
||||||
|
stack_trace_str += f"\n{line}"
|
||||||
|
stack_trace_str += "\n"
|
||||||
|
stack_trace_str += ')"'
|
||||||
|
self.writeline(f'KernelContextGuard _ctx("{kernel_name}", {stack_trace_str});')
|
||||||
|
|
|
||||||
|
|
@ -1707,6 +1707,9 @@ class SIMDScheduling(BaseScheduling):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures):
|
def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures):
|
||||||
|
"""
|
||||||
|
Generate code for nodes in kernel_features
|
||||||
|
"""
|
||||||
node_schedule = kernel_features.node_schedule
|
node_schedule = kernel_features.node_schedule
|
||||||
|
|
||||||
tiling, tiling_score = self.get_tiling_and_scores(
|
tiling, tiling_score = self.get_tiling_and_scores(
|
||||||
|
|
@ -1747,7 +1750,15 @@ class SIMDScheduling(BaseScheduling):
|
||||||
node for node in node_schedule if isinstance(node, BaseSchedulerNode)
|
node for node in node_schedule if isinstance(node, BaseSchedulerNode)
|
||||||
]
|
]
|
||||||
self.codegen_comment(base_scheduler_nodes, final_kernel.kernel_name)
|
self.codegen_comment(base_scheduler_nodes, final_kernel.kernel_name)
|
||||||
|
if config.cpp.enable_kernel_profile:
|
||||||
|
V.graph.wrapper_code.write_kernel_context_guard_begin()
|
||||||
|
V.graph.wrapper_code.write_kernel_context_guard(
|
||||||
|
final_kernel.kernel_name,
|
||||||
|
base_scheduler_nodes, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
final_kernel.call_kernel(final_kernel.kernel_name)
|
final_kernel.call_kernel(final_kernel.kernel_name)
|
||||||
|
if config.cpp.enable_kernel_profile:
|
||||||
|
V.graph.wrapper_code.write_kernel_context_guard_end()
|
||||||
|
|
||||||
if config.nan_asserts:
|
if config.nan_asserts:
|
||||||
final_kernel.codegen_nan_check()
|
final_kernel.codegen_nan_check()
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,8 @@ if TYPE_CHECKING:
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
from ..graph import GraphLowering
|
from ..graph import GraphLowering
|
||||||
|
from ..ir import ExternKernel
|
||||||
|
from ..scheduler import BaseSchedulerNode
|
||||||
from .wrapper_fxir import FxConverter
|
from .wrapper_fxir import FxConverter
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -528,6 +530,7 @@ class ExternKernelOutLine(WrapperLine):
|
||||||
node.output_view.codegen_reference() if node.output_view else None,
|
node.output_view.codegen_reference() if node.output_view else None,
|
||||||
args,
|
args,
|
||||||
device,
|
device,
|
||||||
|
self.node.get_stack_traces(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||||
|
|
@ -1554,6 +1557,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||||
out_view: Optional[str],
|
out_view: Optional[str],
|
||||||
args: list[str],
|
args: list[str],
|
||||||
device: str,
|
device: str,
|
||||||
|
stack_traces: Optional[OrderedSet[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# add debug printer code for triton kernel calls at (jit) inductor level
|
# add debug printer code for triton kernel calls at (jit) inductor level
|
||||||
debug_printer_manager = V.graph.wrapper_code.debug_printer
|
debug_printer_manager = V.graph.wrapper_code.debug_printer
|
||||||
|
|
@ -3690,6 +3694,29 @@ class PythonWrapperCodegen(CodeGen):
|
||||||
def can_prove_buffer_has_static_shape(buffer):
|
def can_prove_buffer_has_static_shape(buffer):
|
||||||
return PythonWrapperCodegen.static_shape_for_buffer_or_none(buffer) is not None
|
return PythonWrapperCodegen.static_shape_for_buffer_or_none(buffer) is not None
|
||||||
|
|
||||||
|
def write_kernel_context_guard(
|
||||||
|
self,
|
||||||
|
kernel_name: str,
|
||||||
|
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
def write_kernel_context_guard_begin(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Mark the beginning of kernel context guard
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
def write_kernel_context_guard_end(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Mark the end of kernel context guard
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
|
class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
51
torch/csrc/inductor/aoti_runtime/kernel_context_tls.h
Normal file
51
torch/csrc/inductor/aoti_runtime/kernel_context_tls.h
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace torch::aot_inductor {
|
||||||
|
|
||||||
|
struct KernelContext {
|
||||||
|
std::string kernel_name;
|
||||||
|
std::string python_stack;
|
||||||
|
|
||||||
|
KernelContext(std::string name, std::string stack)
|
||||||
|
: kernel_name(std::move(name)), python_stack(std::move(stack)) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Thread-local pointer
|
||||||
|
extern thread_local KernelContext* tls_kernel_context;
|
||||||
|
|
||||||
|
inline KernelContext* current_kernel_context() {
|
||||||
|
return tls_kernel_context;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void set_kernel_context(KernelContext* ctx) {
|
||||||
|
tls_kernel_context = ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void clear_kernel_context() {
|
||||||
|
tls_kernel_context = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct KernelContextGuard {
|
||||||
|
KernelContextGuard(const std::string& name, const std::string& stack)
|
||||||
|
: owned_context_(name, stack) {
|
||||||
|
set_kernel_context(&owned_context_);
|
||||||
|
}
|
||||||
|
~KernelContextGuard() {
|
||||||
|
clear_kernel_context();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete copy constructor and copy assignment operator
|
||||||
|
KernelContextGuard(const KernelContextGuard&) = delete;
|
||||||
|
KernelContextGuard& operator=(const KernelContextGuard&) = delete;
|
||||||
|
|
||||||
|
KernelContextGuard(KernelContextGuard&&) = default;
|
||||||
|
KernelContextGuard& operator=(KernelContextGuard&&) = delete;
|
||||||
|
|
||||||
|
private:
|
||||||
|
KernelContext owned_context_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace torch::aot_inductor
|
||||||
Loading…
Reference in New Issue
Block a user