diff --git a/test/inductor/test_provenance_tracing.py b/test/inductor/test_provenance_tracing.py index cc8596d9036..0d59616bc53 100644 --- a/test/inductor/test_provenance_tracing.py +++ b/test/inductor/test_provenance_tracing.py @@ -7,6 +7,7 @@ import logging import os import re import shutil +import sys import tempfile import unittest import zipfile @@ -24,7 +25,7 @@ from torch._inductor.debug import ( ) from torch._inductor.fx_passes.post_grad import post_grad_passes from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import run_and_get_code +from torch._inductor.utils import run_and_get_code, run_and_get_cpp_code from torch._inductor.virtualized import V from torch.testing._internal.common_utils import IS_MACOS from torch.testing._internal.triton_utils import requires_cuda_and_triton @@ -32,8 +33,12 @@ from torch.testing._internal.triton_utils import requires_cuda_and_triton try: from .test_aot_inductor_utils import AOTIRunnerUtil + from .test_torchinductor import copy_tests except ImportError: from test_aot_inductor_utils import AOTIRunnerUtil + from test_torchinductor import ( + copy_tests, # @manual=fbcode//caffe2/test/inductor:test_inductor-library + ) trace_log = logging.getLogger("torch.__trace") @@ -806,5 +811,135 @@ class TestProvenanceTracingStackTraces(TestCase): self.assertTrue("aoti_torch_cpu_convolution" in keys) +class ProvenanceTracingKernelContextTemplate: + def test_jit_inductor_with_flag(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x, a, b, c): + x = self.fc1(x) + x = self.relu(x) + x = self.sigmoid(x) + d = a * 3.14 + y = torch.addmm(c, d, b) + z = torch.nn.functional.gelu(y) + return x, z + + model = Model().to(self.device) + x = torch.randn(8, 10).to(self.device) + a = torch.randn(10, 20).to(self.device) + b = torch.randn(20, 30).to(self.device) + c = torch.randn(10, 30).to(self.device) + example_inputs = (x, a, b, c) + + with config.patch( + { + "cpp.enable_kernel_profile": True, + } + ): + torch.compile(model)(*example_inputs) + + @unittest.skipIf(sys.platform == "darwin", "Different kernel names on MacOS") + def test_aoti_python_stack_traces(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x, a, b, c): + x = self.fc1(x) + x = self.relu(x) + x = self.sigmoid(x) + d = a * 3.14 + y = torch.addmm(c, d, b) + z = torch.nn.functional.gelu(y) + return x, z + + x = torch.randn(8, 10).to(self.device) + a = torch.randn(10, 20).to(self.device) + b = torch.randn(20, 30).to(self.device) + c = torch.randn(10, 30).to(self.device) + example_inputs = (x, a, b, c) + model = Model().to(self.device) + + ep = torch.export.export(model, example_inputs) + _, code = run_and_get_cpp_code(torch._inductor.aoti_compile_and_package, ep) + + self.assertTrue("KernelContextGuard" not in code) + + with config.patch( + { + "trace.provenance_tracking_level": 1, + "cpp.enable_kernel_profile": True, + } + ): + package_path, code = run_and_get_cpp_code( + torch._inductor.aoti_compile_and_package, ep + ) + + FileCheck().check( + "#include " + ).check("thread_local KernelContext* tls_kernel_context = nullptr;").run( + code + ) + + if self.device == "cuda": + FileCheck().check( + """KernelContextGuard _ctx("aoti_torch_cuda_mm_out", R"(""" + ).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda_mm_out(").check( + """KernelContextGuard _ctx("triton_poi_fused_addmm_relu_sigmoid_0", R"(""" + ).check("call_triton_poi_fused_addmm_relu_sigmoid_0(").check( + """KernelContextGuard _ctx("triton_poi_fused_mul_1", R"(""" + ).check("call_triton_poi_fused_mul_1(").check( + """KernelContextGuard _ctx("aoti_torch_cuda_mm_out", R"(""" + ).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda_mm_out(").check( + """ KernelContextGuard _ctx("triton_poi_fused_addmm_gelu_2", R"(""" + ).check("call_triton_poi_fused_addmm_gelu_2(").run(code) + else: + FileCheck().check( + """KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(""" + ).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(").check( + """KernelContextGuard _ctx("cpp_fused_mul_relu_sigmoid_0", R"(""" + ).check("cpp_fused_mul_relu_sigmoid_0(").check( + """KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(""" + ).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(").check( + """ KernelContextGuard _ctx("cpp_fused_gelu_1", R"(""" + ).check("cpp_fused_gelu_1(").run(code) + + compiled_model = torch._inductor.aoti_load_package(package_path) + result = compiled_model(*example_inputs) + self.assertEqual(result, model(*example_inputs)) + + +class TestProvenanceTracingKernelContextCpu(TestCase): + device = "cpu" + + +copy_tests( + ProvenanceTracingKernelContextTemplate, + TestProvenanceTracingKernelContextCpu, + "cpu", +) + + +@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS") +@unittest.skipIf(not torch.cuda.is_available(), "No CUDA") +class TestProvenanceTracingKernelContextGpu(TestCase): + device = "cuda" + + +copy_tests( + ProvenanceTracingKernelContextTemplate, + TestProvenanceTracingKernelContextGpu, + "cuda", +) + + if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_utils.py b/test/inductor/test_utils.py index fa666dfc987..9516b4ee089 100644 --- a/test/inductor/test_utils.py +++ b/test/inductor/test_utils.py @@ -91,7 +91,7 @@ class TestUtils(TestCase): def test_flops_fx(self): def create_fx_node( - aten: torch._ops.OpOverloadPacket, args, kwargs + aten, op_overload: torch._ops.OpOverload, args, kwargs ) -> tuple[torch.fx.Node, torch.fx.Node]: node1 = torch.fx.Node( graph=torch.fx.Graph(), @@ -101,8 +101,13 @@ class TestUtils(TestCase): args=args, kwargs=kwargs, ) - name: str = aten.overloads()[0] - op_overload: torch._ops.OpOverload = getattr(aten, name) + # name: str = aten.overloads()[0] + # if aten == torch.ops.aten.addmm: + # name = "default" + # print(aten) + # print(aten.overloads()) + # print(name) + # op_overload: torch._ops.OpOverload = getattr(aten, name) node2 = torch.fx.Node( graph=torch.fx.Graph(), name="", @@ -119,17 +124,25 @@ class TestUtils(TestCase): trues = [ ( torch.ops.aten.addmm, + torch.ops.aten.addmm.default, (torch.Tensor(4, 4), torch.Tensor(4, 5), torch.Tensor(5, 4)), {}, ), ( torch.ops.aten.bmm, + torch.ops.aten.bmm.default, (torch.Tensor(10, 4, 5), torch.Tensor(10, 5, 4)), {}, ), - (torch.ops.aten.mm, (torch.Tensor(2, 3), torch.Tensor(3, 2)), {}), + ( + torch.ops.aten.mm, + torch.ops.aten.mm.default, + (torch.Tensor(2, 3), torch.Tensor(3, 2)), + {}, + ), ( torch.ops.aten.convolution, + torch.ops.aten.convolution.default, ( torch.Tensor(2, 2, 3), torch.Tensor(2, 2, 2), @@ -145,6 +158,7 @@ class TestUtils(TestCase): ), ( torch.ops.aten._convolution, + torch.ops.aten._convolution.deprecated, ( torch.Tensor(2, 2, 2), torch.Tensor(2, 2, 2), @@ -166,17 +180,19 @@ class TestUtils(TestCase): falses = [ ( torch.ops.aten.add, + torch.ops.aten.add.Tensor, (torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)), {}, ), ( torch.ops.aten.mul, + torch.ops.aten.mul.Tensor, (torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)), {}, ), ] - for t, args, kwargs in trues: - fx_node_1, fx_node_2 = create_fx_node(t, args, kwargs) + for t, t2, args, kwargs in trues: + fx_node_1, fx_node_2 = create_fx_node(t, t2, args, kwargs) self.assertTrue( countable_fx(fx_node_1), f"Expected true {t}: {fx_node_1}" ) @@ -185,8 +201,8 @@ class TestUtils(TestCase): ) self.assertNotEqual(count_flops_fx(fx_node_1), None) self.assertNotEqual(count_flops_fx(fx_node_2), None) - for f, args, kwargs in falses: - fx_node_1, fx_node_2 = create_fx_node(f, args, kwargs) + for f, f2, args, kwargs in falses: + fx_node_1, fx_node_2 = create_fx_node(f, f2, args, kwargs) self.assertFalse( countable_fx(fx_node_1), f"Expected false {f}: {fx_node_1}" ) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 65f5d37d0d8..28036b2d302 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -5469,7 +5469,16 @@ class CppScheduling(BaseScheduling): src_code, self.kernel_group.scheduled_nodes ) self.codegen_comment(self.kernel_group.scheduled_nodes, kernel_name) + if config.cpp.enable_kernel_profile: + V.graph.wrapper_code.write_kernel_context_guard_begin() + V.graph.wrapper_code.write_kernel_context_guard( + kernel_name, + self.kernel_group.scheduled_nodes, # type: ignore[arg-type] + ) self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name) + if config.cpp.enable_kernel_profile: + V.graph.wrapper_code.write_kernel_context_guard_end() + self.reset_kernel_group() self._set_flush_status(False) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index e49498cce41..0c8ea746066 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -22,6 +22,7 @@ from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import config, cpp_builder, ir +from ..ir import ExternKernel from ..utils import _align, DeferredLineBase, LineContext, normalize_name from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper @@ -43,6 +44,8 @@ if TYPE_CHECKING: # At most, the list nesting can go one layer deep. _OUTPUT_ARGS_TYPE = list[Union[Optional[str], list[Optional[str]]]] + from ..scheduler import BaseSchedulerNode + class HasWriteLine(Protocol): def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None: ... @@ -233,6 +236,18 @@ class CppWrapperCpu(PythonWrapperCodegen): self.header.splice(f"""#include \"{self.model_class_name_suffix}.h\"""") self.header.splice("\n") + if config.cpp.enable_kernel_profile: + self.header.splice( + "#include " + ) + self.header.splice( + """ + namespace torch::aot_inductor { + thread_local KernelContext* tls_kernel_context = nullptr; + } + """ + ) + def _include_extra_header(self, header: str): # This is needed for cpp to python dtype conversion self.header.splice(f"#include <{header}>") @@ -1249,7 +1264,7 @@ class CppWrapperCpu(PythonWrapperCodegen): device: str, *, debug_args: Optional[list[str]] = None, - debug_handle: Optional[int] = None, + stack_traces: Optional[OrderedSet[str]] = None, ) -> None: """debug_args kwarg allows CppWrapperCpuArrayRef to pass in wrapped arguments in place of args while preserving debug printer output.""" @@ -1266,21 +1281,26 @@ class CppWrapperCpu(PythonWrapperCodegen): ] with debug_printer_manager: shim_fn = self.get_c_shim_func_name(kernel, device) - self.write_provenance_debug_handle(shim_fn, debug_handle) - shim_fn_codes = ( + shim_fn_codes = [ f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));" - ) + ] if enable_kernel_profile: - debug_handle_str = "" if debug_handle is None else f":{debug_handle}" - shim_fn_codes = textwrap.dedent( - f""" - {{ - RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}{debug_handle_str}", nullptr); - {shim_fn_codes} - }} - """ - ) - self.writeline(shim_fn_codes) + stack_trace_str = 'R"(' + if stack_traces: + for stack_trace in stack_traces: + for line in stack_trace.split("\n"): + stack_trace_str += f"\n{line}" + stack_trace_str += "\n" + stack_trace_str += ')"' + + shim_fn_codes = [ + "{", + f"""KernelContextGuard _ctx("{shim_fn}", {stack_trace_str});""", + f"""RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}", nullptr);""", + shim_fn_codes[0], + "}", + ] + self.writelines(shim_fn_codes) def generate_c_shim_extern_kernel_alloc( self, extern_kernel: ir.ExternKernelAlloc, args: list[str] @@ -1373,7 +1393,7 @@ class CppWrapperCpu(PythonWrapperCodegen): out_view: Optional[str], args: list[str], device: str, - debug_handle: Optional[int] = None, + stack_traces: Optional[OrderedSet[str]] = None, ) -> None: if out_view: out_name = f"{out}_as_strided" @@ -1383,7 +1403,7 @@ class CppWrapperCpu(PythonWrapperCodegen): args.insert(0, out) self.generate_c_shim_extern_kernel_call( - kernel, args, device, debug_handle=debug_handle + kernel, args, device, stack_traces=stack_traces ) def _get_scatter_reduce_enum(self, reduce): @@ -2897,3 +2917,53 @@ if (!custom_op_wrapper) { writer.writeline(call_str) return tmp_var_name + + def write_kernel_context_guard_begin( + self, + ): + # Beginning of a kernel context guarded block. + # The block looks like this: + # { + # KernelContextGuard _ctx("{kernel_name}", {stack_trace_str}); + # ... operations... + # } + self.writeline("{") + + def write_kernel_context_guard_end( + self, + ): + # End of a kernel context guarded block. + self.writeline("}") + + def write_kernel_context_guard( + self, + kernel_name: str, + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], + ): + def aggregate_stack_traces( + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], + ) -> OrderedSet[str]: + if isinstance(node_schedule, list): + return functools.reduce( + lambda a, b: a | b, + [ + node.node.get_stack_traces() + for node in node_schedule + if hasattr(node, "node") and node.node + ], + OrderedSet(), + ) + elif isinstance(node_schedule, ExternKernel): + return node_schedule.get_stack_traces() + else: + return OrderedSet() + + stack_trace_str = 'R"(' + stack_traces = aggregate_stack_traces(node_schedule) + + for stack_trace in stack_traces: + for line in stack_trace.split("\n"): + stack_trace_str += f"\n{line}" + stack_trace_str += "\n" + stack_trace_str += ')"' + self.writeline(f'KernelContextGuard _ctx("{kernel_name}", {stack_trace_str});') diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 7e5457f78eb..3310e3facfa 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1707,6 +1707,9 @@ class SIMDScheduling(BaseScheduling): return True def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures): + """ + Generate code for nodes in kernel_features + """ node_schedule = kernel_features.node_schedule tiling, tiling_score = self.get_tiling_and_scores( @@ -1747,7 +1750,15 @@ class SIMDScheduling(BaseScheduling): node for node in node_schedule if isinstance(node, BaseSchedulerNode) ] self.codegen_comment(base_scheduler_nodes, final_kernel.kernel_name) + if config.cpp.enable_kernel_profile: + V.graph.wrapper_code.write_kernel_context_guard_begin() + V.graph.wrapper_code.write_kernel_context_guard( + final_kernel.kernel_name, + base_scheduler_nodes, # type: ignore[arg-type] + ) final_kernel.call_kernel(final_kernel.kernel_name) + if config.cpp.enable_kernel_profile: + V.graph.wrapper_code.write_kernel_context_guard_end() if config.nan_asserts: final_kernel.codegen_nan_check() diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index fa5048fd726..afc782386e3 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -77,6 +77,8 @@ if TYPE_CHECKING: import triton from ..graph import GraphLowering + from ..ir import ExternKernel + from ..scheduler import BaseSchedulerNode from .wrapper_fxir import FxConverter @@ -528,6 +530,7 @@ class ExternKernelOutLine(WrapperLine): node.output_view.codegen_reference() if node.output_view else None, args, device, + self.node.get_stack_traces(), ) def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: @@ -1554,6 +1557,7 @@ class PythonWrapperCodegen(CodeGen): out_view: Optional[str], args: list[str], device: str, + stack_traces: Optional[OrderedSet[str]] = None, ) -> None: # add debug printer code for triton kernel calls at (jit) inductor level debug_printer_manager = V.graph.wrapper_code.debug_printer @@ -3690,6 +3694,29 @@ class PythonWrapperCodegen(CodeGen): def can_prove_buffer_has_static_shape(buffer): return PythonWrapperCodegen.static_shape_for_buffer_or_none(buffer) is not None + def write_kernel_context_guard( + self, + kernel_name: str, + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], + ): + return + + def write_kernel_context_guard_begin( + self, + ): + """ + Mark the beginning of kernel context guard + """ + return + + def write_kernel_context_guard_end( + self, + ): + """ + Mark the end of kernel context guard + """ + return + class SubgraphPythonWrapperCodegen(PythonWrapperCodegen): """ diff --git a/torch/csrc/inductor/aoti_runtime/kernel_context_tls.h b/torch/csrc/inductor/aoti_runtime/kernel_context_tls.h new file mode 100644 index 00000000000..3489494d77e --- /dev/null +++ b/torch/csrc/inductor/aoti_runtime/kernel_context_tls.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include + +namespace torch::aot_inductor { + +struct KernelContext { + std::string kernel_name; + std::string python_stack; + + KernelContext(std::string name, std::string stack) + : kernel_name(std::move(name)), python_stack(std::move(stack)) {} +}; + +// Thread-local pointer +extern thread_local KernelContext* tls_kernel_context; + +inline KernelContext* current_kernel_context() { + return tls_kernel_context; +} + +inline void set_kernel_context(KernelContext* ctx) { + tls_kernel_context = ctx; +} + +inline void clear_kernel_context() { + tls_kernel_context = nullptr; +} + +struct KernelContextGuard { + KernelContextGuard(const std::string& name, const std::string& stack) + : owned_context_(name, stack) { + set_kernel_context(&owned_context_); + } + ~KernelContextGuard() { + clear_kernel_context(); + } + + // Delete copy constructor and copy assignment operator + KernelContextGuard(const KernelContextGuard&) = delete; + KernelContextGuard& operator=(const KernelContextGuard&) = delete; + + KernelContextGuard(KernelContextGuard&&) = default; + KernelContextGuard& operator=(KernelContextGuard&&) = delete; + + private: + KernelContext owned_context_; +}; + +} // namespace torch::aot_inductor