[ROCm] logsumexp on ROCm needs scaling back to natural base. (#156903)

Fixes #156012

This is a temporary solution that makes context parallelism working before logsumexp behavior changes landed in AOTriton.

After discussion we are not going to release AOTriton 0.10.1 to fix this due to
* Even if the interface is not changed, changing the behavior of returned logsumexp tensor should still be considered as an ABI break. Such changes do not fall into the "ABI compatible" category and should be postponed to next release.
* AOTriton 0.11 is scheduled to be released before end of July, which is less than five weeks

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156903
Approved by: https://github.com/jeffdaily, https://github.com/XilunWu
This commit is contained in:
Xinya Zhang 2025-07-14 02:50:36 +00:00 committed by PyTorch MergeBot
parent edb92e16ba
commit 1ea9cde598

View File

@ -43,6 +43,16 @@ class _RotateMethod(Enum):
aten = torch.ops.aten
logger = logging.getLogger(__name__)
_is_hip: bool = hasattr(torch.version, "hip") and torch.version.hip is not None
if _is_hip:
gcn_arch_name = torch.cuda.get_device_properties("cuda").gcnArchName
_is_ck_supported = False
for arch in ["gfx942", "gfx950"]:
if arch in gcn_arch_name:
_is_ck_supported = True
_preferred_rocm_fa_library = torch.backends.cuda.preferred_rocm_fa_library
_CK_BACKEND = torch.backends.cuda._ROCmFABackends["ck"]
class _DispatchMode(Enum):
MONKEY_PATCH = auto()
@ -446,6 +456,14 @@ def _templated_ring_attention(
is_causal=is_causal_behavior.value,
**kwargs,
)
if _is_hip: # See: https://github.com/pytorch/pytorch/issues/156012
need_scaling = True
# Note: it is possible that CK is seleted but not compiled in the binary.
if _is_ck_supported and _preferred_rocm_fa_library() == _CK_BACKEND:
# Unsure about CK's behavior, keep logsumexp untouched
need_scaling = False
if need_scaling:
logsumexp *= 0.6931471805599453
sdpa_merger.step(out, logsumexp, partial)
return *sdpa_merger.results(), *rest