diff --git a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp index c68d3217d60..7be355b74c2 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp @@ -40,37 +40,14 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) { return true; } -bool input_require_grad( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const std::optional& attn_mask) { - return at::GradMode::is_enabled() && - (query.requires_grad() || key.requires_grad() || value.requires_grad() || - (attn_mask.has_value() && attn_mask.value().requires_grad())); -} - -bool check_grad(sdp::sdp_params const& params, bool debug) { - if (!input_require_grad( - params.query, params.key, params.value, params.attn_mask)) - return true; - - auto q_num_heads = params.query.sym_size(-3); - auto k_num_heads = params.key.sym_size(-3); - auto v_num_heads = params.value.sym_size(-3); - bool is_gqa = q_num_heads != k_num_heads || q_num_heads != v_num_heads; - if (debug && is_gqa) - TORCH_WARN( - "scale_dot_product_attention with gqa is not supported for gradient computation on xpu."); - - bool attn_mask_needs_grad = - params.attn_mask.has_value() && params.attn_mask.value().requires_grad(); - if (debug && attn_mask_needs_grad) { - TORCH_WARN( - "scale_dot_product_attention on xpu is not supported when attn_mask.requires_grad() == True."); +bool check_no_grad(sdp::sdp_params const& params, bool debug) { + const bool any_inputs_require_grad = params.query.requires_grad() || + params.key.requires_grad() || params.value.requires_grad(); + const bool gradmode_enabled = at::GradMode::is_enabled(); + if (debug && any_inputs_require_grad && gradmode_enabled) { + TORCH_WARN("Backward or grad to be supported."); } - - return !is_gqa && !attn_mask_needs_grad; + return !any_inputs_require_grad || !gradmode_enabled; } bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) { @@ -88,7 +65,7 @@ bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) { sdp::check_nonzero_sequence_lengths_dense, sdp::check_last_dim_stride_equals_1_dense, check_head_dim_size_xpu, - check_grad); + check_no_grad); for (auto& constraint : constraints) { if (!constraint(params, debug)) { return false; @@ -248,11 +225,10 @@ _scaled_dot_product_fused_attention_overrideable_xpu( double dropout_p, bool is_causal, bool return_debug_mask, - std::optional scale, - bool compute_logsumexp) { + std::optional scale) { TORCH_INTERNAL_ASSERT( query.dim() == 4 && key.dim() == 4 && value.dim() == 4, - "scaled_dot_product_fused_attention_overrideable_xpu: Accept only 4 dims inputs shape of {B, H, T, K}"); + "scaled_dot_product_fused_attention_overrideable_xpu: Accept only 4 dims inputs shape of {(B), H, T, K}"); TORCH_INTERNAL_ASSERT( (key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) && (key.size(2) == value.size(2)), @@ -269,9 +245,6 @@ _scaled_dot_product_fused_attention_overrideable_xpu( TORCH_INTERNAL_ASSERT( !(attn_bias.has_value() && is_causal), "scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal"); - TORCH_INTERNAL_ASSERT( - !(attn_bias.has_value() && attn_bias.value().requires_grad()), - "scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot have requires_grad=True"); const int64_t batch_size = query.size(0); const int64_t num_head_q = query.size(1); @@ -281,14 +254,11 @@ _scaled_dot_product_fused_attention_overrideable_xpu( const int64_t seq_len_q = query.size(2); const int64_t seq_len_kv = key.size(2); - at::Tensor attention; - std::vector attention_shape = { + at::Tensor output; + std::vector output_shape = { batch_size, num_head_q, seq_len_q, head_dim_v}; - alloc_with_matching_layout(query, attention, attention_shape); - - auto opts = query.options(); - at::Tensor logsumexp = - at::empty({batch_size, num_head_q, seq_len_q}, opts.dtype(at::kFloat)); + alloc_with_matching_layout(query, output, output_shape); + at::Tensor logsumexp, debug_attn_mask; // not supported at::native::onednn::sdpa( batch_size, @@ -304,15 +274,15 @@ _scaled_dot_product_fused_attention_overrideable_xpu( attn_bias, is_causal, scale.has_value() ? scale.value() : (1.0 / std::sqrt(head_dim_qk)), - attention, - compute_logsumexp, + output, + false, logsumexp); // rng not used auto philox_seed = at::empty({}, at::dtype(at::kLong)); auto philox_offset = at::empty({}, at::dtype(at::kLong)); return std::make_tuple( - attention, + output, logsumexp, /* cum_seq_q */ at::Tensor(), /* cum_seq_k */ at::Tensor(), @@ -320,106 +290,7 @@ _scaled_dot_product_fused_attention_overrideable_xpu( seq_len_kv, philox_seed, philox_offset, - /*debug_attn_mask */ at::Tensor()); -} - -std::tuple -_scaled_dot_product_fused_attention_overrideable_backward_xpu( - const at::Tensor& grad_out, - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const at::Tensor& attn_bias, - std::array grad_input_mask, - const at::Tensor& out, - const at::Tensor& logsumexp, - const at::Tensor& cum_seq_q, - const at::Tensor& cum_seq_k, - int64_t max_q, - int64_t max_k, - double dropout_p, - bool is_causal, - const at::Tensor& philox_seed, - const at::Tensor& philox_offset, - std::optional scale) { - TORCH_INTERNAL_ASSERT( - grad_out.dim() == 4 && out.dim() == 4 && - grad_out.size(0) == out.size(0) && grad_out.size(1) == out.size(1) && - grad_out.size(2) == out.size(2) && grad_out.size(3) == out.size(3), - "scaled_dot_product_fused_attention_overrideable_backward_xpu: grad_out and out should have the same shape of {B, H, T, K}"); - TORCH_INTERNAL_ASSERT( - query.dim() == 4 && key.dim() == 4 && value.dim() == 4, - "scaled_dot_product_fused_attention_overrideable_backward_xpu: Accept only 4 dims inputs shape of {B, H, T, K}"); - TORCH_INTERNAL_ASSERT( - (key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) && - (key.size(2) == value.size(2)), - "scaled_dot_product_fused_attention_overrideable_backward_xpu: K/V should have the same batch / seq / num_head"); - TORCH_INTERNAL_ASSERT( - query.size(0) == grad_out.size(0) && query.size(1) == grad_out.size(1) && - query.size(2) == grad_out.size(2), - "scaled_dot_product_fused_attention_overrideable_backward_xpu: Q should have the same batch / num_head / seq_len as grad_out"); - TORCH_INTERNAL_ASSERT( - query.size(3) == key.size(3), - "scaled_dot_product_fused_attention_overrideable_backward_xpu: Q/K should have the same head_dim"); - TORCH_INTERNAL_ASSERT( - value.size(3) == grad_out.size(3), - "scaled_dot_product_fused_attention_overrideable_backward_xpu: V should have the same head_dim as grad_out"); - TORCH_INTERNAL_ASSERT( - query.size(1) == key.size(1), - "scaled_dot_product_fused_attention_overrideable_backward_xpu: number of heads in K/V must equal to number of heads in Q"); - TORCH_INTERNAL_ASSERT( - dropout_p == 0.0, - "scaled_dot_product_fused_attention_overrideable_backward_xpu: Currently do not support dropout > 0"); - TORCH_INTERNAL_ASSERT( - logsumexp.dim() == 3 && logsumexp.size(0) == query.size(0) && - logsumexp.size(1) == query.size(1) && - logsumexp.size(2) == query.size(2) && - "scaled_dot_product_fused_attention_overrideable_backward_xpu: logsumexp should have the shape of {B, H, T}"); - - std::optional attn_bias_opt; - if (attn_bias.defined()) { - attn_bias_opt = attn_bias; - } - - const int64_t batch_size = query.size(0); - const int64_t num_head_q = query.size(1); - const int64_t num_head_kv = key.size(1); - const int64_t seq_len_q = query.size(2); - const int64_t seq_len_kv = key.size(2); - const int64_t head_dim_qk = query.size(3); - const int64_t head_dim_v = value.size(3); - - auto grad_q = at::empty_like(query); - auto grad_k = at::empty_like(key); - auto grad_v = at::empty_like(value); - auto grad_attn_bias = attn_bias_opt.has_value() - ? at::empty_like(attn_bias_opt.value()) - : at::Tensor(); - at::native::onednn::sdpa_backward( - batch_size, - num_head_q, - num_head_kv, - seq_len_q, - seq_len_kv, - head_dim_qk, - head_dim_v, - grad_out, - query, - key, - value, - out, - logsumexp, - attn_bias_opt, - is_causal, - scale.has_value() ? scale.value() : (1.0 / std::sqrt(query.size(3))), - grad_q, - grad_k, - grad_v); - return std::make_tuple( - std::move(grad_q), - std::move(grad_k), - std::move(grad_v), - std::move(grad_attn_bias)); + debug_attn_mask); } REGISTER_XPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_xpu); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index d76a7385907..0cc734350b9 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -15095,7 +15095,7 @@ CPU: _scaled_dot_product_flash_attention_cpu tags: nondeterministic_seeded -- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None, bool compute_log_sumexp=True) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) +- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) dispatch: CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable XPU: _scaled_dot_product_fused_attention_overrideable_xpu @@ -15119,7 +15119,6 @@ variants: function dispatch: CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable_backward - XPU: _scaled_dot_product_fused_attention_overrideable_backward_xpu - func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) dispatch: diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index f206aca87dc..7aad4309924 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -768,11 +768,8 @@ Tensor scaled_dot_product_attention( return std::get<0>(out_and_lse); } case SDPBackend::overrideable: { - bool compute_logsumexp = should_compute_logsumexp(query_, key, value); - compute_logsumexp = compute_logsumexp || - (at::GradMode::is_enabled() && attn_mask.has_value() && attn_mask.value().requires_grad()); auto out_lse_softmax = at::_scaled_dot_product_fused_attention_overrideable( - query_, key, value, attn_mask, dropout_p, is_causal, false /*return_debug_mask*/, scale, compute_logsumexp); + query_, key, value, attn_mask, dropout_p, is_causal, false /*return_debug_mask*/, scale); return std::get<0>(out_lse_softmax); } case SDPBackend::math: { @@ -1018,8 +1015,7 @@ _scaled_dot_product_fused_attention_overrideable( double dropout_p, bool is_causal, bool return_debug_mask, - std::optional scale, - bool compute_logsumexp) { + std::optional scale) { TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_dot_product_fused_attention_overrideable not implemented. This is an operator for privateuse1 backends, please use TORCH_LIBRARY_IMPL to override this function "); } diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp index 2fc15c70666..1ba82564cc3 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp @@ -58,8 +58,7 @@ wrapper__scaled_dot_product_fused_attention_overrideable( double dropout_p, bool is_causal, bool return_debug_mask, - std::optional scale, - bool compute_log_sumexp) { + std::optional scale) { return at::native::openreg::_scaled_dot_product_fused_attention_overrideable( query, key, @@ -68,8 +67,7 @@ wrapper__scaled_dot_product_fused_attention_overrideable( dropout_p, is_causal, return_debug_mask, - scale, - compute_log_sumexp); + scale); } std::tuple diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp index 9aa341bb4bc..129ad621cf8 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp @@ -47,8 +47,7 @@ _scaled_dot_product_fused_attention_overrideable( double dropout_p, bool is_causal, bool return_debug_mask, - std::optional scale, - bool compute_log_sumexp) { + std::optional scale) { const int64_t batch_size = query.size(0); const int64_t num_heads = query.size(1); const int64_t head_dim_v = value.size(3); diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.h index b4d24312a45..f002949a103 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.h +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.h @@ -39,8 +39,7 @@ _scaled_dot_product_fused_attention_overrideable( double dropout_p, bool is_causal, bool return_debug_mask, - std::optional scale, - bool compute_log_sumexp); + std::optional scale); std::tuple _scaled_dot_product_fused_attention_overrideable_backward( const at::Tensor& grad_out, diff --git a/test/test_transformers.py b/test/test_transformers.py index 25ef1cd9594..ec6ae548996 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -4387,7 +4387,7 @@ class TestSDPAXpuOnly(NNTestCase): self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2) - @parametrize("fused_kernel", [SDPBackend.OVERRIDEABLE]) + @parametrize("fused_kernel", [SDPBackend.MATH, SDPBackend.OVERRIDEABLE]) @parametrize("dtype", [torch.half, torch.bfloat16, torch.float32]) @parametrize("batch_size,n_head,q_size,kv_size,head_dim", [ (2, 5, 9216, 9216, 64), @@ -4426,7 +4426,7 @@ class TestSDPAXpuOnly(NNTestCase): tol = Tolerances(5e-2, 5e-2) if dtype is torch.float16: tol = Tolerances(1e-2, 1e-2) - mask_shape = [batch_size, 1, q_size, kv_size] + mask_shape = [batch_size, 1, 1, kv_size] make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False) q_shape = SdpaShape(batch_size, n_head, q_size, head_dim) kv_shape = SdpaShape(batch_size, n_head, kv_size, head_dim) @@ -4435,6 +4435,14 @@ class TestSDPAXpuOnly(NNTestCase): v = make_tensor(kv_shape) q2, k2, v2 = q.clone(), k.clone(), v.clone() + if train: + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + q2.requires_grad_(True) + k2.requires_grad_(True) + v2.requires_grad_(True) + # (B, nh, T, hs) q = q.view(batch_size, q_size, n_head, head_dim).transpose(1, 2) k = k.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2) @@ -4454,43 +4462,17 @@ class TestSDPAXpuOnly(NNTestCase): v2 = v2.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2) attn_mask2 = attn_mask.float() if attn_mask is not None else None - if train: - q = q.detach().clone().requires_grad_(True) - k = k.detach().clone().requires_grad_(True) - v = v.detach().clone().requires_grad_(True) - q2 = q2.detach().clone().requires_grad_(True) - k2 = k2.detach().clone().requires_grad_(True) - v2 = v2.detach().clone().requires_grad_(True) + if fused_kernel == SDPBackend.MATH: + actual = torch.ops.aten._scaled_dot_product_attention_math( + q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal)[0] + elif fused_kernel == SDPBackend.OVERRIDEABLE: + actual = torch.ops.aten._scaled_dot_product_fused_attention_overrideable( + q, k, v, attn_bias=attn_mask, dropout_p=0.0, is_causal=is_causal)[0] - with sdpa_kernel(backends=[fused_kernel]): - actual = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal) + math_ref = torch.ops.aten._scaled_dot_product_attention_math( + q2, k2, v2, attn_mask=attn_mask2, dropout_p=0.0, is_causal=is_causal)[0] - with sdpa_kernel(backends=[SDPBackend.MATH]): - math_ref = F.scaled_dot_product_attention( - q2, k2, v2, attn_mask=attn_mask2, dropout_p=0.0, is_causal=is_causal) - - if dtype in [torch.float16, torch.bfloat16]: - math_ref = math_ref.to(dtype) - - self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol) - - if train: - loss = torch.mean(actual) - loss_ref = torch.mean(math_ref) - loss.backward() - loss_ref.backward() - - grad_q_actual, grad_k_actual, grad_v_actual = q.grad, k.grad, v.grad - grad_q_ref, grad_k_ref, grad_v_ref = q2.grad, k2.grad, v2.grad - if dtype in [torch.float16, torch.bfloat16]: - grad_q_ref = grad_q_ref.to(dtype) - grad_k_ref = grad_k_ref.to(dtype) - grad_v_ref = grad_v_ref.to(dtype) - - self.assertEqual(grad_q_actual, grad_q_ref, atol=tol.atol, rtol=tol.rtol) - self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol) - self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol) + self.assertEqual(actual.float(), math_ref, atol=tol.atol, rtol=tol.rtol) class TestAttnBias(NNTestCase): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 4cd02ed35e9..88e0a316f9d 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2907,7 +2907,7 @@ output_differentiability: [True, False, False, False, False, False, False, False, False] query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) -- name: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None, bool compute_log_sumexp=True) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) +- name: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) output_differentiability: [True, False, False, False, False, False, False, False, False] query, key, value, attn_bias: _scaled_dot_product_fused_attention_overrideable_backward_symint(grad, query, key, value, attn_bias, grad_input_mask, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 2305ffbaca5..91080bf5a8b 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5754,7 +5754,6 @@ def meta__scaled_dot_product_fused_attention_overrideable( is_causal: bool = False, return_debug_mask: bool = False, scale: Optional[float] = None, - compute_log_sumexp: bool = True, ): B = query.size(0) H_Q = query.size(1) @@ -5788,36 +5787,6 @@ def meta__scaled_dot_product_fused_attention_overrideable( ) -@register_meta([aten._scaled_dot_product_fused_attention_overrideable_backward]) -def meta__scaled_dot_product_fused_attention_overrideable_backward( - grad_out: Tensor, - query: Tensor, - key: Tensor, - value: Tensor, - attn_bias: Tensor, - grad_input_mask: list[bool], - out: Tensor, - logsumexp: Tensor, - cum_seq_q: Tensor, - cum_seq_k: Tensor, - max_q: int, - max_k: int, - dropout_p: float, - is_causal: bool, - philox_seed: Tensor, - philox_offset: Tensor, - scale: Optional[float] = None, -): - grad_q = torch.empty_like(query) - grad_k = torch.empty_like(key) - grad_v = torch.empty_like(value) - - grad_attn_bias = None - if attn_bias is not None: - grad_attn_bias = torch.empty_like(attn_bias) - return grad_q, grad_k, grad_v, grad_attn_bias - - @register_meta( [ aten._scaled_dot_product_flash_attention_backward, diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index f649f9781ec..aced2b2f539 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -36,7 +36,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int32_t compute_log_sumexp, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index 3688e85c316..c41487ae6dd 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -42,7 +42,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_a AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_attention_backward(AtenTensorHandle grad_out_, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double dropout_p, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, int32_t is_causal, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int32_t compute_log_sumexp, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_grouped_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* offs, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h index 24dc9a32d24..e075956e14d 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h @@ -25,7 +25,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int32_t compute_log_sumexp, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h index e24728fd5c1..39f0dec8616 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h @@ -15,7 +15,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHand AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int32_t compute_log_sumexp, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int4pack_mm_with_scales_and_zeros(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScale, AtenTensorHandle qZeros, AtenTensorHandle* ret0);