mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dtensor] add CuDNN SDPA op support to DTensor (#148537)
### Summary This PR adds `_scaled_dot_product_cudnn_attention` and `_scaled_dot_product_cudnn_attention_backward` to DTensor ops ### Test `pytest test/distributed/tensor/test_attention.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/148537 Approved by: https://github.com/drisspg, https://github.com/fegin
This commit is contained in:
parent
3960f97832
commit
e2a0296e80
|
|
@ -21,6 +21,7 @@ from torch.distributed.tensor.debug import CommDebugMode
|
|||
from torch.distributed.tensor.parallel import parallelize_module
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||
|
|
@ -41,6 +42,8 @@ if PLATFORM_SUPPORTS_FLASH_ATTENTION:
|
|||
backends.append(SDPBackend.FLASH_ATTENTION)
|
||||
if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
|
||||
backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
||||
if PLATFORM_SUPPORTS_CUDNN_ATTENTION:
|
||||
backends.append(SDPBackend.CUDNN_ATTENTION)
|
||||
|
||||
rotater_enum_to_str = {
|
||||
_RotateMethod.ALL_GATHER: "allgather",
|
||||
|
|
@ -109,7 +112,10 @@ class RingAttentionTest(DTensorTestBase):
|
|||
nheads = 8
|
||||
torch.manual_seed(10)
|
||||
dtype = (
|
||||
torch.bfloat16 if backend == SDPBackend.FLASH_ATTENTION else torch.float32
|
||||
torch.bfloat16
|
||||
if backend == SDPBackend.FLASH_ATTENTION
|
||||
or backend == SDPBackend.CUDNN_ATTENTION
|
||||
else torch.float32
|
||||
)
|
||||
|
||||
_cp_options.enable_load_balance = load_balance
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
# implement matrix related ops for distributed tensor
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
||||
|
|
@ -570,3 +572,201 @@ def scaled_dot_product_efficient_attention_backward_strategy(
|
|||
single_mesh_dim_strategies,
|
||||
input_index=4,
|
||||
)
|
||||
|
||||
|
||||
@register_op_strategy(
|
||||
aten._scaled_dot_product_cudnn_attention.default,
|
||||
schema_info=RuntimeSchemaInfo(4),
|
||||
)
|
||||
def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
mesh = op_schema.get_mesh_from_args()
|
||||
|
||||
(
|
||||
query_strategy, # query
|
||||
_, # key
|
||||
_, # value
|
||||
attn_bias_strategy,
|
||||
compute_log_sumexp, # compute_log_sumexp
|
||||
*rest_args, # optional args: dropout_p, is_causal, return_debug_mask, scale
|
||||
) = op_schema.args_schema
|
||||
return_debug_mask = len(op_schema.args_schema) >= 8 and rest_args[2]
|
||||
has_attn_bias = attn_bias_strategy is not None
|
||||
debug_attn_mask_sharding: Optional[Placement] = (
|
||||
Replicate() if return_debug_mask else None
|
||||
)
|
||||
|
||||
assert isinstance(query_strategy, OpStrategy)
|
||||
# assuming q/k/v have the same shape
|
||||
|
||||
single_mesh_dim_strategies = []
|
||||
|
||||
# placement list stores placements of [outputs, inputs]
|
||||
# in the spda case, we have 2 valid tensor outputs and 3 tensor inputs
|
||||
# first we can always accept full replication for both inputs and outputs
|
||||
all_replicate: PlacementList = [
|
||||
Replicate(), # output
|
||||
Replicate(), # logsumexp
|
||||
None, # cum_seq_q
|
||||
None, # cum_seq_k
|
||||
None, # max_q
|
||||
None, # max_k
|
||||
None, # philox_seed
|
||||
None, # philox_offset
|
||||
# NOTE: debug_attn_mask is not supproted by pytorch and is always an empty tensor
|
||||
# https://github.com/pytorch/pytorch/blob/60205b0eb2602317856312a66d955c88334ade0b/aten/src/ATen/native/transformers/cuda/attention.cu#L839-L840
|
||||
debug_attn_mask_sharding, # debug_attn_mask
|
||||
Replicate(), # q
|
||||
Replicate(), # k
|
||||
Replicate(), # v
|
||||
]
|
||||
if has_attn_bias:
|
||||
all_replicate.append(Replicate()) # attn bias
|
||||
|
||||
single_mesh_dim_strategies.append(all_replicate)
|
||||
|
||||
# second we can accept the sharding pattern of tensor parallelism, which
|
||||
# shard on the num of head dim
|
||||
tp_sharding = Shard(1) # num head dim
|
||||
qkv_sharding = tp_sharding
|
||||
output_sharding = tp_sharding
|
||||
logsumexp_sharding = tp_sharding if compute_log_sumexp else Replicate()
|
||||
debug_attn_mask_sharding = tp_sharding if return_debug_mask else None
|
||||
|
||||
num_heads_dim_sharding: PlacementList = [
|
||||
output_sharding,
|
||||
logsumexp_sharding,
|
||||
None, # cum_seq_q
|
||||
None, # cum_seq_k
|
||||
None, # max_q
|
||||
None, # max_k
|
||||
None, # philox_seed
|
||||
None, # philox_offset
|
||||
debug_attn_mask_sharding,
|
||||
qkv_sharding,
|
||||
qkv_sharding,
|
||||
qkv_sharding,
|
||||
]
|
||||
single_mesh_dim_strategies.append(num_heads_dim_sharding)
|
||||
|
||||
# Context Parallelism: shards on the sequence dim
|
||||
cp_sharding = Shard(2) # seq dim
|
||||
logsumexp_sharding = cp_sharding if compute_log_sumexp else Replicate()
|
||||
debug_attn_mask_sharding = cp_sharding if return_debug_mask else None
|
||||
|
||||
single_mesh_dim_strategies.append(
|
||||
[
|
||||
cp_sharding, # output
|
||||
logsumexp_sharding, # logsumexp
|
||||
None, # cum_seq_q
|
||||
None, # cum_seq_k
|
||||
None, # max_q
|
||||
None, # max_k
|
||||
None, # philox_seed
|
||||
None, # philox_offset
|
||||
debug_attn_mask_sharding, # debug_attn_mask
|
||||
cp_sharding, # q
|
||||
cp_sharding, # k
|
||||
cp_sharding, # v
|
||||
]
|
||||
)
|
||||
return expand_to_full_mesh_op_strategy(
|
||||
mesh, op_schema, single_mesh_dim_strategies, input_index=9
|
||||
)
|
||||
|
||||
|
||||
@register_op_strategy(aten._scaled_dot_product_cudnn_attention_backward.default)
|
||||
def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
|
||||
op_schema: OpSchema,
|
||||
) -> OpStrategy:
|
||||
# backward op does not need to validate the mesh since forward op has already done it
|
||||
mesh = op_schema.get_mesh_from_args(validate=False)
|
||||
|
||||
assert len(op_schema.args_schema) >= 15
|
||||
has_attn_bias = op_schema.args_schema[8] is not None
|
||||
has_scale = len(op_schema.args_schema) >= 16 and False
|
||||
|
||||
query_strategy = op_schema.args_schema[1]
|
||||
assert isinstance(query_strategy, OpStrategy)
|
||||
# assuming q/k/v have the same shape
|
||||
|
||||
single_mesh_dim_strategies = []
|
||||
|
||||
# placement list stores placements of [outputs, inputs]
|
||||
# cudnn outputs: (Tensor dq, Tensor dk, Tensor dv)
|
||||
# cudnn inputs: (
|
||||
# Tensor grad_out,
|
||||
# Tensor query,
|
||||
# Tensor key,
|
||||
# Tensor value,
|
||||
# Tensor out,
|
||||
# Tensor logsumexp,
|
||||
# Tensor philox_seed,
|
||||
# Tensor philox_offset,
|
||||
# Tensor attn_bias,
|
||||
# Tensor cum_seq_q,
|
||||
# Tensor cum_seq_k,
|
||||
# SymInt max_q,
|
||||
# SymInt max_k,
|
||||
# float dropout_p,
|
||||
# bool is_causal,
|
||||
# int? scale,
|
||||
# )
|
||||
|
||||
# case 1: we can always accept full replication for both inputs and outputs
|
||||
all_replicate_out: PlacementList = [
|
||||
Replicate(), # dq
|
||||
Replicate(), # dk
|
||||
Replicate(), # dv
|
||||
]
|
||||
all_replicate_inp: PlacementList = [Replicate()] * 6
|
||||
all_replicate_inp += [
|
||||
Replicate()
|
||||
] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor
|
||||
all_replicate_inp += [Replicate() if has_attn_bias else None]
|
||||
all_replicate_inp += [None] * 6
|
||||
if has_scale:
|
||||
all_replicate_inp.append(None)
|
||||
|
||||
all_replicate: PlacementList = all_replicate_out + all_replicate_inp
|
||||
single_mesh_dim_strategies.append(all_replicate)
|
||||
|
||||
# case 2: we can accept the sharding pattern of tensor parallelism, which
|
||||
# shards on the num of head dim
|
||||
qkv_sharding = Shard(1) # num head dim
|
||||
output_sharding = Shard(1) # num head dim
|
||||
logsumexp_sharding = Shard(1) # num head dim
|
||||
|
||||
num_heads_dim_sharding_out: PlacementList = [qkv_sharding] * 3
|
||||
num_heads_dim_sharding_inp: PlacementList = [qkv_sharding] * 4
|
||||
num_heads_dim_sharding_inp += [output_sharding]
|
||||
num_heads_dim_sharding_inp += [logsumexp_sharding]
|
||||
num_heads_dim_sharding_inp += [
|
||||
Replicate()
|
||||
] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor
|
||||
num_heads_dim_sharding_inp += [Shard(1) if has_attn_bias else None]
|
||||
num_heads_dim_sharding_inp += [None] * 6
|
||||
if has_scale:
|
||||
num_heads_dim_sharding_inp.append(None)
|
||||
|
||||
num_heads_dim_sharding = num_heads_dim_sharding_out + num_heads_dim_sharding_inp
|
||||
single_mesh_dim_strategies.append(num_heads_dim_sharding)
|
||||
|
||||
# case 3: Context Parallelism which shards on the sequence dim
|
||||
context_parallel_sharding_out: PlacementList = [Shard(2)] * 3
|
||||
context_parallel_sharding_inp: PlacementList = [Shard(2)] * 6
|
||||
context_parallel_sharding_inp += [
|
||||
Replicate()
|
||||
] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor
|
||||
context_parallel_sharding_inp += [Shard(2) if has_attn_bias else None]
|
||||
context_parallel_sharding_inp += [None] * 6
|
||||
if has_scale:
|
||||
context_parallel_sharding_inp.append(None)
|
||||
|
||||
context_parallel_sharding = (
|
||||
context_parallel_sharding_out + context_parallel_sharding_inp
|
||||
)
|
||||
single_mesh_dim_strategies.append(context_parallel_sharding)
|
||||
|
||||
return expand_to_full_mesh_op_strategy(
|
||||
mesh, op_schema, single_mesh_dim_strategies, input_index=3
|
||||
)
|
||||
|
|
|
|||
|
|
@ -246,6 +246,43 @@ def _scaled_dot_product_ring_efficient_attention(
|
|||
)
|
||||
|
||||
|
||||
def _scaled_dot_product_ring_cudnn_attention(
|
||||
mesh: DeviceMesh,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[torch.Tensor] = None,
|
||||
compute_log_sumexp: bool = True,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
return_debug_mask: bool = False,
|
||||
*,
|
||||
scale: Optional[float] = None,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if attn_bias is not None:
|
||||
raise NotImplementedError("attn_bias is not supported yet")
|
||||
|
||||
if not compute_log_sumexp:
|
||||
# CP requires compute_log_sumexp to be True because it always merges LSE
|
||||
compute_log_sumexp = True
|
||||
|
||||
seq_dim = 2
|
||||
return _templated_ring_attention(
|
||||
mesh,
|
||||
seq_dim,
|
||||
aten._scaled_dot_product_cudnn_attention,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_bias=attn_bias,
|
||||
compute_log_sumexp=compute_log_sumexp,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
return_debug_mask=return_debug_mask,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
|
||||
class _AttentionOp(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
|
|
@ -545,6 +582,12 @@ def _sdpa_handler(
|
|||
*op_info.local_args, # type: ignore[arg-type]
|
||||
**op_info.local_kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
elif op_call == aten._scaled_dot_product_cudnn_attention.default:
|
||||
local_results = _scaled_dot_product_ring_cudnn_attention(
|
||||
op_info.compute_mesh,
|
||||
*op_info.local_args, # type: ignore[arg-type]
|
||||
**op_info.local_kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"CP only supports flash attention and memory efficient attention now."
|
||||
|
|
@ -584,6 +627,12 @@ def _sdpa_backward_handler(
|
|||
*op_info.local_args, # type: ignore[arg-type]
|
||||
**op_info.local_kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
elif op_call == aten._scaled_dot_product_cudnn_attention_backward.default:
|
||||
local_results = _scaled_dot_product_ring_cudnn_attention_backward(
|
||||
op_info.compute_mesh,
|
||||
*op_info.local_args, # type: ignore[arg-type]
|
||||
**op_info.local_kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{op_call=}")
|
||||
|
||||
|
|
@ -841,11 +890,58 @@ def _scaled_dot_product_ring_efficient_attention_backward(
|
|||
)
|
||||
|
||||
|
||||
def _scaled_dot_product_ring_cudnn_attention_backward(
|
||||
mesh: DeviceMesh,
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
logsumexp: torch.Tensor,
|
||||
philox_seed: torch.Tensor,
|
||||
philox_offset: torch.Tensor,
|
||||
attn_bias: torch.Tensor,
|
||||
cum_seq_q: torch.Tensor,
|
||||
cum_seq_k: torch.Tensor,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
*,
|
||||
scale: Optional[float] = None,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
seq_dim = 2
|
||||
return _templated_ring_attention_backward(
|
||||
mesh,
|
||||
seq_dim,
|
||||
aten._scaled_dot_product_cudnn_attention_backward.default,
|
||||
grad_out=grad_out,
|
||||
grad_out_name="grad_out",
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
out=out,
|
||||
logsumexp=logsumexp,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset=philox_offset,
|
||||
attn_bias=attn_bias,
|
||||
cum_seq_q=cum_seq_q,
|
||||
cum_seq_k=cum_seq_k,
|
||||
max_q=max_q,
|
||||
max_k=max_k,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
|
||||
customized_ops = {
|
||||
aten._scaled_dot_product_flash_attention.default: _sdpa_handler,
|
||||
aten._scaled_dot_product_flash_attention_backward.default: _sdpa_backward_handler,
|
||||
aten._scaled_dot_product_efficient_attention.default: _sdpa_handler,
|
||||
aten._scaled_dot_product_efficient_attention_backward.default: _sdpa_backward_handler,
|
||||
aten._scaled_dot_product_cudnn_attention.default: _sdpa_handler,
|
||||
aten._scaled_dot_product_cudnn_attention_backward.default: _sdpa_backward_handler,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user