mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
We add a thread_local KernelContext object so Strobelight (and other potential profilers) can read the stack trace information of the running kernel.
This will bring extra overhead, so we guard this behind the `cpp.enable_kernel_profile` flag.
Example output code:
```cpp
#include <torch/csrc/inductor/aoti_runtime/kernel_context_tls.h>
namespace torch::aot_inductor {
thread_local KernelContext* tls_kernel_context = nullptr;
}
// Other code .....
void AOTInductorModel::run_impl(
AtenTensorHandle*
input_handles, // array of input AtenTensorHandle; handles
// are stolen; the array itself is borrowed
AtenTensorHandle*
output_handles, // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor
) {
__check_inputs_outputs(input_handles, output_handles);
auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 4);
auto arg2_1 = std::move(inputs[0]);
auto arg3_1 = std::move(inputs[1]);
auto arg4_1 = std::move(inputs[2]);
auto arg5_1 = std::move(inputs[3]);
[[maybe_unused]] auto& fc1_weight = constants_->at(0);
[[maybe_unused]] auto& fc1_bias = constants_->at(1);
inputs.clear();
[[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
static constexpr int64_t int_array_0[] = {8L, 16L};
static constexpr int64_t int_array_1[] = {16L, 1L};
AtenTensorHandle buf0_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf0_handle));
RAIIAtenTensorHandle buf0(buf0_handle);
// Topologically Sorted Source Nodes: [linear], Original ATen: [aten.t, aten.addmm]
// [Provenance debug handles] aoti_torch_cpu_addmm_out:1
static constexpr int64_t int_array_2[] = {10L, 16L};
static constexpr int64_t int_array_3[] = {1L, 10L};
{
KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 829, in forward
x = self.fc1(x)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/linear.py", line 134, in forward
return F.linear(input, self.weight, self.bias)
)");
RAIIAtenRecordFunctionHandle record_aoti_torch_cpu_addmm_out_("aoti_torch_cpu_addmm_out", nullptr);
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(buf0, fc1_bias, arg2_1, wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper(fc1_weight, 2, int_array_2, int_array_3, 0L)), 1L, 1L));
}
arg2_1.reset();
auto buf1 = std::move(buf0); // reuse
static constexpr int64_t int_array_4[] = {10L, 20L};
static constexpr int64_t int_array_5[] = {20L, 1L};
AtenTensorHandle buf2_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_4, int_array_5, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf2_handle));
RAIIAtenTensorHandle buf2(buf2_handle);
// [Provenance debug handles] cpp_fused_mul_relu_sigmoid_0:2
{
KernelContextGuard _ctx("cpp_fused_mul_relu_sigmoid_0", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 831, in forward
x = self.sigmoid(x)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/activation.py", line 359, in forward
return torch.sigmoid(input)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 830, in forward
x = self.relu(x)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/activation.py", line 144, in forward
return F.relu(input, inplace=self.inplace)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 832, in forward
d = a * 3.14
)");
cpp_fused_mul_relu_sigmoid_0((float*)(buf1.data_ptr()), (const float*)(arg3_1.data_ptr()), (float*)(buf2.data_ptr()));
}
arg3_1.reset();
static constexpr int64_t int_array_6[] = {10L, 30L};
static constexpr int64_t int_array_7[] = {30L, 1L};
AtenTensorHandle buf3_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_6, int_array_7, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf3_handle));
RAIIAtenTensorHandle buf3(buf3_handle);
// Topologically Sorted Source Nodes: [mul, addmm], Original ATen: [aten.mul, aten.addmm]
// [Provenance debug handles] aoti_torch_cpu_addmm_out:3
{
KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 833, in forward
y = torch.addmm(c, d, b)
)");
RAIIAtenRecordFunctionHandle record_aoti_torch_cpu_addmm_out_("aoti_torch_cpu_addmm_out", nullptr);
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(buf3, arg5_1, buf2, arg4_1, 1L, 1L));
}
arg4_1.reset();
arg5_1.reset();
buf2.reset();
auto buf4 = std::move(buf3); // reuse
// [Provenance debug handles] cpp_fused_gelu_1:4
{
KernelContextGuard _ctx("cpp_fused_gelu_1", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 834, in forward
z = torch.nn.functional.gelu(y)
)");
cpp_fused_gelu_1((float*)(buf4.data_ptr()));
}
output_handles[0] = buf1.release();
output_handles[1] = buf4.release();
} // AOTInductorModel::run_impl
```
Test Plan:
```
buck run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing -- -r stack_traces
```
Rollback Plan:
Differential Revision: D78436007
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160539
Approved by: https://github.com/yiming0416
224 lines
7.9 KiB
Python
224 lines
7.9 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import unittest
|
|
|
|
from sympy import Symbol, sympify
|
|
|
|
import torch
|
|
from torch._inductor.fx_utils import count_flops_fx, countable_fx
|
|
from torch._inductor.utils import get_device_tflops, sympy_str, sympy_subs
|
|
from torch._inductor.virtualized import V
|
|
from torch.testing._internal.common_device_type import (
|
|
dtypes,
|
|
instantiate_device_type_tests,
|
|
)
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
class TestUtils(TestCase):
|
|
def test_zip_schema(self):
|
|
def foo(x: torch.Tensor) -> None:
|
|
pass
|
|
|
|
result = torch.library.custom_op("mylib::foo", foo, mutates_args={"x"})
|
|
schema = result._opoverload._schema
|
|
g = torch.tensor([11, 2])
|
|
found = False
|
|
for arg, val in torch._library.utils.zip_schema(schema, [], {"x": g}):
|
|
if arg.name == "x":
|
|
found = True
|
|
|
|
self.assertTrue(found)
|
|
|
|
found = False
|
|
for arg, val in torch._library.utils.zip_schema(schema, [g], {}):
|
|
if arg.name == "x":
|
|
found = True
|
|
self.assertTrue(found)
|
|
|
|
def testSympySubs(self):
|
|
# integer and nonnegetaive attributes are preserved.
|
|
expr = Symbol("x")
|
|
result = sympy_subs(expr, {expr: "y"})
|
|
self.assertEqual(result.name, "y")
|
|
self.assertEqual(result.is_integer, None)
|
|
self.assertEqual(result.is_nonnegative, None)
|
|
|
|
expr = Symbol("x", integer=True, nonnegative=False)
|
|
result = sympy_subs(expr, {expr: "y"})
|
|
self.assertEqual(result.name, "y")
|
|
self.assertEqual(result.is_integer, True)
|
|
self.assertEqual(result.is_nonnegative, False)
|
|
|
|
# invalid replacement.
|
|
expr = Symbol("x", integer=True)
|
|
result = sympy_subs(expr, {Symbol("x"): Symbol("y")})
|
|
self.assertEqual(result.name, "x")
|
|
|
|
# valid replacement since properties match.
|
|
expr = Symbol("x", integer=True)
|
|
result = sympy_subs(expr, {Symbol("x", integer=True): Symbol("y")})
|
|
self.assertEqual(result.name, "y")
|
|
|
|
# invalid replacement.
|
|
expr = Symbol("x", integer=None)
|
|
result = sympy_subs(expr, {Symbol("x", integer=False): Symbol("y")})
|
|
self.assertEqual(result.name, "x")
|
|
|
|
# replaced can't be string
|
|
self.assertRaises(AssertionError, sympy_subs, expr, {"x": "y"})
|
|
|
|
# replaced can be an expression
|
|
expr = Symbol("x")
|
|
expr = abs(expr)
|
|
self.assertEqual(expr.is_integer, None)
|
|
self.assertEqual(expr.is_nonnegative, None)
|
|
# replace abs(x) with y
|
|
# propagate abs(x) sympy properties.
|
|
result = sympy_subs(expr, {expr: Symbol("y")})
|
|
self.assertEqual(result.name, "y")
|
|
self.assertEqual(result.is_integer, None)
|
|
self.assertEqual(result.is_nonnegative, None)
|
|
|
|
def test_sympy_str(self):
|
|
self.assertEqual(sympy_str(sympify("a+b+c")), "a + b + c")
|
|
self.assertEqual(sympy_str(sympify("a*b+c")), "c + a * b")
|
|
self.assertEqual(sympy_str(sympify("a+b*(c+d)")), "a + b * (c + d)")
|
|
self.assertEqual(sympy_str(sympify("(a+b)*(c+d)")), "(a + b) * (c + d)")
|
|
self.assertEqual(sympy_str(sympify("-a")), "-a")
|
|
self.assertEqual(sympy_str(sympify("a-b")), "a - b")
|
|
self.assertEqual(sympy_str(sympify("a+-b")), "a - b")
|
|
|
|
def test_flops_fx(self):
|
|
def create_fx_node(
|
|
aten, op_overload: torch._ops.OpOverload, args, kwargs
|
|
) -> tuple[torch.fx.Node, torch.fx.Node]:
|
|
node1 = torch.fx.Node(
|
|
graph=torch.fx.Graph(),
|
|
name="",
|
|
op="call_function",
|
|
target=aten,
|
|
args=args,
|
|
kwargs=kwargs,
|
|
)
|
|
# name: str = aten.overloads()[0]
|
|
# if aten == torch.ops.aten.addmm:
|
|
# name = "default"
|
|
# print(aten)
|
|
# print(aten.overloads())
|
|
# print(name)
|
|
# op_overload: torch._ops.OpOverload = getattr(aten, name)
|
|
node2 = torch.fx.Node(
|
|
graph=torch.fx.Graph(),
|
|
name="",
|
|
op="call_function",
|
|
target=op_overload,
|
|
args=args,
|
|
kwargs=kwargs,
|
|
)
|
|
return node1, node2
|
|
|
|
with V.set_fake_mode(
|
|
torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
|
|
):
|
|
trues = [
|
|
(
|
|
torch.ops.aten.addmm,
|
|
torch.ops.aten.addmm.default,
|
|
(torch.Tensor(4, 4), torch.Tensor(4, 5), torch.Tensor(5, 4)),
|
|
{},
|
|
),
|
|
(
|
|
torch.ops.aten.bmm,
|
|
torch.ops.aten.bmm.default,
|
|
(torch.Tensor(10, 4, 5), torch.Tensor(10, 5, 4)),
|
|
{},
|
|
),
|
|
(
|
|
torch.ops.aten.mm,
|
|
torch.ops.aten.mm.default,
|
|
(torch.Tensor(2, 3), torch.Tensor(3, 2)),
|
|
{},
|
|
),
|
|
(
|
|
torch.ops.aten.convolution,
|
|
torch.ops.aten.convolution.default,
|
|
(
|
|
torch.Tensor(2, 2, 3),
|
|
torch.Tensor(2, 2, 2),
|
|
torch.Tensor(2),
|
|
(1,),
|
|
(0,),
|
|
(1,),
|
|
True,
|
|
(0,),
|
|
1,
|
|
),
|
|
{},
|
|
),
|
|
(
|
|
torch.ops.aten._convolution,
|
|
torch.ops.aten._convolution.deprecated,
|
|
(
|
|
torch.Tensor(2, 2, 2),
|
|
torch.Tensor(2, 2, 2),
|
|
torch.Tensor(2),
|
|
(1,),
|
|
(0,),
|
|
(1,),
|
|
True,
|
|
(0,),
|
|
1,
|
|
False,
|
|
True,
|
|
False,
|
|
),
|
|
{},
|
|
),
|
|
]
|
|
# we don't support pointwise ops
|
|
falses = [
|
|
(
|
|
torch.ops.aten.add,
|
|
torch.ops.aten.add.Tensor,
|
|
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
|
|
{},
|
|
),
|
|
(
|
|
torch.ops.aten.mul,
|
|
torch.ops.aten.mul.Tensor,
|
|
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
|
|
{},
|
|
),
|
|
]
|
|
for t, t2, args, kwargs in trues:
|
|
fx_node_1, fx_node_2 = create_fx_node(t, t2, args, kwargs)
|
|
self.assertTrue(
|
|
countable_fx(fx_node_1), f"Expected true {t}: {fx_node_1}"
|
|
)
|
|
self.assertTrue(
|
|
countable_fx(fx_node_2), f"Expected true {t}: {fx_node_2}"
|
|
)
|
|
self.assertNotEqual(count_flops_fx(fx_node_1), None)
|
|
self.assertNotEqual(count_flops_fx(fx_node_2), None)
|
|
for f, f2, args, kwargs in falses:
|
|
fx_node_1, fx_node_2 = create_fx_node(f, f2, args, kwargs)
|
|
self.assertFalse(
|
|
countable_fx(fx_node_1), f"Expected false {f}: {fx_node_1}"
|
|
)
|
|
self.assertFalse(
|
|
countable_fx(fx_node_2), f"Expected false {f}: {fx_node_2}"
|
|
)
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "skip if no device")
|
|
@dtypes(torch.float16, torch.bfloat16, torch.float32)
|
|
def test_get_device_tflops(self, dtype):
|
|
ret = get_device_tflops(dtype)
|
|
self.assertTrue(type(ret) is float)
|
|
|
|
|
|
instantiate_device_type_tests(TestUtils, globals(), allow_xpu=True)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|