diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4c687f9fea3..504c6cb7725 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14491,13 +14491,13 @@ CUDA: _flash_attention_backward # Returns ouput, logsumexp if compute_logsumexp -- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset) +- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) variants: function dispatch: CUDA: _efficient_attention_forward tags: nondeterministic_seeded -- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_k, int max_seqlen_q, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor) +- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor) device_check: NoCheck variants: function dispatch: diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index 607518cba61..28cbd4da3ea 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -307,7 +307,8 @@ _scaled_dot_product_efficient_attention_nestedtensor_cuda( : sdp::CustomMaskType::NoCustomMask; // See Note [Seed and Offset] for description of seed and offset - auto [attention, log_sumexp, seed, offset] = at::_efficient_attention_forward( + // Although max_seqlen_q, and max_seqlen_batch_kv is returned we drop these values. + auto [attention, log_sumexp, seed, offset, max_seqlen_q, max_seqlen_batch_kv] = at::_efficient_attention_forward( query_buffer_reshaped.unsqueeze(0), key_buffer_reshaped.unsqueeze(0), value_buffer_reshaped.unsqueeze(0), diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 707c3b9c38c..63b507826c8 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -742,7 +742,7 @@ std::tuple _scaled_dot_product_efficient_attenti ? sdp::CustomMaskType::CausalFromTopLeft : sdp::CustomMaskType::NoCustomMask; - auto [attention, log_sumexp, seed, offset] = at::_efficient_attention_forward( + auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] = at::_efficient_attention_forward( q_t, k_t, v_t, @@ -874,7 +874,7 @@ _flash_attention_forward( Tensor()); } -std::tuple _efficient_attention_forward( +std::tuple _efficient_attention_forward( const at::Tensor& query, // [b, seqlen, num_heads, K] const at::Tensor& key, // [b, seqlen, num_heads, K] const at::Tensor& value, // [b, seqlen, num_heads, Kv] @@ -915,8 +915,8 @@ std::tuple _efficient_attention_forward( // Embedding per head TORCH_CHECK(query.size(3) == key.size(3)); - // TODO_DRISS we should return max_seqlen_k; - int64_t max_seqlen_q, max_seqlen_k; + + int64_t max_seqlen_q = 0, max_seqlen_k = 0; TORCH_CHECK(seqstart_q.has_value() == seqstart_k.has_value()); if (seqstart_q.has_value()) { TORCH_CHECK(seqstart_q->scalar_type() == at::ScalarType::Int); @@ -1164,10 +1164,12 @@ std::tuple _efficient_attention_forward( std::move(res), std::move(logsumexp), std::move(seed_t), - std::move(offset_t)); + std::move(offset_t), + max_seqlen_q, + max_seqlen_k); #endif TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.") - return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}, 0, 0); } Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tensor& v, double dropout_p){ diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 5fba092c6fe..d506f3f5be0 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -134,14 +134,14 @@ _efficient_attention_backward( const at::Tensor& query, const at::Tensor& key, const at::Tensor& value, - const c10::optional& bias, // additive attention bias + const c10::optional& kernel_bias, // additive attention bias const at::Tensor& out, // (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the // position of the first query token for batch $b - const c10::optional& cu_seqlens_q, + const c10::optional& cu_seqlens_q_dummy, // (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the // position of the first key token for batch $b - const c10::optional& cu_seqlens_k, + const c10::optional& cu_seqlens_k_dummy, // (Mode 1MHK only) Maximum sequence length across batches int64_t max_seqlen_q, // (Mode 1MHK only) Maximum sequence length across batches @@ -158,6 +158,14 @@ _efficient_attention_backward( if (!grad_out_.defined()) { return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); } + // This path is used when we directly call _efficient_attention_forward + // from python. + // This is needed because SaveVariable automatically converts + // c10::optional to undefined tensor + c10::optional bias, cu_seqlens_q, cu_seqlens_k; + bias = kernel_bias.has_value() && !kernel_bias->defined() ? c10::nullopt : kernel_bias; + cu_seqlens_q = cu_seqlens_q_dummy.has_value() && !cu_seqlens_q_dummy->defined() ? c10::nullopt : cu_seqlens_q_dummy; + cu_seqlens_k = cu_seqlens_k_dummy.has_value() && !cu_seqlens_k_dummy->defined() ? c10::nullopt : cu_seqlens_k_dummy; // ndim TORCH_CHECK(query.dim() == grad_out_.dim()); diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index dee237e0398..23bcc44cc32 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -264,11 +264,8 @@ ALLOW_LIST = [ ("aten::_upsample_nearest_exact1d_backward", datetime.date(2022, 12, 15)), ("aten::_upsample_nearest_exact2d", datetime.date(2022, 12, 15)), ("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)), - ("aten::_scaled_dot_product_attention", datetime.date(2023, 8, 1)), - ("aten::_chunk_grad_outputs_efficient_attention", datetime.date(2023, 8, 1)), - ("aten::_scaled_dot_product_flash_attention", datetime.date(2023, 9, 30)), - ("aten::_flash_attention_forward", datetime.date(2023, 9, 30)), - ("aten::_flash_attention_backward", datetime.date(2023, 9, 30)), + ("aten::_efficient_attention_forward", datetime.date(2023, 11, 30)), + ("aten::_efficient_attention_backward", datetime.date(2023, 11, 30)), ("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)), ("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)), ("prim::CudaFusionIvalGuard", datetime.date(2023, 2, 1)), diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index d0b93c54dd3..945162ac69e 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -392,6 +392,7 @@ class TestOperators(TestCase): # query: last dimension must be contiguous # Fused attention kernels require last dim to be contiguous xfail('nn.functional.scaled_dot_product_attention'), + xfail("torch.ops.aten._efficient_attention_forward"), })) @opsToleranceOverride('TestOperators', 'test_grad', ( tol1('nn.functional.binary_cross_entropy_with_logits', @@ -474,6 +475,7 @@ class TestOperators(TestCase): xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents xfail('nn.functional.scaled_dot_product_attention'), + xfail('torch.ops.aten._efficient_attention_forward'), xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented xfail('NumpyExpMarkDirtyAutogradFunction'), # TODO: https://github.com/pytorch/pytorch/issues/91280 @@ -601,6 +603,7 @@ class TestOperators(TestCase): # RuntimeError: query: last dimension must be contiguous # The fused attention kernels require the last dim to be contiguous xfail('nn.functional.scaled_dot_product_attention'), + xfail('torch.ops.aten._efficient_attention_forward'), # BUG # AssertionError: Tensor-likes are not close! xfail('as_strided'), @@ -679,6 +682,7 @@ class TestOperators(TestCase): xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides xfail('sparse.mm', 'reduce'), # sparse tensors have no strides skip('nn.functional.scaled_dot_product_attention'), + xfail('torch.ops.aten._efficient_attention_forward'), # AssertionError: Tensor-likes are not close! # Mismatched elements: 1 / 15 (6.7%) # Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed) @@ -774,6 +778,7 @@ class TestOperators(TestCase): skip("nn.functional.fractional_max_pool2d"), # calls random op skip("nn.functional.fractional_max_pool3d"), # calls random op xfail('nn.functional.scaled_dot_product_attention'), # randomness + xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints xfail('nn.functional.multi_head_attention_forward'), # randomness # It looks like you're either (1) calling .item() on a Tensor or # (2) attempting to use a Tensor in some data-dependent control flow or @@ -888,6 +893,7 @@ class TestOperators(TestCase): skip('nn.functional.dropout3d', ''), # randomness skip('nn.functional.alpha_dropout'), # randomness skip('nn.functional.scaled_dot_product_attention'), # randomness + xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints skip('nn.functional.multi_head_attention_forward'), # randomness xfail('index_put', ''), # not possible due to dynamic shapes; we support a subset xfail('nn.functional.fractional_max_pool2d'), # random @@ -982,6 +988,7 @@ class TestOperators(TestCase): skip('nn.functional.dropout2d', ''), skip('nn.functional.dropout3d', ''), skip('nn.functional.scaled_dot_product_attention'), # randomness + xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints skip('nn.functional.multi_head_attention_forward'), # randomness skip('nn.functional.alpha_dropout'), # randomness skip('nn.functional.feature_alpha_dropout', 'without_train'), @@ -1253,6 +1260,7 @@ class TestOperators(TestCase): skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness skip('nn.functional.scaled_dot_product_attention'), + xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints skip('nn.functional.multi_head_attention_forward'), # randomness skip('nn.functional.alpha_dropout'), # randomness skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format @@ -1376,6 +1384,7 @@ class TestOperators(TestCase): xfail('nn.functional.ctc_loss', ''), # NYI: forward-AD for _ctc_loss xfail('nn.functional.pdist', ''), # NYI: forward-AD with _pdist_forward skip('nn.functional.scaled_dot_product_attention'), + xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints xfail('nn.functional.multi_margin_loss', ''), # NYI: forward AD with multi_margin_loss skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides @@ -1500,6 +1509,7 @@ class TestOperators(TestCase): xfail('nn.functional.dropout3d'), # calls random op xfail('nn.functional.dropout'), # calls random op xfail('nn.functional.scaled_dot_product_attention'), # randomness + xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints xfail('nn.functional.multi_head_attention_forward'), # randomness xfail('nn.functional.embedding_bag'), # Forward AD not implemented and no decomposition xfail('nn.functional.alpha_dropout'), # calls randomn op @@ -1768,6 +1778,7 @@ class TestOperators(TestCase): xfail('nn.functional.max_unpool2d', 'grad'), # contiguous call xfail('nn.functional.max_unpool2d'), # contiguous call xfail('to_sparse'), # dispatch key issue + xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints # https://github.com/pytorch/pytorch/issues/96560 decorate('xlogy', decorator=skipIfRocm), diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 9a3e5007bf1..b0c21421b8b 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3604,6 +3604,8 @@ class TestVmapOperatorsOpInfo(TestCase): xfail('addcmul'), xfail('clamp'), + xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints + # TypeError: expected Tensor as element 0 in argument 0, but got float xfail('item'), })) @@ -3660,6 +3662,7 @@ class TestVmapOperatorsOpInfo(TestCase): xfail('nn.functional.dropout'), # works, can't check against for loop because of randomness inconsistency xfail('nn.functional.scaled_dot_product_attention'), # randomness xfail('nn.functional.multi_head_attention_forward'), # randomness + xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints xfail('resize_'), xfail('view_as_complex'), xfail('matrix_exp'), diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 77829e29ab9..a305f46555b 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -234,6 +234,7 @@ inductor_expected_failures_single_sample["cuda"] = { "to_sparse": {f16, f32, f64}, "pca_lowrank": {f32, f64}, "svd_lowrank": {f32, f64}, + "torch.ops.aten._efficient_attention_forward": {f16, bf16, f32}, } diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 7229ad47a51..3c19c6ed198 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2783,6 +2783,10 @@ # output_differentiability: [True, False, False, False, False, False, False, False] # query, key, value: _flash_attention_backward(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale) +- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) + output_differentiability: [True, False, False, False, False, False] + query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale) + # fft - name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back())) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 079eee30037..4028d8b1b48 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2082,7 +2082,8 @@ make_fallback( sdpa_constraint, warn=False, ) - +make_fallback(torch.ops.aten._efficient_attention_forward.default) +make_fallback(torch.ops.aten._efficient_attention_backward.default) make_fallback(aten.sort) make_fallback(aten.sort.stable) make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index b36abe55952..05a3377b818 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5188,6 +5188,90 @@ def meta__scaled_dot_product_efficient_backward( return grad_q, grad_k, grad_v, grad_bias +@register_meta( + [ + aten._efficient_attention_forward, + ] +) +def meta__efficient_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + bias: Optional[Tensor], + cu_seqlens_q: Optional[Tensor], + cu_seqlens_k: Optional[Tensor], + max_seqlen_q: Optional[int], + dropout_p: float, + custom_mask_type: int, + compute_log_sumexp: bool = False, + scale: Optional[float] = None, + causal_diagonal: Optional[Tensor] = None, + seqlen_k: Optional[Tensor] = None, +): + B = query.size(0) + M = query.size(1) + N = key.size(1) + num_heads = query.size(-2) + K = query.size(-1) + Kv = value.size(-1) + + res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device) + + logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0 + logsum_exp = torch.empty( + (B, num_heads, logsumexp_dim), + dtype=torch.float, + device=query.device, + ) + + # See Note [Seed and Offset]: + seed = torch.empty((), dtype=torch.long, device="meta") + offset = torch.empty((), dtype=torch.long, device="meta") + + return res, logsum_exp, seed, offset, M, N + + +@register_meta( + [ + aten._efficient_attention_backward, + ] +) +def meta__efficient_attention_backward( + grad_out: Tensor, + query: Tensor, + key: Tensor, + value: Tensor, + bias: Optional[Tensor], + cu_seqlens_q: Optional[Tensor], + cu_seqlens_k: Optional[Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + logsumexp: Tensor, + dropout_p: float, + philox_seed: Tensor, + philox_offset: Tensor, + custom_mask_type: int, + bias_requires_grad: bool, + scale: Optional[float] = None, + num_splits_key: Optional[int] = None, +): + grad_query = torch.empty_like(query) + grad_key = torch.empty_like(key) + grad_value = torch.empty_like(value) + + if bias is not None: + assert bias is not None + lastDim = bias.size(-1) + lastDimAligned = 16 * ((lastDim + 15) // 16) + new_sizes = list(bias.size()) + new_sizes[-1] = lastDimAligned + grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device) + else: + grad_bias = torch.empty((), device=query.device) + + return grad_query, grad_key, grad_value, grad_bias + + @register_meta([aten._scaled_mm.default]) def meta_scaled_mm( self: torch.Tensor, diff --git a/torch/_subclasses/fake_utils.py b/torch/_subclasses/fake_utils.py index 3e676f59115..6aace66851c 100644 --- a/torch/_subclasses/fake_utils.py +++ b/torch/_subclasses/fake_utils.py @@ -45,6 +45,25 @@ def output_alias_each_other(outputs): return False +def is_sdpa_error(func, idx, e): + if ( + func is aten._scaled_dot_product_flash_attention.default + and idx in (6, 7) + and "Devices" in repr(e) + ): + return True + if ( + ( + func is aten._scaled_dot_product_efficient_attention.default + or func is aten._efficient_attention_forward.default + ) + and idx in (2, 3) + and "Devices" in repr(e) + ): + return True + return False + + class CrossRefFakeMode(TorchDispatchMode): def __init__( self, @@ -155,17 +174,7 @@ class CrossRefFakeMode(TorchDispatchMode): allow_rhs_unbacked=True, ) except Exception as e: - if ( - func is aten._scaled_dot_product_flash_attention.default - and idx in (6, 7) - and "Devices" in repr(e) - ): - continue - if ( - func is aten._scaled_dot_product_efficient_attention.default - and idx in (2, 3) - and "Devices" in repr(e) - ): + if is_sdpa_error(func, idx, e): continue error_message = ( f"{context} mismatched tensor metadata: {e}" diff --git a/torch/testing/_internal/common_dtype.py b/torch/testing/_internal/common_dtype.py index dbd691648ca..8d7d2bff25c 100644 --- a/torch/testing/_internal/common_dtype.py +++ b/torch/testing/_internal/common_dtype.py @@ -76,6 +76,10 @@ _all_types_and_half = _all_types + (torch.half,) def all_types_and_half(): return _all_types_and_half +def custom_types(*dtypes): + """Create a list of arbitrary dtypes""" + return _empty_types + _validate_dtypes(*dtypes) + # The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro # See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS. diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4c655523bd7..c372171636b 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -18,7 +18,7 @@ from torch.testing import make_tensor from torch.testing._internal.common_dtype import ( _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, - all_types, empty_types, complex_types_and, integral_types + all_types, empty_types, complex_types_and, integral_types, custom_types ) from torch.testing._internal.common_device_type import \ (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, @@ -8242,6 +8242,78 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_ yield from samples + +def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + batch, num_heads, head_dim = 4, 4, 8 + seq_q = 11 + seq_kv = 32 + + dim_4_q_shape = (batch, num_heads, seq_q, head_dim) + dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) + + qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)] + samples = [] + mask_types = [1, 2] # UpperLeft, LowerRight + scales = [None, 1.0] + + for qkv_shape, is_causal, dropout_p, mask_type, scale in product( + qkv_shapes, [True, False], [0.0, 0.5], mask_types, scales): + shape_q, shape_kv = qkv_shape + samples.append(SampleInput( + make(shape_q).transpose(1, 2), + make(shape_kv).transpose(1, 2), + make(shape_kv).transpose(1, 2), + bias=None, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + dropout_p=dropout_p, + custom_mask_type=mask_type, + compute_log_sumexp=requires_grad, + scale=scale, + causal_diagonal=None, + seqlen_k=None + )) + + # Add non standard shapes + diff_v_head_dim = SampleInput( + make((batch, seq_q, num_heads, head_dim)), + make((batch, seq_kv, num_heads, head_dim)), + make((batch, seq_kv, num_heads, head_dim + 8)), + bias=None, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + dropout_p=dropout_p, + custom_mask_type=0, # No Mask + compute_log_sumexp=requires_grad, + scale=None, + causal_diagonal=None, + seqlen_k=None + ) + + # Add an attn_mask + samples.append( + SampleInput( + make((batch, seq_q, num_heads, head_dim)), + make((batch, seq_kv, num_heads, head_dim)), + make((batch, seq_kv, num_heads, head_dim)), + bias=make(batch, num_heads, seq_q, seq_kv), + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + dropout_p=dropout_p, + custom_mask_type=0, # No Mask + compute_log_sumexp=requires_grad, + scale=None, + causal_diagonal=None, + seqlen_k=None + ) + ) + + yield from samples + def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -14172,6 +14244,31 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),), ), + OpInfo( + 'torch.ops.aten._efficient_attention_forward', + sample_inputs_func=sample_inputs_efficient_attention_forward, + dtypes=empty_types(), + dtypesIfCUDA=custom_types(torch.float16, torch.float32) + if not SM80OrLater + else custom_types(torch.float16, torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=True, + supports_fwgrad_bwgrad=False, + supports_forward_ad=False, + check_batched_forward_grad=False, + decorators=[], + skips=( + # Device mismatch due to philox seed and offset + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'), + # Checking the scaler value of the philox seed and offset + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'), + # None Mismatch Tensor + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'), + ) + ), UnaryUfuncInfo( 'nn.functional.silu', aten_backward_name='silu_backward', diff --git a/torch/testing/_internal/composite_compliance.py b/torch/testing/_internal/composite_compliance.py index 6e8e260893c..764e5d90f05 100644 --- a/torch/testing/_internal/composite_compliance.py +++ b/torch/testing/_internal/composite_compliance.py @@ -422,6 +422,7 @@ def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradch results = output_process_fn_grad(results) flat_results = pytree.tree_leaves(results) + flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)] flat_diff_results = [r for r in flat_results if r.requires_grad] assert len(flat_diff_results) > 0 @@ -467,6 +468,7 @@ def check_backward_formula(op: Callable, args, kwargs, ) flat_results = pytree.tree_leaves(results) + flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)] flat_diff_results = [r for r in flat_results if r.requires_grad] assert len(flat_diff_results) > 0