mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CP] Fix the offsets to KV in backward (#152625)
This is more semantically correct even though we currently assumed KV have the same lengths. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152625 Approved by: https://github.com/XilunWu
This commit is contained in:
parent
1fef3cdabc
commit
36e5ff6bc4
|
|
@ -809,7 +809,7 @@ def _templated_ring_attention_backward(
|
||||||
grad_query = grad_query.to(query.dtype)
|
grad_query = grad_query.to(query.dtype)
|
||||||
next_grad_kv = dkv_rotater.next_buffer().to(key.dtype)
|
next_grad_kv = dkv_rotater.next_buffer().to(key.dtype)
|
||||||
grad_key = next_grad_kv[: grad_key.numel()].reshape(grad_key.shape)
|
grad_key = next_grad_kv[: grad_key.numel()].reshape(grad_key.shape)
|
||||||
grad_value = next_grad_kv[grad_value.numel() :].reshape(grad_value.shape)
|
grad_value = next_grad_kv[grad_key.numel() :].reshape(grad_value.shape)
|
||||||
return (
|
return (
|
||||||
grad_query,
|
grad_query,
|
||||||
grad_key,
|
grad_key,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user