Move attention kernels from meta_registrations to fake_impls (#120682)

This PR is mostly just code movement to make the code review easier - AFAIK it should not change any functionality. The final goal is to remove the xfails for some of the test_fake opinfos for these ops. The opinfos are failing because the outputs can have mixed devices - we need to move them to fake_impls first before we can support mixed device returns.

This PR:
* Move the `_meta_registrations.py` implementations to `fake_impls.py`
* Change the function signature from taking explicit named variables to taking `{args, kwargs}` and normalizing them
* Wrap all the returned tensors in FakeTensors

Tests: relying on opinfos. I also checked `test_fake_*` for these tests (by removing x-fails and patching things until they passed) to verify general correctness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120682
Approved by: https://github.com/drisspg
This commit is contained in:
David Berard 2024-02-27 23:06:08 -08:00 committed by PyTorch MergeBot
parent 50073248ed
commit d6c202975c
3 changed files with 272 additions and 218 deletions

View File

@ -5063,67 +5063,6 @@ def meta_scatter_(self, dim, index, src_or_value, reduce=None):
return self
@register_meta(
[
aten._scaled_dot_product_flash_attention,
]
)
def meta__scaled_dot_product_flash(
query: Tensor,
key: Tensor,
value: Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
scale: Optional[float] = None,
):
batch_size = query.size(0)
num_heads = query.size(1)
max_seqlen_batch_q = query.size(2)
head_dim = query.size(3)
max_seqlen_batch_k = key.size(2)
query_t = query.transpose(1, 2)
attention = torch.empty_like(query_t).transpose(1, 2)
logsumexp = torch.empty(
(batch_size, num_heads, max_seqlen_batch_q),
dtype=torch.float,
device=query.device,
)
if return_debug_mask:
blocksize_c = 128 if head_dim > 64 else 256
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
if max_seqlen_batch_k <= 128:
max_seqlen_k = 128
elif max_seqlen_batch_k <= 256:
max_seqlen_k = 256
debug_mask = torch.empty(
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
dtype=query.dtype,
device=query.device,
)
else:
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
# Note [Seed and Offset]: device for seed and offset below depends on whether we are
# capturing or not, but at the time of tracing we don't know if we
# are going to use cudagraphs or not, so we return meta tensors here
# it's possible we'll need to have some special handling in inductor for sdpa
return (
attention,
logsumexp,
None,
None,
max_seqlen_batch_q,
max_seqlen_batch_k,
torch.empty((), dtype=torch.long, device="meta"),
torch.empty((), dtype=torch.long, device="meta"),
debug_mask,
)
@register_meta(
[
aten._scaled_dot_product_flash_attention_backward,
@ -5238,50 +5177,6 @@ def meta__scaled_dot_product_flash_attention_for_cpu_backward(
return grad_q, grad_k, grad_v
@register_meta(
[
aten._scaled_dot_product_efficient_attention,
]
)
def meta__scaled_dot_product_efficient(
query: Tensor,
key: Tensor,
value: Tensor,
attn_bias: Optional[Tensor],
compute_log_sumexp: bool,
dropout_p=0.0,
is_causal: bool = False,
scale: Optional[float] = None,
):
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
B = query.size(0)
M = query.size(1)
N = key.size(1)
num_heads = query.size(-2)
K = query.size(-1)
Kv = value.size(-1)
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
logsum_exp = torch.empty(
(B, num_heads, logsumexp_dim),
dtype=torch.float,
device=query.device,
)
res = res.transpose(1, 2)
# See Note [Seed and Offset]:
seed = torch.empty((), dtype=torch.long, device="meta")
offset = torch.empty((), dtype=torch.long, device="meta")
return res, logsum_exp, seed, offset
@register_meta(
[
aten._scaled_dot_product_efficient_attention_backward,
@ -5342,67 +5237,6 @@ def meta__scaled_dot_product_efficient_backward(
return grad_q, grad_k, grad_v, grad_bias
@register_meta(
[
aten._flash_attention_forward,
]
)
def meta__flash_attention_forward(
query: Tensor,
key: Tensor,
value: Tensor,
cum_seq_q: Optional[Tensor],
cum_seq_k: Optional[Tensor],
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
return_debug_mask: bool,
scale: Optional[float] = None,
):
# NB: there are two underlying paths:
# 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim)
# 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total
# includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total
batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1
max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
num_heads = query.size(-2)
head_dim = query.size(-1)
# Cuda Path
attention = torch.empty_like(query)
logsumexp = torch.empty(
(batch_size, num_heads, max_seqlen_batch_q),
dtype=torch.float,
device=query.device,
)
if return_debug_mask:
blocksize_c = 128 if head_dim > 64 else 256
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
if max_seqlen_batch_k <= 128:
max_seqlen_k = 128
elif max_seqlen_batch_k <= 256:
max_seqlen_k = 256
debug_mask = torch.empty(
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
dtype=query.dtype,
device=query.device,
)
else:
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
# See Note [Seed and Offset]:
return (
attention,
logsumexp,
torch.empty((), dtype=torch.long, device="meta"),
torch.empty((), dtype=torch.long, device="meta"),
debug_mask,
)
@register_meta(
[
aten._flash_attention_backward,
@ -5432,57 +5266,6 @@ def meta__flash_attention_backward(
return grad_query, grad_key, grad_value
@register_meta(
[
aten._efficient_attention_forward,
]
)
def meta__efficient_attention_forward(
query: Tensor,
key: Tensor,
value: Tensor,
bias: Optional[Tensor],
cu_seqlens_q: Optional[Tensor],
cu_seqlens_k: Optional[Tensor],
max_seqlen_q: Optional[int],
max_seqlen_k: Optional[int],
dropout_p: float,
custom_mask_type: int,
compute_log_sumexp: bool = False,
scale: Optional[float] = None,
causal_diagonal: Optional[Tensor] = None,
seqlen_k: Optional[Tensor] = None,
):
B = query.size(0)
M = query.size(1)
N = key.size(1)
num_heads = query.size(-2)
K = query.size(-1)
Kv = value.size(-1)
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B
actual_max_seqlen_q = M
if cu_seqlens_q is not None:
assert max_seqlen_q is not None
actual_max_seqlen_q = max_seqlen_q
logsumexp_dim = (
math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0
)
logsum_exp = torch.empty(
(logsumexp_batch_dim, num_heads, logsumexp_dim),
dtype=torch.float,
device=query.device,
)
# See Note [Seed and Offset]:
seed = torch.empty((), dtype=torch.long, device="meta")
offset = torch.empty((), dtype=torch.long, device="meta")
return res, logsum_exp, seed, offset, M, N
@register_meta(
[
aten._efficient_attention_backward,

View File

@ -2,6 +2,7 @@
import functools
import itertools
import math
import sys
from typing import Callable, Union
@ -595,6 +596,259 @@ def conv(fake_mode, func, *args, **kwargs):
)
@register_op_impl(aten._scaled_dot_product_flash_attention.default)
def meta__scaled_dot_product_flash(fake_mode, func, *args, **kwargs):
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
query = kwargs["query"]
key = kwargs["key"]
return_debug_mask = kwargs["return_debug_mask"]
# unused: value, dropout_p, is_causal, scale
def convert_tensor(t, device):
return FakeTensor(fake_mode, t, device)
batch_size = query.size(0)
num_heads = query.size(1)
max_seqlen_batch_q = query.size(2)
head_dim = query.size(3)
max_seqlen_batch_k = key.size(2)
query_t = query.transpose(1, 2)
# empty_like already returns a fake tensor so we don't need to convert it
attention = torch.empty_like(query_t).transpose(1, 2)
logsumexp = convert_tensor(
torch.empty(
(batch_size, num_heads, max_seqlen_batch_q),
dtype=torch.float,
device="meta",
),
device=query.device,
)
if return_debug_mask:
blocksize_c = 128 if head_dim > 64 else 256
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
if max_seqlen_batch_k <= 128:
max_seqlen_k = 128
elif max_seqlen_batch_k <= 256:
max_seqlen_k = 256
debug_mask = convert_tensor(
torch.empty(
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
dtype=query.dtype,
device="meta",
),
device=query.device,
)
else:
debug_mask = convert_tensor(
torch.empty(0, dtype=query.dtype, device="meta"),
query.device,
)
# Note [Seed and Offset]: device for seed and offset below depends on whether we are
# capturing or not, but at the time of tracing we don't know if we
# are going to use cudagraphs or not, so we return meta tensors here
# it's possible we'll need to have some special handling in inductor for sdpa
return (
attention,
logsumexp,
None,
None,
max_seqlen_batch_q,
max_seqlen_batch_k,
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
debug_mask,
)
@register_op_impl(aten._scaled_dot_product_efficient_attention.default)
def meta__scaled_dot_product_efficient(fake_mode, func, *args, **kwargs):
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
query = kwargs["query"]
key = kwargs["key"]
value = kwargs["value"]
compute_log_sumexp = kwargs["compute_log_sumexp"]
# unused: attn_bias, dropout_p, is_causal, scale
def convert_tensor(t, device):
return FakeTensor(fake_mode, t, device)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
B = query.size(0)
M = query.size(1)
N = key.size(1)
num_heads = query.size(-2)
K = query.size(-1)
Kv = value.size(-1)
res = convert_tensor(
torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"),
query.device,
)
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
logsum_exp = convert_tensor(
torch.empty(
(B, num_heads, logsumexp_dim),
dtype=torch.float,
device="meta",
),
query.device,
)
res = res.transpose(1, 2)
# See Note [Seed and Offset]:
seed = convert_tensor(
torch.empty((), dtype=torch.long, device="meta"), query.device
)
offset = convert_tensor(
torch.empty((), dtype=torch.long, device="meta"), query.device
)
return res, logsum_exp, seed, offset
@register_op_impl(aten._flash_attention_forward.default)
def meta__flash_attention_forward(fake_mode, func, *args, **kwargs):
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
query = kwargs["query"]
key = kwargs["key"]
cum_seq_q = kwargs["cum_seq_q"]
cum_seq_k = kwargs["cum_seq_k"]
max_q = kwargs["max_q"]
max_k = kwargs["max_k"]
return_debug_mask = kwargs["return_debug_mask"]
# unused: value, dropout_p, is_causal, scale
def convert_tensor(t, device):
return FakeTensor(fake_mode, t, device)
# NB: there are two underlying paths:
# 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim)
# 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total
# includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total
batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1
max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
num_heads = query.size(-2)
head_dim = query.size(-1)
# Cuda Path
# note: empty_like already returns a fake tensor, we don't need to wrap it
attention = torch.empty_like(query)
logsumexp = convert_tensor(
torch.empty(
(batch_size, num_heads, max_seqlen_batch_q),
dtype=torch.float,
device="meta",
),
device=query.device,
)
if return_debug_mask:
blocksize_c = 128 if head_dim > 64 else 256
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
if max_seqlen_batch_k <= 128:
max_seqlen_k = 128
elif max_seqlen_batch_k <= 256:
max_seqlen_k = 256
debug_mask = convert_tensor(
torch.empty(
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
dtype=query.dtype,
device="meta",
),
query.device,
)
else:
debug_mask = convert_tensor(
torch.empty(0, dtype=query.dtype, device="meta"),
query.device,
)
# See Note [Seed and Offset]:
return (
attention,
logsumexp,
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
debug_mask,
)
@register_op_impl(aten._efficient_attention_forward.default)
def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs):
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
query = kwargs["query"]
key = kwargs["key"]
value = kwargs["value"]
cu_seqlens_q = kwargs["cu_seqlens_q"]
max_seqlen_q = kwargs["max_seqlen_q"]
compute_log_sumexp = kwargs["compute_log_sumexp"]
# unused: bias, cu_seqlens_k, max_seqlen_k, dropout_p, custom_mask_type, scale, causal_diagonal, seqlen_k
def convert_tensor(t, device):
return FakeTensor(fake_mode, t, device)
B = query.size(0)
M = query.size(1)
N = key.size(1)
num_heads = query.size(-2)
K = query.size(-1)
Kv = value.size(-1)
res = convert_tensor(
torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"),
query.device,
)
logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B
actual_max_seqlen_q = M
if cu_seqlens_q is not None:
assert max_seqlen_q is not None
actual_max_seqlen_q = max_seqlen_q
logsumexp_dim = (
math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0
)
logsum_exp = convert_tensor(
torch.empty(
(logsumexp_batch_dim, num_heads, logsumexp_dim),
dtype=torch.float,
device="meta",
),
query.device,
)
# See Note [Seed and Offset]:
seed = convert_tensor(
torch.empty((), dtype=torch.long, device="meta"), query.device
)
offset = convert_tensor(
torch.empty((), dtype=torch.long, device="meta"), query.device
)
return res, logsum_exp, seed, offset, M, N
FAST_OP_IMPLEMENTATIONS = {}

View File

@ -14428,7 +14428,14 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace',
device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace',
device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),),
device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),
# registered in fake_impls.py instead of _meta_registrations.py
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
dtypes=(torch.bfloat16, torch.float16, torch.float32)),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
dtypes=(torch.bfloat16, torch.float16, torch.float32)),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace_all_strides"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),),
),
OpInfo(
'torch.ops.aten._flash_attention_forward',
@ -14447,6 +14454,11 @@ op_db: List[OpInfo] = [
# Device mismatch due to philox seed and offset
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'),
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'),
# meta implementation is in fake_impls.py instead of being a meta registration
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
# Checking the scalar value of the philox seed and offset
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
@ -14481,6 +14493,11 @@ op_db: List[OpInfo] = [
# Device mismatch due to philox seed and offset
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'),
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'),
# meta implementation is in fake_impls.py instead of being a meta registration
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
# Checking the scaler value of the philox seed and offset
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),