diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index 2b9c9727c02..03e73631151 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -809,7 +809,7 @@ def _templated_ring_attention_backward( grad_query = grad_query.to(query.dtype) next_grad_kv = dkv_rotater.next_buffer().to(key.dtype) grad_key = next_grad_kv[: grad_key.numel()].reshape(grad_key.shape) - grad_value = next_grad_kv[grad_value.numel() :].reshape(grad_value.shape) + grad_value = next_grad_kv[grad_key.numel() :].reshape(grad_value.shape) return ( grad_query, grad_key,