diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 79ad989ca8a..5d20be29448 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5063,67 +5063,6 @@ def meta_scatter_(self, dim, index, src_or_value, reduce=None): return self -@register_meta( - [ - aten._scaled_dot_product_flash_attention, - ] -) -def meta__scaled_dot_product_flash( - query: Tensor, - key: Tensor, - value: Tensor, - dropout_p: float = 0.0, - is_causal: bool = False, - return_debug_mask: bool = False, - scale: Optional[float] = None, -): - batch_size = query.size(0) - num_heads = query.size(1) - max_seqlen_batch_q = query.size(2) - head_dim = query.size(3) - max_seqlen_batch_k = key.size(2) - - query_t = query.transpose(1, 2) - attention = torch.empty_like(query_t).transpose(1, 2) - logsumexp = torch.empty( - (batch_size, num_heads, max_seqlen_batch_q), - dtype=torch.float, - device=query.device, - ) - - if return_debug_mask: - blocksize_c = 128 if head_dim > 64 else 256 - max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) - if max_seqlen_batch_k <= 128: - max_seqlen_k = 128 - elif max_seqlen_batch_k <= 256: - max_seqlen_k = 256 - debug_mask = torch.empty( - (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), - dtype=query.dtype, - device=query.device, - ) - else: - debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) - - # Note [Seed and Offset]: device for seed and offset below depends on whether we are - # capturing or not, but at the time of tracing we don't know if we - # are going to use cudagraphs or not, so we return meta tensors here - # it's possible we'll need to have some special handling in inductor for sdpa - - return ( - attention, - logsumexp, - None, - None, - max_seqlen_batch_q, - max_seqlen_batch_k, - torch.empty((), dtype=torch.long, device="meta"), - torch.empty((), dtype=torch.long, device="meta"), - debug_mask, - ) - - @register_meta( [ aten._scaled_dot_product_flash_attention_backward, @@ -5238,50 +5177,6 @@ def meta__scaled_dot_product_flash_attention_for_cpu_backward( return grad_q, grad_k, grad_v -@register_meta( - [ - aten._scaled_dot_product_efficient_attention, - ] -) -def meta__scaled_dot_product_efficient( - query: Tensor, - key: Tensor, - value: Tensor, - attn_bias: Optional[Tensor], - compute_log_sumexp: bool, - dropout_p=0.0, - is_causal: bool = False, - scale: Optional[float] = None, -): - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - B = query.size(0) - M = query.size(1) - N = key.size(1) - num_heads = query.size(-2) - K = query.size(-1) - Kv = value.size(-1) - - res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device) - - logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0 - logsum_exp = torch.empty( - (B, num_heads, logsumexp_dim), - dtype=torch.float, - device=query.device, - ) - - res = res.transpose(1, 2) - - # See Note [Seed and Offset]: - seed = torch.empty((), dtype=torch.long, device="meta") - offset = torch.empty((), dtype=torch.long, device="meta") - - return res, logsum_exp, seed, offset - - @register_meta( [ aten._scaled_dot_product_efficient_attention_backward, @@ -5342,67 +5237,6 @@ def meta__scaled_dot_product_efficient_backward( return grad_q, grad_k, grad_v, grad_bias -@register_meta( - [ - aten._flash_attention_forward, - ] -) -def meta__flash_attention_forward( - query: Tensor, - key: Tensor, - value: Tensor, - cum_seq_q: Optional[Tensor], - cum_seq_k: Optional[Tensor], - max_q: int, - max_k: int, - dropout_p: float, - is_causal: bool, - return_debug_mask: bool, - scale: Optional[float] = None, -): - # NB: there are two underlying paths: - # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim) - # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total - # includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total - batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1 - max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q - max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k - num_heads = query.size(-2) - head_dim = query.size(-1) - - # Cuda Path - attention = torch.empty_like(query) - logsumexp = torch.empty( - (batch_size, num_heads, max_seqlen_batch_q), - dtype=torch.float, - device=query.device, - ) - - if return_debug_mask: - blocksize_c = 128 if head_dim > 64 else 256 - max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) - if max_seqlen_batch_k <= 128: - max_seqlen_k = 128 - elif max_seqlen_batch_k <= 256: - max_seqlen_k = 256 - debug_mask = torch.empty( - (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), - dtype=query.dtype, - device=query.device, - ) - else: - debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) - - # See Note [Seed and Offset]: - return ( - attention, - logsumexp, - torch.empty((), dtype=torch.long, device="meta"), - torch.empty((), dtype=torch.long, device="meta"), - debug_mask, - ) - - @register_meta( [ aten._flash_attention_backward, @@ -5432,57 +5266,6 @@ def meta__flash_attention_backward( return grad_query, grad_key, grad_value -@register_meta( - [ - aten._efficient_attention_forward, - ] -) -def meta__efficient_attention_forward( - query: Tensor, - key: Tensor, - value: Tensor, - bias: Optional[Tensor], - cu_seqlens_q: Optional[Tensor], - cu_seqlens_k: Optional[Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], - dropout_p: float, - custom_mask_type: int, - compute_log_sumexp: bool = False, - scale: Optional[float] = None, - causal_diagonal: Optional[Tensor] = None, - seqlen_k: Optional[Tensor] = None, -): - B = query.size(0) - M = query.size(1) - N = key.size(1) - num_heads = query.size(-2) - K = query.size(-1) - Kv = value.size(-1) - - res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device) - - logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B - actual_max_seqlen_q = M - if cu_seqlens_q is not None: - assert max_seqlen_q is not None - actual_max_seqlen_q = max_seqlen_q - logsumexp_dim = ( - math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0 - ) - logsum_exp = torch.empty( - (logsumexp_batch_dim, num_heads, logsumexp_dim), - dtype=torch.float, - device=query.device, - ) - - # See Note [Seed and Offset]: - seed = torch.empty((), dtype=torch.long, device="meta") - offset = torch.empty((), dtype=torch.long, device="meta") - - return res, logsum_exp, seed, offset, M, N - - @register_meta( [ aten._efficient_attention_backward, diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 0b08d98d76b..0512e502e31 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -2,6 +2,7 @@ import functools import itertools +import math import sys from typing import Callable, Union @@ -595,6 +596,259 @@ def conv(fake_mode, func, *args, **kwargs): ) +@register_op_impl(aten._scaled_dot_product_flash_attention.default) +def meta__scaled_dot_product_flash(fake_mode, func, *args, **kwargs): + _, kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + query = kwargs["query"] + key = kwargs["key"] + return_debug_mask = kwargs["return_debug_mask"] + # unused: value, dropout_p, is_causal, scale + + def convert_tensor(t, device): + return FakeTensor(fake_mode, t, device) + + batch_size = query.size(0) + num_heads = query.size(1) + max_seqlen_batch_q = query.size(2) + head_dim = query.size(3) + max_seqlen_batch_k = key.size(2) + + query_t = query.transpose(1, 2) + # empty_like already returns a fake tensor so we don't need to convert it + attention = torch.empty_like(query_t).transpose(1, 2) + logsumexp = convert_tensor( + torch.empty( + (batch_size, num_heads, max_seqlen_batch_q), + dtype=torch.float, + device="meta", + ), + device=query.device, + ) + + if return_debug_mask: + blocksize_c = 128 if head_dim > 64 else 256 + max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) + if max_seqlen_batch_k <= 128: + max_seqlen_k = 128 + elif max_seqlen_batch_k <= 256: + max_seqlen_k = 256 + debug_mask = convert_tensor( + torch.empty( + (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), + dtype=query.dtype, + device="meta", + ), + device=query.device, + ) + else: + debug_mask = convert_tensor( + torch.empty(0, dtype=query.dtype, device="meta"), + query.device, + ) + + # Note [Seed and Offset]: device for seed and offset below depends on whether we are + # capturing or not, but at the time of tracing we don't know if we + # are going to use cudagraphs or not, so we return meta tensors here + # it's possible we'll need to have some special handling in inductor for sdpa + + return ( + attention, + logsumexp, + None, + None, + max_seqlen_batch_q, + max_seqlen_batch_k, + convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), + convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), + debug_mask, + ) + + +@register_op_impl(aten._scaled_dot_product_efficient_attention.default) +def meta__scaled_dot_product_efficient(fake_mode, func, *args, **kwargs): + _, kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + query = kwargs["query"] + key = kwargs["key"] + value = kwargs["value"] + compute_log_sumexp = kwargs["compute_log_sumexp"] + # unused: attn_bias, dropout_p, is_causal, scale + + def convert_tensor(t, device): + return FakeTensor(fake_mode, t, device) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + B = query.size(0) + M = query.size(1) + N = key.size(1) + num_heads = query.size(-2) + K = query.size(-1) + Kv = value.size(-1) + + res = convert_tensor( + torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"), + query.device, + ) + + logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0 + logsum_exp = convert_tensor( + torch.empty( + (B, num_heads, logsumexp_dim), + dtype=torch.float, + device="meta", + ), + query.device, + ) + + res = res.transpose(1, 2) + + # See Note [Seed and Offset]: + seed = convert_tensor( + torch.empty((), dtype=torch.long, device="meta"), query.device + ) + offset = convert_tensor( + torch.empty((), dtype=torch.long, device="meta"), query.device + ) + + return res, logsum_exp, seed, offset + + +@register_op_impl(aten._flash_attention_forward.default) +def meta__flash_attention_forward(fake_mode, func, *args, **kwargs): + _, kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + query = kwargs["query"] + key = kwargs["key"] + cum_seq_q = kwargs["cum_seq_q"] + cum_seq_k = kwargs["cum_seq_k"] + max_q = kwargs["max_q"] + max_k = kwargs["max_k"] + return_debug_mask = kwargs["return_debug_mask"] + # unused: value, dropout_p, is_causal, scale + + def convert_tensor(t, device): + return FakeTensor(fake_mode, t, device) + + # NB: there are two underlying paths: + # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim) + # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total + # includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total + batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1 + max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q + max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k + num_heads = query.size(-2) + head_dim = query.size(-1) + + # Cuda Path + # note: empty_like already returns a fake tensor, we don't need to wrap it + attention = torch.empty_like(query) + logsumexp = convert_tensor( + torch.empty( + (batch_size, num_heads, max_seqlen_batch_q), + dtype=torch.float, + device="meta", + ), + device=query.device, + ) + + if return_debug_mask: + blocksize_c = 128 if head_dim > 64 else 256 + max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) + if max_seqlen_batch_k <= 128: + max_seqlen_k = 128 + elif max_seqlen_batch_k <= 256: + max_seqlen_k = 256 + debug_mask = convert_tensor( + torch.empty( + (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), + dtype=query.dtype, + device="meta", + ), + query.device, + ) + else: + debug_mask = convert_tensor( + torch.empty(0, dtype=query.dtype, device="meta"), + query.device, + ) + + # See Note [Seed and Offset]: + return ( + attention, + logsumexp, + convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), + convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), + debug_mask, + ) + + +@register_op_impl(aten._efficient_attention_forward.default) +def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs): + _, kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + query = kwargs["query"] + key = kwargs["key"] + value = kwargs["value"] + cu_seqlens_q = kwargs["cu_seqlens_q"] + max_seqlen_q = kwargs["max_seqlen_q"] + compute_log_sumexp = kwargs["compute_log_sumexp"] + # unused: bias, cu_seqlens_k, max_seqlen_k, dropout_p, custom_mask_type, scale, causal_diagonal, seqlen_k + + def convert_tensor(t, device): + return FakeTensor(fake_mode, t, device) + + B = query.size(0) + M = query.size(1) + N = key.size(1) + num_heads = query.size(-2) + K = query.size(-1) + Kv = value.size(-1) + + res = convert_tensor( + torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"), + query.device, + ) + + logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B + actual_max_seqlen_q = M + if cu_seqlens_q is not None: + assert max_seqlen_q is not None + actual_max_seqlen_q = max_seqlen_q + logsumexp_dim = ( + math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0 + ) + logsum_exp = convert_tensor( + torch.empty( + (logsumexp_batch_dim, num_heads, logsumexp_dim), + dtype=torch.float, + device="meta", + ), + query.device, + ) + + # See Note [Seed and Offset]: + seed = convert_tensor( + torch.empty((), dtype=torch.long, device="meta"), query.device + ) + offset = convert_tensor( + torch.empty((), dtype=torch.long, device="meta"), query.device + ) + + return res, logsum_exp, seed, offset, M, N + + FAST_OP_IMPLEMENTATIONS = {} diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 73ed2811624..3bfddc3e894 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -14428,7 +14428,14 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace', device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace', - device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),), + device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), + # registered in fake_impls.py instead of _meta_registrations.py + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=(torch.bfloat16, torch.float16, torch.float32)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.bfloat16, torch.float16, torch.float32)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace_all_strides"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),), ), OpInfo( 'torch.ops.aten._flash_attention_forward', @@ -14447,6 +14454,11 @@ op_db: List[OpInfo] = [ # Device mismatch due to philox seed and offset DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'), DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'), + # meta implementation is in fake_impls.py instead of being a meta registration + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), # Checking the scalar value of the philox seed and offset DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), @@ -14481,6 +14493,11 @@ op_db: List[OpInfo] = [ # Device mismatch due to philox seed and offset DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'), DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'), + # meta implementation is in fake_impls.py instead of being a meta registration + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), # Checking the scaler value of the philox seed and offset DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),