mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CP] Fix the offsets to KV in backward (#152625)
This is more semantically correct even though we currently assumed KV have the same lengths. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152625 Approved by: https://github.com/XilunWu
This commit is contained in:
parent
1fef3cdabc
commit
36e5ff6bc4
|
|
@ -809,7 +809,7 @@ def _templated_ring_attention_backward(
|
|||
grad_query = grad_query.to(query.dtype)
|
||||
next_grad_kv = dkv_rotater.next_buffer().to(key.dtype)
|
||||
grad_key = next_grad_kv[: grad_key.numel()].reshape(grad_key.shape)
|
||||
grad_value = next_grad_kv[grad_value.numel() :].reshape(grad_value.shape)
|
||||
grad_value = next_grad_kv[grad_key.numel() :].reshape(grad_value.shape)
|
||||
return (
|
||||
grad_query,
|
||||
grad_key,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user