mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] Properly package target info for triton.compile (#125241)
Triton updated the interface for `triton.compile` 5162346487
The `target` argument to compile needs to be wrapped in a `GPUTarget` object. Without proper wrapping, we hit an assert in `compile`. If that assert is removed, Triton attempts to read device info from Torch while inside a torch thread, which hits an in bad fork assert. This change is required for compatibility with latest commits in Triton. The implementation is backwards compatible, so existing versions of Triton that work now continue to work.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125241
Approved by: https://github.com/jansel
This commit is contained in:
parent
9aa7699185
commit
8a1af95b09
|
|
@ -55,11 +55,17 @@ if triton is not None:
|
|||
from triton.compiler.compiler import ASTSource
|
||||
except ImportError:
|
||||
ASTSource = None
|
||||
|
||||
try:
|
||||
from triton.backends.compiler import GPUTarget
|
||||
except ImportError:
|
||||
GPUTarget = None
|
||||
else:
|
||||
Config = object
|
||||
KernelInterface = object
|
||||
OutOfResources = object
|
||||
ASTSource = None
|
||||
GPUTarget = None
|
||||
|
||||
try:
|
||||
autograd_profiler = torch.autograd.profiler
|
||||
|
|
@ -334,11 +340,22 @@ class CachingAutotuner(KernelInterface):
|
|||
else:
|
||||
rocm_warp_size = 64
|
||||
|
||||
target = (
|
||||
(compile_meta["device_type"], compile_meta["cc"])
|
||||
if not torch.version.hip
|
||||
else [compile_meta["device_type"], compile_meta["cc"], rocm_warp_size]
|
||||
)
|
||||
if GPUTarget:
|
||||
target = GPUTarget(
|
||||
compile_meta["device_type"],
|
||||
compile_meta["cc"],
|
||||
rocm_warp_size if torch.version.hip else 32,
|
||||
)
|
||||
else:
|
||||
target = (
|
||||
(compile_meta["device_type"], compile_meta["cc"])
|
||||
if not torch.version.hip
|
||||
else [
|
||||
compile_meta["device_type"],
|
||||
compile_meta["cc"],
|
||||
rocm_warp_size,
|
||||
]
|
||||
)
|
||||
|
||||
options = {
|
||||
"num_warps": compile_meta["num_warps"],
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user