[easy] Handle Autotuners in get_triton_source_codes_for_gm (#161914)

Some triton kernels are autotuners, in that case, grab the function from the autotuner.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161914
Approved by: https://github.com/oulgen
This commit is contained in:
James Wu 2025-09-14 19:58:52 +00:00 committed by PyTorch MergeBot
parent 7d1bcd9aea
commit 70337a066f

View File

@ -306,6 +306,8 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
self,
gm: torch.fx.GraphModule,
):
assert has_triton_package(), "Triton is not available"
triton_kernels = []
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
@ -331,6 +333,11 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
)
for kernel in triton_kernels:
from triton.runtime.autotuner import Autotuner
if isinstance(kernel, Autotuner):
# Grab the Inner JITFunction
kernel = kernel.fn
source_codes = user_defined_triton_kernel_transitive_closure_source_code(
kernel
)
@ -355,6 +362,7 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
[],
[],
)
if has_triton_package():
self.triton_kernel_source_codes = self.get_triton_source_codes_from_gm(gm)
if hasattr(gm, "saved_tensors_hooks_pack_0"):