mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[easy] Handle Autotuners in get_triton_source_codes_for_gm (#161914)
Some triton kernels are autotuners, in that case, grab the function from the autotuner. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161914 Approved by: https://github.com/oulgen
This commit is contained in:
parent
7d1bcd9aea
commit
70337a066f
|
|
@ -306,6 +306,8 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
|
|||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
):
|
||||
assert has_triton_package(), "Triton is not available"
|
||||
|
||||
triton_kernels = []
|
||||
for module in gm.modules():
|
||||
if not isinstance(module, torch.fx.GraphModule):
|
||||
|
|
@ -331,6 +333,11 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
|
|||
)
|
||||
|
||||
for kernel in triton_kernels:
|
||||
from triton.runtime.autotuner import Autotuner
|
||||
|
||||
if isinstance(kernel, Autotuner):
|
||||
# Grab the Inner JITFunction
|
||||
kernel = kernel.fn
|
||||
source_codes = user_defined_triton_kernel_transitive_closure_source_code(
|
||||
kernel
|
||||
)
|
||||
|
|
@ -355,7 +362,8 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
|
|||
[],
|
||||
[],
|
||||
)
|
||||
self.triton_kernel_source_codes = self.get_triton_source_codes_from_gm(gm)
|
||||
if has_triton_package():
|
||||
self.triton_kernel_source_codes = self.get_triton_source_codes_from_gm(gm)
|
||||
|
||||
if hasattr(gm, "saved_tensors_hooks_pack_0"):
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user