pytorch/test/inductor/test_utils.py
Shangdi Yu d2eff5d454 Add python stack trace to AOTI generated code (#160539)
Summary:
We add a thread_local KernelContext object so Strobelight (and other potential profilers) can read the stack trace information of the running kernel.

This will bring extra overhead, so we guard this behind the `cpp.enable_kernel_profile` flag.

Example output code:

```cpp
#include <torch/csrc/inductor/aoti_runtime/kernel_context_tls.h>
namespace torch::aot_inductor {
thread_local KernelContext* tls_kernel_context = nullptr;
}
// Other code .....
void AOTInductorModel::run_impl(
    AtenTensorHandle*
        input_handles, // array of input AtenTensorHandle; handles
                        // are stolen; the array itself is borrowed
    AtenTensorHandle*
        output_handles, // array for writing output AtenTensorHandle; handles
                        // will be stolen by the caller; the array itself is
                        // borrowed
    DeviceStreamType stream,
    AOTIProxyExecutorHandle proxy_executor
) {
    __check_inputs_outputs(input_handles, output_handles);
    auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 4);
    auto arg2_1 = std::move(inputs[0]);
    auto arg3_1 = std::move(inputs[1]);
    auto arg4_1 = std::move(inputs[2]);
    auto arg5_1 = std::move(inputs[3]);
    [[maybe_unused]] auto& fc1_weight = constants_->at(0);
    [[maybe_unused]] auto& fc1_bias = constants_->at(1);
    inputs.clear();
    [[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
    static constexpr int64_t int_array_0[] = {8L, 16L};
    static constexpr int64_t int_array_1[] = {16L, 1L};
    AtenTensorHandle buf0_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf0_handle));
    RAIIAtenTensorHandle buf0(buf0_handle);
    // Topologically Sorted Source Nodes: [linear], Original ATen: [aten.t, aten.addmm]
    // [Provenance debug handles] aoti_torch_cpu_addmm_out:1
    static constexpr int64_t int_array_2[] = {10L, 16L};
    static constexpr int64_t int_array_3[] = {1L, 10L};
    {
    KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 829, in forward
    x = self.fc1(x)
  File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/linear.py", line 134, in forward
    return F.linear(input, self.weight, self.bias)
)");
    RAIIAtenRecordFunctionHandle record_aoti_torch_cpu_addmm_out_("aoti_torch_cpu_addmm_out", nullptr);
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(buf0, fc1_bias, arg2_1, wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper(fc1_weight, 2, int_array_2, int_array_3, 0L)), 1L, 1L));
    }
    arg2_1.reset();
    auto buf1 = std::move(buf0);  // reuse
    static constexpr int64_t int_array_4[] = {10L, 20L};
    static constexpr int64_t int_array_5[] = {20L, 1L};
    AtenTensorHandle buf2_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_4, int_array_5, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf2_handle));
    RAIIAtenTensorHandle buf2(buf2_handle);
    // [Provenance debug handles] cpp_fused_mul_relu_sigmoid_0:2
    {
    KernelContextGuard _ctx("cpp_fused_mul_relu_sigmoid_0", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 831, in forward
    x = self.sigmoid(x)
  File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/activation.py", line 359, in forward
    return torch.sigmoid(input)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 830, in forward
    x = self.relu(x)
  File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/activation.py", line 144, in forward
    return F.relu(input, inplace=self.inplace)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 832, in forward
    d = a * 3.14
)");
    cpp_fused_mul_relu_sigmoid_0((float*)(buf1.data_ptr()), (const float*)(arg3_1.data_ptr()), (float*)(buf2.data_ptr()));
    }
    arg3_1.reset();
    static constexpr int64_t int_array_6[] = {10L, 30L};
    static constexpr int64_t int_array_7[] = {30L, 1L};
    AtenTensorHandle buf3_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_6, int_array_7, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf3_handle));
    RAIIAtenTensorHandle buf3(buf3_handle);
    // Topologically Sorted Source Nodes: [mul, addmm], Original ATen: [aten.mul, aten.addmm]
    // [Provenance debug handles] aoti_torch_cpu_addmm_out:3
    {
    KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 833, in forward
    y = torch.addmm(c, d, b)
)");
    RAIIAtenRecordFunctionHandle record_aoti_torch_cpu_addmm_out_("aoti_torch_cpu_addmm_out", nullptr);
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(buf3, arg5_1, buf2, arg4_1, 1L, 1L));
    }
    arg4_1.reset();
    arg5_1.reset();
    buf2.reset();
    auto buf4 = std::move(buf3);  // reuse
    // [Provenance debug handles] cpp_fused_gelu_1:4
    {
    KernelContextGuard _ctx("cpp_fused_gelu_1", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 834, in forward
    z = torch.nn.functional.gelu(y)
)");
    cpp_fused_gelu_1((float*)(buf4.data_ptr()));
    }
    output_handles[0] = buf1.release();
    output_handles[1] = buf4.release();
} // AOTInductorModel::run_impl
```

Test Plan:
```
buck run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing -- -r  stack_traces
```

Rollback Plan:

Differential Revision: D78436007

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160539
Approved by: https://github.com/yiming0416
2025-10-29 22:47:52 +00:00

224 lines
7.9 KiB
Python

# Owner(s): ["module: inductor"]
import unittest
from sympy import Symbol, sympify
import torch
from torch._inductor.fx_utils import count_flops_fx, countable_fx
from torch._inductor.utils import get_device_tflops, sympy_str, sympy_subs
from torch._inductor.virtualized import V
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
)
from torch.testing._internal.common_utils import run_tests, TestCase
class TestUtils(TestCase):
def test_zip_schema(self):
def foo(x: torch.Tensor) -> None:
pass
result = torch.library.custom_op("mylib::foo", foo, mutates_args={"x"})
schema = result._opoverload._schema
g = torch.tensor([11, 2])
found = False
for arg, val in torch._library.utils.zip_schema(schema, [], {"x": g}):
if arg.name == "x":
found = True
self.assertTrue(found)
found = False
for arg, val in torch._library.utils.zip_schema(schema, [g], {}):
if arg.name == "x":
found = True
self.assertTrue(found)
def testSympySubs(self):
# integer and nonnegetaive attributes are preserved.
expr = Symbol("x")
result = sympy_subs(expr, {expr: "y"})
self.assertEqual(result.name, "y")
self.assertEqual(result.is_integer, None)
self.assertEqual(result.is_nonnegative, None)
expr = Symbol("x", integer=True, nonnegative=False)
result = sympy_subs(expr, {expr: "y"})
self.assertEqual(result.name, "y")
self.assertEqual(result.is_integer, True)
self.assertEqual(result.is_nonnegative, False)
# invalid replacement.
expr = Symbol("x", integer=True)
result = sympy_subs(expr, {Symbol("x"): Symbol("y")})
self.assertEqual(result.name, "x")
# valid replacement since properties match.
expr = Symbol("x", integer=True)
result = sympy_subs(expr, {Symbol("x", integer=True): Symbol("y")})
self.assertEqual(result.name, "y")
# invalid replacement.
expr = Symbol("x", integer=None)
result = sympy_subs(expr, {Symbol("x", integer=False): Symbol("y")})
self.assertEqual(result.name, "x")
# replaced can't be string
self.assertRaises(AssertionError, sympy_subs, expr, {"x": "y"})
# replaced can be an expression
expr = Symbol("x")
expr = abs(expr)
self.assertEqual(expr.is_integer, None)
self.assertEqual(expr.is_nonnegative, None)
# replace abs(x) with y
# propagate abs(x) sympy properties.
result = sympy_subs(expr, {expr: Symbol("y")})
self.assertEqual(result.name, "y")
self.assertEqual(result.is_integer, None)
self.assertEqual(result.is_nonnegative, None)
def test_sympy_str(self):
self.assertEqual(sympy_str(sympify("a+b+c")), "a + b + c")
self.assertEqual(sympy_str(sympify("a*b+c")), "c + a * b")
self.assertEqual(sympy_str(sympify("a+b*(c+d)")), "a + b * (c + d)")
self.assertEqual(sympy_str(sympify("(a+b)*(c+d)")), "(a + b) * (c + d)")
self.assertEqual(sympy_str(sympify("-a")), "-a")
self.assertEqual(sympy_str(sympify("a-b")), "a - b")
self.assertEqual(sympy_str(sympify("a+-b")), "a - b")
def test_flops_fx(self):
def create_fx_node(
aten, op_overload: torch._ops.OpOverload, args, kwargs
) -> tuple[torch.fx.Node, torch.fx.Node]:
node1 = torch.fx.Node(
graph=torch.fx.Graph(),
name="",
op="call_function",
target=aten,
args=args,
kwargs=kwargs,
)
# name: str = aten.overloads()[0]
# if aten == torch.ops.aten.addmm:
# name = "default"
# print(aten)
# print(aten.overloads())
# print(name)
# op_overload: torch._ops.OpOverload = getattr(aten, name)
node2 = torch.fx.Node(
graph=torch.fx.Graph(),
name="",
op="call_function",
target=op_overload,
args=args,
kwargs=kwargs,
)
return node1, node2
with V.set_fake_mode(
torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
):
trues = [
(
torch.ops.aten.addmm,
torch.ops.aten.addmm.default,
(torch.Tensor(4, 4), torch.Tensor(4, 5), torch.Tensor(5, 4)),
{},
),
(
torch.ops.aten.bmm,
torch.ops.aten.bmm.default,
(torch.Tensor(10, 4, 5), torch.Tensor(10, 5, 4)),
{},
),
(
torch.ops.aten.mm,
torch.ops.aten.mm.default,
(torch.Tensor(2, 3), torch.Tensor(3, 2)),
{},
),
(
torch.ops.aten.convolution,
torch.ops.aten.convolution.default,
(
torch.Tensor(2, 2, 3),
torch.Tensor(2, 2, 2),
torch.Tensor(2),
(1,),
(0,),
(1,),
True,
(0,),
1,
),
{},
),
(
torch.ops.aten._convolution,
torch.ops.aten._convolution.deprecated,
(
torch.Tensor(2, 2, 2),
torch.Tensor(2, 2, 2),
torch.Tensor(2),
(1,),
(0,),
(1,),
True,
(0,),
1,
False,
True,
False,
),
{},
),
]
# we don't support pointwise ops
falses = [
(
torch.ops.aten.add,
torch.ops.aten.add.Tensor,
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
{},
),
(
torch.ops.aten.mul,
torch.ops.aten.mul.Tensor,
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
{},
),
]
for t, t2, args, kwargs in trues:
fx_node_1, fx_node_2 = create_fx_node(t, t2, args, kwargs)
self.assertTrue(
countable_fx(fx_node_1), f"Expected true {t}: {fx_node_1}"
)
self.assertTrue(
countable_fx(fx_node_2), f"Expected true {t}: {fx_node_2}"
)
self.assertNotEqual(count_flops_fx(fx_node_1), None)
self.assertNotEqual(count_flops_fx(fx_node_2), None)
for f, f2, args, kwargs in falses:
fx_node_1, fx_node_2 = create_fx_node(f, f2, args, kwargs)
self.assertFalse(
countable_fx(fx_node_1), f"Expected false {f}: {fx_node_1}"
)
self.assertFalse(
countable_fx(fx_node_2), f"Expected false {f}: {fx_node_2}"
)
@unittest.skipIf(not torch.cuda.is_available(), "skip if no device")
@dtypes(torch.float16, torch.bfloat16, torch.float32)
def test_get_device_tflops(self, dtype):
ret = get_device_tflops(dtype)
self.assertTrue(type(ret) is float)
instantiate_device_type_tests(TestUtils, globals(), allow_xpu=True)
if __name__ == "__main__":
run_tests()