mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] logsumexp on ROCm needs scaling back to natural base. (#156903)
Fixes #156012 This is a temporary solution that makes context parallelism working before logsumexp behavior changes landed in AOTriton. After discussion we are not going to release AOTriton 0.10.1 to fix this due to * Even if the interface is not changed, changing the behavior of returned logsumexp tensor should still be considered as an ABI break. Such changes do not fall into the "ABI compatible" category and should be postponed to next release. * AOTriton 0.11 is scheduled to be released before end of July, which is less than five weeks Pull Request resolved: https://github.com/pytorch/pytorch/pull/156903 Approved by: https://github.com/jeffdaily, https://github.com/XilunWu
This commit is contained in:
parent
edb92e16ba
commit
1ea9cde598
|
|
@ -43,6 +43,16 @@ class _RotateMethod(Enum):
|
|||
aten = torch.ops.aten
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_hip: bool = hasattr(torch.version, "hip") and torch.version.hip is not None
|
||||
if _is_hip:
|
||||
gcn_arch_name = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
_is_ck_supported = False
|
||||
for arch in ["gfx942", "gfx950"]:
|
||||
if arch in gcn_arch_name:
|
||||
_is_ck_supported = True
|
||||
_preferred_rocm_fa_library = torch.backends.cuda.preferred_rocm_fa_library
|
||||
_CK_BACKEND = torch.backends.cuda._ROCmFABackends["ck"]
|
||||
|
||||
|
||||
class _DispatchMode(Enum):
|
||||
MONKEY_PATCH = auto()
|
||||
|
|
@ -446,6 +456,14 @@ def _templated_ring_attention(
|
|||
is_causal=is_causal_behavior.value,
|
||||
**kwargs,
|
||||
)
|
||||
if _is_hip: # See: https://github.com/pytorch/pytorch/issues/156012
|
||||
need_scaling = True
|
||||
# Note: it is possible that CK is seleted but not compiled in the binary.
|
||||
if _is_ck_supported and _preferred_rocm_fa_library() == _CK_BACKEND:
|
||||
# Unsure about CK's behavior, keep logsumexp untouched
|
||||
need_scaling = False
|
||||
if need_scaling:
|
||||
logsumexp *= 0.6931471805599453
|
||||
sdpa_merger.step(out, logsumexp, partial)
|
||||
|
||||
return *sdpa_merger.results(), *rest
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user