mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Revert "bwd pass (#164504)"
This reverts commit f36f372acc.
Reverted https://github.com/pytorch/pytorch/pull/164504 on behalf of https://github.com/jeffdaily due to CI had been clean for both cuda and rocm before merge, broke post merge? ([comment](https://github.com/pytorch/pytorch/pull/164504#issuecomment-3462116676))
This commit is contained in:
parent
467c21ad9a
commit
d6d6fa26f5
|
|
@ -5,12 +5,11 @@ from collections import namedtuple
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.attention.varlen import varlen_attn
|
||||
from torch.nn.attention import varlen_attn
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_nn import NNTestCase
|
||||
from torch.testing._internal.common_utils import parametrize, run_tests
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
VarlenShape = namedtuple(
|
||||
|
|
@ -24,18 +23,6 @@ default_tolerances = {
|
|||
}
|
||||
|
||||
|
||||
class OpLoggingMode(TorchDispatchMode):
|
||||
"""Logging mode that captures all dispatched operations"""
|
||||
|
||||
def __init__(self):
|
||||
self.called_ops = []
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
op_name = str(func)
|
||||
self.called_ops.append(op_name)
|
||||
return func(*args, **(kwargs or {}))
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self, embed_dim: int, num_heads: int, device: torch.device, dtype: torch.dtype
|
||||
|
|
@ -52,9 +39,12 @@ class AttentionBlock(nn.Module):
|
|||
embed_dim, embed_dim, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def get_varlen_qkv(
|
||||
def forward_varlen(
|
||||
self,
|
||||
x_packed: torch.Tensor,
|
||||
cu_seq: torch.Tensor,
|
||||
max_len: int,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
qkv = self.qkv_proj(x_packed)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
|
@ -63,51 +53,24 @@ class AttentionBlock(nn.Module):
|
|||
k = k.view(-1, self.num_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_heads, self.head_dim)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward_varlen(
|
||||
self,
|
||||
x_packed: torch.Tensor,
|
||||
cu_seq: torch.Tensor,
|
||||
max_len: int,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
q, k, v = self.get_varlen_qkv(x_packed)
|
||||
|
||||
attn_out = varlen_attn(q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal)
|
||||
attn_out = varlen_attn(
|
||||
q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal=is_causal
|
||||
)
|
||||
attn_out = attn_out.view(-1, self.embed_dim)
|
||||
|
||||
return self.out_proj(attn_out)
|
||||
|
||||
def forward_sdpa(
|
||||
self,
|
||||
x_padded: torch.Tensor,
|
||||
seq_lengths: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
def forward_sdpa(self, x_padded: torch.Tensor, is_causal: bool = False):
|
||||
batch_size, seq_len, _ = x_padded.shape
|
||||
|
||||
qkv = self.qkv_proj(x_padded)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
mask = (
|
||||
torch.arange(seq_len, device=x_padded.device)[None, :]
|
||||
< seq_lengths[:, None]
|
||||
)
|
||||
|
||||
attn_mask = mask[:, None, None, :].expand(
|
||||
batch_size, self.num_heads, seq_len, seq_len
|
||||
)
|
||||
|
||||
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
attn_out = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, is_causal=is_causal
|
||||
)
|
||||
|
||||
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
|
||||
attn_out = (
|
||||
attn_out.transpose(1, 2)
|
||||
.contiguous()
|
||||
|
|
@ -128,9 +91,7 @@ def create_variable_length_batch(
|
|||
seq_lengths = torch.tensor(seq_lengths, device=device)
|
||||
total_tokens = seq_lengths.sum().item()
|
||||
|
||||
x_packed = torch.randn(
|
||||
total_tokens, shape.embed_dim, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
x_packed = torch.randn(total_tokens, shape.embed_dim, device=device, dtype=dtype)
|
||||
|
||||
cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32)
|
||||
cu_seq[1:] = seq_lengths.cumsum(0)
|
||||
|
|
@ -145,7 +106,6 @@ def create_variable_length_batch(
|
|||
end_idx = start_idx + seq_len
|
||||
x_padded[i, :seq_len] = x_packed[start_idx:end_idx]
|
||||
start_idx = end_idx
|
||||
x_padded = x_padded.clone().detach().requires_grad_()
|
||||
|
||||
return {
|
||||
"seq_lengths": seq_lengths,
|
||||
|
|
@ -173,11 +133,7 @@ class TestVarlenAttention(NNTestCase):
|
|||
|
||||
total_tokens = shape.batch_size * shape.max_seq_len
|
||||
x_packed = torch.randn(
|
||||
total_tokens,
|
||||
shape.embed_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
total_tokens, shape.embed_dim, device=device, dtype=dtype
|
||||
)
|
||||
cu_seq = torch.tensor(
|
||||
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
|
||||
|
|
@ -191,128 +147,6 @@ class TestVarlenAttention(NNTestCase):
|
|||
self.assertEqual(output.device, torch.device(device))
|
||||
self.assertEqual(output.dtype, dtype)
|
||||
|
||||
varlen_grad_out = torch.ones_like(output)
|
||||
|
||||
varlen_grad = torch.autograd.grad(
|
||||
outputs=output,
|
||||
inputs=x_packed,
|
||||
grad_outputs=varlen_grad_out,
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
allow_unused=False,
|
||||
)[0]
|
||||
|
||||
self.assertIsNotNone(varlen_grad)
|
||||
self.assertEqual(varlen_grad.shape, x_packed.shape)
|
||||
self.assertEqual(varlen_grad.dtype, x_packed.dtype)
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
|
||||
)
|
||||
@parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
def test_custom_op_compliance(self, device, dtype):
|
||||
torch.manual_seed(42)
|
||||
|
||||
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
|
||||
|
||||
attention_block = AttentionBlock(
|
||||
shape.embed_dim, shape.num_heads, device, dtype
|
||||
)
|
||||
|
||||
total_tokens = shape.batch_size * shape.max_seq_len
|
||||
x_packed = torch.randn(
|
||||
total_tokens,
|
||||
shape.embed_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
cu_seq = torch.tensor(
|
||||
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
q, k, v = attention_block.get_varlen_qkv(x_packed)
|
||||
|
||||
torch.library.opcheck(
|
||||
torch.ops.torch_attn._varlen_attn,
|
||||
(q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False),
|
||||
)
|
||||
|
||||
out, lse, rng_state = torch.ops.torch_attn._varlen_attn(
|
||||
q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False
|
||||
)
|
||||
grad_out = torch.randn_like(out)
|
||||
|
||||
# we don't support double backward
|
||||
# skipping test_autograd_registration, test_aot_dispatch_dynamic, test_aot_dispatch_static
|
||||
torch.library.opcheck(
|
||||
torch.ops.torch_attn._varlen_attn_backward,
|
||||
(
|
||||
grad_out,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
lse,
|
||||
cu_seq,
|
||||
cu_seq,
|
||||
shape.max_seq_len,
|
||||
shape.max_seq_len,
|
||||
False,
|
||||
rng_state,
|
||||
),
|
||||
test_utils=["test_schema", "test_faketensor"],
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
|
||||
)
|
||||
@parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
def test_custom_op_registration(self, device, dtype):
|
||||
torch.manual_seed(42)
|
||||
|
||||
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
|
||||
|
||||
attention_block = AttentionBlock(
|
||||
shape.embed_dim, shape.num_heads, device, dtype
|
||||
)
|
||||
|
||||
total_tokens = shape.batch_size * shape.max_seq_len
|
||||
x_packed = torch.randn(
|
||||
total_tokens,
|
||||
shape.embed_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
cu_seq = torch.tensor(
|
||||
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
compiled_forward = torch.compile(
|
||||
attention_block.forward_varlen, backend="eager", fullgraph=True
|
||||
)
|
||||
with OpLoggingMode() as mode:
|
||||
output = compiled_forward(
|
||||
x_packed, cu_seq, shape.max_seq_len, is_causal=False
|
||||
)
|
||||
|
||||
varlen_grad_out = torch.ones_like(output)
|
||||
_ = torch.autograd.grad(
|
||||
outputs=output,
|
||||
inputs=x_packed,
|
||||
grad_outputs=varlen_grad_out,
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
allow_unused=False,
|
||||
)[0]
|
||||
|
||||
called_ops = mode.called_ops
|
||||
|
||||
custom_ops_called = any(
|
||||
"torch_attn._varlen_attn" in op for op in called_ops
|
||||
) and any("torch_attn._varlen_attn_backward" in op for op in called_ops)
|
||||
assert custom_ops_called
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
|
||||
)
|
||||
|
|
@ -338,10 +172,7 @@ class TestVarlenAttention(NNTestCase):
|
|||
is_causal=is_causal,
|
||||
)
|
||||
sdpa_output = attention_block.forward_sdpa(
|
||||
variable_length_batch_data["x_padded"],
|
||||
variable_length_batch_data["seq_lengths"],
|
||||
dtype=dtype,
|
||||
is_causal=is_causal,
|
||||
variable_length_batch_data["x_padded"], is_causal=is_causal
|
||||
)
|
||||
|
||||
tolerances = default_tolerances[dtype]
|
||||
|
|
@ -355,44 +186,6 @@ class TestVarlenAttention(NNTestCase):
|
|||
torch.testing.assert_close(varlen_seq, sdpa_seq, **tolerances)
|
||||
start_idx = end_idx
|
||||
|
||||
varlen_grad_out = torch.ones_like(varlen_output)
|
||||
|
||||
sdpa_grad_out = torch.zeros_like(sdpa_output)
|
||||
|
||||
start_idx = 0
|
||||
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
|
||||
end_idx = start_idx + seq_len
|
||||
sdpa_grad_out[i, :seq_len] = varlen_grad_out[start_idx:end_idx]
|
||||
start_idx = end_idx
|
||||
|
||||
varlen_grad = torch.autograd.grad(
|
||||
outputs=varlen_output,
|
||||
inputs=variable_length_batch_data["x_packed"],
|
||||
grad_outputs=varlen_grad_out,
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
allow_unused=False,
|
||||
)[0]
|
||||
|
||||
sdpa_grad = torch.autograd.grad(
|
||||
outputs=sdpa_output,
|
||||
inputs=variable_length_batch_data["x_padded"],
|
||||
grad_outputs=sdpa_grad_out,
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
allow_unused=False,
|
||||
)[0]
|
||||
|
||||
start_idx = 0
|
||||
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
|
||||
end_idx = start_idx + seq_len
|
||||
|
||||
varlen_grad_seq = varlen_grad[start_idx:end_idx]
|
||||
sdpa_grad_seq = sdpa_grad[i, :seq_len]
|
||||
|
||||
torch.testing.assert_close(varlen_grad_seq, sdpa_grad_seq, **tolerances)
|
||||
start_idx = end_idx
|
||||
|
||||
|
||||
device_types = ("cuda",)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,11 +14,14 @@ from torch.backends.cuda import (
|
|||
SDPAParams,
|
||||
)
|
||||
|
||||
from .varlen import varlen_attn
|
||||
|
||||
|
||||
__all__: list[str] = [
|
||||
"SDPBackend",
|
||||
"sdpa_kernel",
|
||||
"WARN_FOR_UNFUSED_KERNELS",
|
||||
"varlen_attn",
|
||||
]
|
||||
|
||||
# Note: [SDPA warnings]
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ that calls into the optimized Flash Attention kernels.
|
|||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
from typing import NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -33,7 +33,8 @@ class AuxRequest(NamedTuple):
|
|||
lse: bool = False
|
||||
|
||||
|
||||
@torch.library.custom_op("torch_attn::_varlen_attn", mutates_args={})
|
||||
# import failures when I try to register as custom op
|
||||
# @torch.library.custom_op("torch_nn_attention::_varlen_attn", mutates_args={})
|
||||
def _varlen_attn(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
|
|
@ -43,7 +44,7 @@ def _varlen_attn(
|
|||
max_q: int,
|
||||
max_k: int,
|
||||
is_causal: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Private custom op for variable-length attention.
|
||||
|
||||
|
|
@ -69,7 +70,7 @@ def _varlen_attn(
|
|||
False, # return_debug_mask
|
||||
)
|
||||
# cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask)
|
||||
output, softmax_lse, rng_state = result[0], result[1], result[6]
|
||||
output, softmax_lse = result[0], result[1]
|
||||
else:
|
||||
log.info("Using Flash Attention backend for varlen_attn")
|
||||
output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward(
|
||||
|
|
@ -85,13 +86,10 @@ def _varlen_attn(
|
|||
return_debug_mask=False,
|
||||
)
|
||||
|
||||
rng_state_ = torch.zeros(
|
||||
(2,), dtype=torch.uint64, device=query.device
|
||||
) # hardcoded since dropout is hardcoded to 0
|
||||
return output, softmax_lse, rng_state_
|
||||
return output, softmax_lse
|
||||
|
||||
|
||||
@_varlen_attn.register_fake
|
||||
# @_varlen_attn.register_fake
|
||||
def _varlen_attn_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
|
|
@ -101,7 +99,7 @@ def _varlen_attn_fake(
|
|||
max_q: int,
|
||||
max_k: int,
|
||||
is_causal: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Fake implementation for meta tensor computation and tracing.
|
||||
|
||||
|
|
@ -119,9 +117,7 @@ def _varlen_attn_fake(
|
|||
(num_heads, total_q), dtype=torch.float, device=query.device
|
||||
)
|
||||
|
||||
rng_state = torch.empty((2,), dtype=torch.uint64, device=query.device)
|
||||
|
||||
return output, logsumexp, rng_state
|
||||
return output, logsumexp
|
||||
|
||||
|
||||
def varlen_attn(
|
||||
|
|
@ -195,145 +191,9 @@ def varlen_attn(
|
|||
... query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False
|
||||
... )
|
||||
"""
|
||||
out, lse, _ = torch.ops.torch_attn._varlen_attn(
|
||||
out, lse = _varlen_attn(
|
||||
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal
|
||||
)
|
||||
if return_aux is not None and return_aux.lse:
|
||||
return out, lse
|
||||
return out
|
||||
|
||||
|
||||
def _setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None:
|
||||
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal = inputs
|
||||
out, lse, rng_state = output
|
||||
ctx.query = query
|
||||
ctx.key = key
|
||||
ctx.value = value
|
||||
ctx.cu_seq_q = cu_seq_q
|
||||
ctx.cu_seq_k = cu_seq_k
|
||||
ctx.max_q = max_q
|
||||
ctx.max_k = max_k
|
||||
ctx.is_causal = is_causal
|
||||
ctx.output = out
|
||||
ctx.lse = lse
|
||||
ctx.rng_state = rng_state
|
||||
|
||||
|
||||
@torch.library.custom_op("torch_attn::_varlen_attn_backward", mutates_args={})
|
||||
def _varlen_attn_backward(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
cu_seq_q: torch.Tensor,
|
||||
cu_seq_k: torch.Tensor,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
is_causal: bool,
|
||||
rng_state: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
unused = torch.empty(0, device=query.device)
|
||||
|
||||
use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
|
||||
if use_cudnn:
|
||||
log.info("Using cuDNN backend for varlen_attn")
|
||||
dq, dk, dv = torch.ops.aten._cudnn_attention_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0,
|
||||
is_causal,
|
||||
rng_state,
|
||||
unused,
|
||||
)
|
||||
else:
|
||||
log.info("Using Flash Attention backend for varlen_attn")
|
||||
dq, dk, dv = torch.ops.aten._flash_attention_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0,
|
||||
is_causal,
|
||||
rng_state,
|
||||
unused,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
|
||||
@_varlen_attn_backward.register_fake
|
||||
def _varlen_attn_backward_fake(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
cu_seq_q: torch.Tensor,
|
||||
cu_seq_k: torch.Tensor,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
is_causal: bool,
|
||||
rng_state: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Fake implementation for meta tensor computation and tracing.
|
||||
"""
|
||||
|
||||
grad_query = torch.empty_like(query)
|
||||
grad_key = torch.empty_like(key)
|
||||
grad_value = torch.empty_like(value)
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _backward(
|
||||
ctx: Any, grad_out: torch.Tensor, grad_lse: torch.Tensor, grad_rng: torch.Tensor
|
||||
) -> tuple[Optional[torch.Tensor], ...]:
|
||||
query = ctx.query
|
||||
key = ctx.key
|
||||
value = ctx.value
|
||||
cu_seq_q = ctx.cu_seq_q
|
||||
cu_seq_k = ctx.cu_seq_k
|
||||
max_q = ctx.max_q
|
||||
max_k = ctx.max_k
|
||||
is_causal = ctx.is_causal
|
||||
out = ctx.output
|
||||
lse = ctx.lse
|
||||
rng_state = ctx.rng_state
|
||||
|
||||
# rng_state = torch.empty(2, device=query.device)
|
||||
|
||||
dq, dk, dv = torch.ops.torch_attn._varlen_attn_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
is_causal,
|
||||
rng_state,
|
||||
)
|
||||
return dq, dk, dv, None, None, None, None, None, None
|
||||
|
||||
|
||||
_varlen_attn.register_autograd(_backward, setup_context=_setup_context)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user