[Triton 3.3] [ROCm] Enabled split_scan support for ROCm builds (#147619)

Fixes issue https://github.com/pytorch/pytorch/issues/133228

Enabled split_scan support for ROCm builds.

Must be handled in a non BC breaking way so this functionality is enabled conditionalised on triton version.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147619
Approved by: https://github.com/davidberard98

Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
Co-authored-by: David Berard <davidberard98@gmail.com>
This commit is contained in:
iupaikov-amd 2025-03-07 23:06:17 +00:00 committed by PyTorch MergeBot
parent 0f852641c2
commit f2dfe2d99c

View File

@ -108,6 +108,16 @@ else:
CUDATemplate: TypeAlias = object
try:
import triton
triton_version = triton.__version__
has_triton = True
except ImportError:
triton_version = None
has_triton = False
_T = TypeVar("_T")
_U = TypeVar("_U")
_V = TypeVar("_V")
@ -2185,7 +2195,9 @@ class Scan(Loops):
)
scan_type = Scan
if num_splits > 1:
supports_split = torch.version.hip is None and len(dtypes) == 1
supports_split = (
torch.version.hip is None or (has_triton and triton_version >= "3.3.0")
) and (len(dtypes) == 1)
if not supports_split:
if can_fallback_to_aten:
# Fallback to ATen