mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Triton 3.3] [ROCm] Enabled split_scan support for ROCm builds (#147619)
Fixes issue https://github.com/pytorch/pytorch/issues/133228 Enabled split_scan support for ROCm builds. Must be handled in a non BC breaking way so this functionality is enabled conditionalised on triton version. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147619 Approved by: https://github.com/davidberard98 Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com> Co-authored-by: David Berard <davidberard98@gmail.com>
This commit is contained in:
parent
0f852641c2
commit
f2dfe2d99c
|
|
@ -108,6 +108,16 @@ else:
|
|||
CUDATemplate: TypeAlias = object
|
||||
|
||||
|
||||
try:
|
||||
import triton
|
||||
|
||||
triton_version = triton.__version__
|
||||
has_triton = True
|
||||
except ImportError:
|
||||
triton_version = None
|
||||
has_triton = False
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_U = TypeVar("_U")
|
||||
_V = TypeVar("_V")
|
||||
|
|
@ -2185,7 +2195,9 @@ class Scan(Loops):
|
|||
)
|
||||
scan_type = Scan
|
||||
if num_splits > 1:
|
||||
supports_split = torch.version.hip is None and len(dtypes) == 1
|
||||
supports_split = (
|
||||
torch.version.hip is None or (has_triton and triton_version >= "3.3.0")
|
||||
) and (len(dtypes) == 1)
|
||||
if not supports_split:
|
||||
if can_fallback_to_aten:
|
||||
# Fallback to ATen
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user