[Inductor] Properly package target info for triton.compile (#125241)

Triton updated the interface for `triton.compile` 5162346487

The `target` argument to compile needs to be wrapped in a `GPUTarget` object. Without proper wrapping, we hit an assert in `compile`. If that assert is removed, Triton attempts to read device info from Torch while inside a torch thread, which hits an in bad fork assert. This change is required for compatibility with latest commits in Triton. The implementation is backwards compatible, so existing versions of Triton that work now continue to work.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125241
Approved by: https://github.com/jansel
This commit is contained in:
Alex Baden 2024-05-04 00:10:53 +00:00 committed by PyTorch MergeBot
parent 9aa7699185
commit 8a1af95b09

View File

@ -55,11 +55,17 @@ if triton is not None:
from triton.compiler.compiler import ASTSource
except ImportError:
ASTSource = None
try:
from triton.backends.compiler import GPUTarget
except ImportError:
GPUTarget = None
else:
Config = object
KernelInterface = object
OutOfResources = object
ASTSource = None
GPUTarget = None
try:
autograd_profiler = torch.autograd.profiler
@ -334,11 +340,22 @@ class CachingAutotuner(KernelInterface):
else:
rocm_warp_size = 64
target = (
(compile_meta["device_type"], compile_meta["cc"])
if not torch.version.hip
else [compile_meta["device_type"], compile_meta["cc"], rocm_warp_size]
)
if GPUTarget:
target = GPUTarget(
compile_meta["device_type"],
compile_meta["cc"],
rocm_warp_size if torch.version.hip else 32,
)
else:
target = (
(compile_meta["device_type"], compile_meta["cc"])
if not torch.version.hip
else [
compile_meta["device_type"],
compile_meta["cc"],
rocm_warp_size,
]
)
options = {
"num_warps": compile_meta["num_warps"],