[dtensor] add CuDNN SDPA op support to DTensor (#148537)

### Summary
This PR adds `_scaled_dot_product_cudnn_attention` and `_scaled_dot_product_cudnn_attention_backward` to DTensor ops

### Test
`pytest test/distributed/tensor/test_attention.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148537
Approved by: https://github.com/drisspg, https://github.com/fegin
This commit is contained in:
Xilun Wu 2025-03-05 16:19:10 -08:00 committed by PyTorch MergeBot
parent 3960f97832
commit e2a0296e80
3 changed files with 303 additions and 1 deletions

View File

@ -21,6 +21,7 @@ from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import parallelize_module
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_FUSED_ATTENTION,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
@ -41,6 +42,8 @@ if PLATFORM_SUPPORTS_FLASH_ATTENTION:
backends.append(SDPBackend.FLASH_ATTENTION)
if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
backends.append(SDPBackend.EFFICIENT_ATTENTION)
if PLATFORM_SUPPORTS_CUDNN_ATTENTION:
backends.append(SDPBackend.CUDNN_ATTENTION)
rotater_enum_to_str = {
_RotateMethod.ALL_GATHER: "allgather",
@ -109,7 +112,10 @@ class RingAttentionTest(DTensorTestBase):
nheads = 8
torch.manual_seed(10)
dtype = (
torch.bfloat16 if backend == SDPBackend.FLASH_ATTENTION else torch.float32
torch.bfloat16
if backend == SDPBackend.FLASH_ATTENTION
or backend == SDPBackend.CUDNN_ATTENTION
else torch.float32
)
_cp_options.enable_load_balance = load_balance

View File

@ -2,6 +2,8 @@
# implement matrix related ops for distributed tensor
from typing import Optional
import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
@ -570,3 +572,201 @@ def scaled_dot_product_efficient_attention_backward_strategy(
single_mesh_dim_strategies,
input_index=4,
)
@register_op_strategy(
aten._scaled_dot_product_cudnn_attention.default,
schema_info=RuntimeSchemaInfo(4),
)
def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrategy:
mesh = op_schema.get_mesh_from_args()
(
query_strategy, # query
_, # key
_, # value
attn_bias_strategy,
compute_log_sumexp, # compute_log_sumexp
*rest_args, # optional args: dropout_p, is_causal, return_debug_mask, scale
) = op_schema.args_schema
return_debug_mask = len(op_schema.args_schema) >= 8 and rest_args[2]
has_attn_bias = attn_bias_strategy is not None
debug_attn_mask_sharding: Optional[Placement] = (
Replicate() if return_debug_mask else None
)
assert isinstance(query_strategy, OpStrategy)
# assuming q/k/v have the same shape
single_mesh_dim_strategies = []
# placement list stores placements of [outputs, inputs]
# in the spda case, we have 2 valid tensor outputs and 3 tensor inputs
# first we can always accept full replication for both inputs and outputs
all_replicate: PlacementList = [
Replicate(), # output
Replicate(), # logsumexp
None, # cum_seq_q
None, # cum_seq_k
None, # max_q
None, # max_k
None, # philox_seed
None, # philox_offset
# NOTE: debug_attn_mask is not supproted by pytorch and is always an empty tensor
# https://github.com/pytorch/pytorch/blob/60205b0eb2602317856312a66d955c88334ade0b/aten/src/ATen/native/transformers/cuda/attention.cu#L839-L840
debug_attn_mask_sharding, # debug_attn_mask
Replicate(), # q
Replicate(), # k
Replicate(), # v
]
if has_attn_bias:
all_replicate.append(Replicate()) # attn bias
single_mesh_dim_strategies.append(all_replicate)
# second we can accept the sharding pattern of tensor parallelism, which
# shard on the num of head dim
tp_sharding = Shard(1) # num head dim
qkv_sharding = tp_sharding
output_sharding = tp_sharding
logsumexp_sharding = tp_sharding if compute_log_sumexp else Replicate()
debug_attn_mask_sharding = tp_sharding if return_debug_mask else None
num_heads_dim_sharding: PlacementList = [
output_sharding,
logsumexp_sharding,
None, # cum_seq_q
None, # cum_seq_k
None, # max_q
None, # max_k
None, # philox_seed
None, # philox_offset
debug_attn_mask_sharding,
qkv_sharding,
qkv_sharding,
qkv_sharding,
]
single_mesh_dim_strategies.append(num_heads_dim_sharding)
# Context Parallelism: shards on the sequence dim
cp_sharding = Shard(2) # seq dim
logsumexp_sharding = cp_sharding if compute_log_sumexp else Replicate()
debug_attn_mask_sharding = cp_sharding if return_debug_mask else None
single_mesh_dim_strategies.append(
[
cp_sharding, # output
logsumexp_sharding, # logsumexp
None, # cum_seq_q
None, # cum_seq_k
None, # max_q
None, # max_k
None, # philox_seed
None, # philox_offset
debug_attn_mask_sharding, # debug_attn_mask
cp_sharding, # q
cp_sharding, # k
cp_sharding, # v
]
)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=9
)
@register_op_strategy(aten._scaled_dot_product_cudnn_attention_backward.default)
def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
op_schema: OpSchema,
) -> OpStrategy:
# backward op does not need to validate the mesh since forward op has already done it
mesh = op_schema.get_mesh_from_args(validate=False)
assert len(op_schema.args_schema) >= 15
has_attn_bias = op_schema.args_schema[8] is not None
has_scale = len(op_schema.args_schema) >= 16 and False
query_strategy = op_schema.args_schema[1]
assert isinstance(query_strategy, OpStrategy)
# assuming q/k/v have the same shape
single_mesh_dim_strategies = []
# placement list stores placements of [outputs, inputs]
# cudnn outputs: (Tensor dq, Tensor dk, Tensor dv)
# cudnn inputs: (
# Tensor grad_out,
# Tensor query,
# Tensor key,
# Tensor value,
# Tensor out,
# Tensor logsumexp,
# Tensor philox_seed,
# Tensor philox_offset,
# Tensor attn_bias,
# Tensor cum_seq_q,
# Tensor cum_seq_k,
# SymInt max_q,
# SymInt max_k,
# float dropout_p,
# bool is_causal,
# int? scale,
# )
# case 1: we can always accept full replication for both inputs and outputs
all_replicate_out: PlacementList = [
Replicate(), # dq
Replicate(), # dk
Replicate(), # dv
]
all_replicate_inp: PlacementList = [Replicate()] * 6
all_replicate_inp += [
Replicate()
] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor
all_replicate_inp += [Replicate() if has_attn_bias else None]
all_replicate_inp += [None] * 6
if has_scale:
all_replicate_inp.append(None)
all_replicate: PlacementList = all_replicate_out + all_replicate_inp
single_mesh_dim_strategies.append(all_replicate)
# case 2: we can accept the sharding pattern of tensor parallelism, which
# shards on the num of head dim
qkv_sharding = Shard(1) # num head dim
output_sharding = Shard(1) # num head dim
logsumexp_sharding = Shard(1) # num head dim
num_heads_dim_sharding_out: PlacementList = [qkv_sharding] * 3
num_heads_dim_sharding_inp: PlacementList = [qkv_sharding] * 4
num_heads_dim_sharding_inp += [output_sharding]
num_heads_dim_sharding_inp += [logsumexp_sharding]
num_heads_dim_sharding_inp += [
Replicate()
] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor
num_heads_dim_sharding_inp += [Shard(1) if has_attn_bias else None]
num_heads_dim_sharding_inp += [None] * 6
if has_scale:
num_heads_dim_sharding_inp.append(None)
num_heads_dim_sharding = num_heads_dim_sharding_out + num_heads_dim_sharding_inp
single_mesh_dim_strategies.append(num_heads_dim_sharding)
# case 3: Context Parallelism which shards on the sequence dim
context_parallel_sharding_out: PlacementList = [Shard(2)] * 3
context_parallel_sharding_inp: PlacementList = [Shard(2)] * 6
context_parallel_sharding_inp += [
Replicate()
] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor
context_parallel_sharding_inp += [Shard(2) if has_attn_bias else None]
context_parallel_sharding_inp += [None] * 6
if has_scale:
context_parallel_sharding_inp.append(None)
context_parallel_sharding = (
context_parallel_sharding_out + context_parallel_sharding_inp
)
single_mesh_dim_strategies.append(context_parallel_sharding)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=3
)

View File

@ -246,6 +246,43 @@ def _scaled_dot_product_ring_efficient_attention(
)
def _scaled_dot_product_ring_cudnn_attention(
mesh: DeviceMesh,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
compute_log_sumexp: bool = True,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: Optional[float] = None,
) -> tuple[torch.Tensor, ...]:
if attn_bias is not None:
raise NotImplementedError("attn_bias is not supported yet")
if not compute_log_sumexp:
# CP requires compute_log_sumexp to be True because it always merges LSE
compute_log_sumexp = True
seq_dim = 2
return _templated_ring_attention(
mesh,
seq_dim,
aten._scaled_dot_product_cudnn_attention,
query=query,
key=key,
value=value,
attn_bias=attn_bias,
compute_log_sumexp=compute_log_sumexp,
dropout_p=dropout_p,
is_causal=is_causal,
return_debug_mask=return_debug_mask,
scale=scale,
)
class _AttentionOp(Protocol):
def __call__(
self,
@ -545,6 +582,12 @@ def _sdpa_handler(
*op_info.local_args, # type: ignore[arg-type]
**op_info.local_kwargs, # type: ignore[arg-type]
)
elif op_call == aten._scaled_dot_product_cudnn_attention.default:
local_results = _scaled_dot_product_ring_cudnn_attention(
op_info.compute_mesh,
*op_info.local_args, # type: ignore[arg-type]
**op_info.local_kwargs, # type: ignore[arg-type]
)
else:
raise NotImplementedError(
"CP only supports flash attention and memory efficient attention now."
@ -584,6 +627,12 @@ def _sdpa_backward_handler(
*op_info.local_args, # type: ignore[arg-type]
**op_info.local_kwargs, # type: ignore[arg-type]
)
elif op_call == aten._scaled_dot_product_cudnn_attention_backward.default:
local_results = _scaled_dot_product_ring_cudnn_attention_backward(
op_info.compute_mesh,
*op_info.local_args, # type: ignore[arg-type]
**op_info.local_kwargs, # type: ignore[arg-type]
)
else:
raise NotImplementedError(f"{op_call=}")
@ -841,11 +890,58 @@ def _scaled_dot_product_ring_efficient_attention_backward(
)
def _scaled_dot_product_ring_cudnn_attention_backward(
mesh: DeviceMesh,
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
philox_seed: torch.Tensor,
philox_offset: torch.Tensor,
attn_bias: torch.Tensor,
cum_seq_q: torch.Tensor,
cum_seq_k: torch.Tensor,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
*,
scale: Optional[float] = None,
) -> tuple[torch.Tensor, ...]:
seq_dim = 2
return _templated_ring_attention_backward(
mesh,
seq_dim,
aten._scaled_dot_product_cudnn_attention_backward.default,
grad_out=grad_out,
grad_out_name="grad_out",
query=query,
key=key,
value=value,
out=out,
logsumexp=logsumexp,
philox_seed=philox_seed,
philox_offset=philox_offset,
attn_bias=attn_bias,
cum_seq_q=cum_seq_q,
cum_seq_k=cum_seq_k,
max_q=max_q,
max_k=max_k,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
customized_ops = {
aten._scaled_dot_product_flash_attention.default: _sdpa_handler,
aten._scaled_dot_product_flash_attention_backward.default: _sdpa_backward_handler,
aten._scaled_dot_product_efficient_attention.default: _sdpa_handler,
aten._scaled_dot_product_efficient_attention_backward.default: _sdpa_backward_handler,
aten._scaled_dot_product_cudnn_attention.default: _sdpa_handler,
aten._scaled_dot_product_cudnn_attention_backward.default: _sdpa_backward_handler,
}