[CP] Fix the offsets to KV in backward (#152625)

This is more semantically correct even though we currently assumed KV have the same lengths.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152625
Approved by: https://github.com/XilunWu
This commit is contained in:
Chien-Chin Huang 2025-05-01 11:41:07 -07:00 committed by PyTorch MergeBot
parent 1fef3cdabc
commit 36e5ff6bc4

View File

@ -809,7 +809,7 @@ def _templated_ring_attention_backward(
grad_query = grad_query.to(query.dtype) grad_query = grad_query.to(query.dtype)
next_grad_kv = dkv_rotater.next_buffer().to(key.dtype) next_grad_kv = dkv_rotater.next_buffer().to(key.dtype)
grad_key = next_grad_kv[: grad_key.numel()].reshape(grad_key.shape) grad_key = next_grad_kv[: grad_key.numel()].reshape(grad_key.shape)
grad_value = next_grad_kv[grad_value.numel() :].reshape(grad_value.shape) grad_value = next_grad_kv[grad_key.numel() :].reshape(grad_value.shape)
return ( return (
grad_query, grad_query,
grad_key, grad_key,